1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/grappler/grappler_item.h"
17
18#include <unordered_map>
19#include <unordered_set>
20#include <vector>
21
22#include "absl/container/flat_hash_set.h"
23#include "absl/strings/str_join.h"
24#include "tensorflow/core/framework/attr_value.pb.h"
25#include "tensorflow/core/framework/node_def.pb.h"
26#include "tensorflow/core/grappler/op_types.h"
27#include "tensorflow/core/grappler/utils.h"
28#include "tensorflow/core/grappler/utils/transitive_fanin.h"
29#include "tensorflow/core/util/device_name_utils.h"
30
31namespace tensorflow {
32namespace grappler {
33
34GrapplerItem::OptimizationOptions CreateOptOptionsForEager() {
35 GrapplerItem::OptimizationOptions optimization_options;
36 // Tensorflow 2.0 in eager mode with automatic control dependencies will
37 // prune all nodes that are not in the transitive fanin of the fetch nodes.
38 // However because the function will be executed via FunctionLibraryRuntime,
39 // and current function implementation does not prune stateful and dataset
40 // ops, we rely on Grappler to do the correct graph pruning.
41 optimization_options.allow_pruning_stateful_and_dataset_ops = true;
42
43 optimization_options.is_eager_mode = true;
44
45 // All the nested function calls will be executed and optimized via
46 // PartitionedCallOp, there is no need to optimize functions now.
47 optimization_options.optimize_function_library = false;
48
49 return optimization_options;
50}
51
52GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
53 GrapplerItem item;
54 item.id = id;
55 item.feed = feed;
56 item.fetch = fetch;
57 item.init_ops = init_ops;
58 item.keep_ops = keep_ops;
59 item.expected_init_time = expected_init_time;
60 item.save_op = save_op;
61 item.restore_op = restore_op;
62 item.save_restore_loc_tensor = save_restore_loc_tensor;
63 item.queue_runners = queue_runners;
64 item.devices_ = devices_;
65 item.optimization_options_ = optimization_options_;
66 item.graph.Swap(&graph_def);
67 return item;
68}
69
70std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
71 std::vector<const NodeDef*> fanin_nodes;
72 TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
73 return fanin_nodes;
74}
75
76std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
77 std::vector<string> enqueue_ops;
78 for (const auto& queue_runner : queue_runners) {
79 for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
80 enqueue_ops.push_back(enqueue_op);
81 }
82 }
83 std::vector<const NodeDef*> fanin_nodes;
84 TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
85 return fanin_nodes;
86}
87
88std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
89 std::vector<const NodeDef*> fanin_nodes;
90 TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin_nodes));
91 return fanin_nodes;
92}
93
94std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
95 std::vector<const NodeDef*> fanin;
96 TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin));
97 std::vector<const NodeDef*> vars;
98 for (const NodeDef* node : fanin) {
99 if (IsVariable(*node)) {
100 vars.push_back(node);
101 }
102 }
103 return vars;
104}
105
106std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
107 std::unordered_set<string> result;
108 for (const string& f : fetch) {
109 VLOG(1) << "Add fetch " << f;
110 result.insert(NodeName(f));
111 }
112 for (const auto& f : feed) {
113 VLOG(1) << "Add feed " << f.first;
114 result.insert(NodeName(f.first));
115 }
116 for (const auto& node : init_ops) {
117 result.insert(NodeName(node));
118 }
119 for (const auto& node : keep_ops) {
120 result.insert(NodeName(node));
121 }
122 if (!save_op.empty()) {
123 result.insert(NodeName(save_op));
124 }
125 if (!restore_op.empty()) {
126 result.insert(NodeName(restore_op));
127 }
128 if (!save_restore_loc_tensor.empty()) {
129 result.insert(NodeName(save_restore_loc_tensor));
130 }
131
132 for (const auto& queue_runner : queue_runners) {
133 for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
134 result.insert(NodeName(enqueue_op));
135 }
136 if (!queue_runner.close_op_name().empty()) {
137 result.insert(NodeName(queue_runner.close_op_name()));
138 }
139 if (!queue_runner.cancel_op_name().empty()) {
140 result.insert(NodeName(queue_runner.cancel_op_name()));
141 }
142 }
143
144 absl::optional<FunctionLibraryDefinition> fn_library;
145 if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
146 fn_library.emplace(OpRegistry::Global(), graph.library());
147 }
148 for (const NodeDef& node : graph.node()) {
149 const auto attrs = AttrSlice(&node.attr());
150
151 // Tensorflow functions do not prune stateful or dataset-output ops from
152 // the function body (see PruneFunctionBody in common_runtime/function.cc).
153 if (!optimization_options_.allow_pruning_stateful_and_dataset_ops &&
154 (IsStateful(node, &*fn_library) || IsDataset(node))) {
155 result.insert(node.name());
156 }
157
158 // Do not remove ops with attribute _grappler_do_not_remove. This is useful
159 // for debugging.
160 bool do_not_remove;
161 if (TryGetNodeAttr(attrs, "_grappler_do_not_remove", &do_not_remove) &&
162 do_not_remove) {
163 result.insert(node.name());
164 }
165 }
166
167 return result;
168}
169
170const std::unordered_set<string>& GrapplerItem::devices() const {
171 return devices_;
172}
173
174Status GrapplerItem::AddDevice(const string& device) {
175 DeviceNameUtils::ParsedName name;
176
177 if (!DeviceNameUtils::ParseFullName(device, &name)) {
178 return errors::InvalidArgument("Invalid device name: device=", device);
179
180 } else if (!name.has_job || !name.has_replica || !name.has_task ||
181 !name.has_type || !name.has_id) {
182 return errors::InvalidArgument("Not a fully defined device name: device=",
183 device);
184 }
185
186 devices_.insert(DeviceNameUtils::ParsedNameToString(name));
187 return OkStatus();
188}
189
190Status GrapplerItem::AddDevices(const GrapplerItem& other) {
191 std::vector<absl::string_view> invalid_devices;
192 for (const string& device : other.devices()) {
193 Status added = AddDevice(device);
194 if (!added.ok()) invalid_devices.emplace_back(device);
195 }
196 return invalid_devices.empty()
197 ? OkStatus()
198 : errors::InvalidArgument("Skipped invalid devices: [",
199 absl::StrJoin(invalid_devices, ", "),
200 "]");
201}
202
203Status GrapplerItem::InferDevicesFromGraph() {
204 absl::flat_hash_set<absl::string_view> invalid_devices;
205 for (const NodeDef& node : graph.node()) {
206 Status added = AddDevice(node.device());
207 if (!added.ok()) invalid_devices.insert(node.device());
208 }
209 VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
210 return invalid_devices.empty()
211 ? OkStatus()
212 : errors::InvalidArgument("Skipped invalid devices: [",
213 absl::StrJoin(invalid_devices, ", "),
214 "]");
215}
216
217void GrapplerItem::ClearDevices() { devices_.clear(); }
218
219const GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options()
220 const {
221 return optimization_options_;
222}
223
224GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() {
225 return optimization_options_;
226}
227
228} // end namespace grappler
229} // end namespace tensorflow
230