1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/grappler/grappler_item.h" |
17 | |
18 | #include <unordered_map> |
19 | #include <unordered_set> |
20 | #include <vector> |
21 | |
22 | #include "absl/container/flat_hash_set.h" |
23 | #include "absl/strings/str_join.h" |
24 | #include "tensorflow/core/framework/attr_value.pb.h" |
25 | #include "tensorflow/core/framework/node_def.pb.h" |
26 | #include "tensorflow/core/grappler/op_types.h" |
27 | #include "tensorflow/core/grappler/utils.h" |
28 | #include "tensorflow/core/grappler/utils/transitive_fanin.h" |
29 | #include "tensorflow/core/util/device_name_utils.h" |
30 | |
31 | namespace tensorflow { |
32 | namespace grappler { |
33 | |
34 | GrapplerItem::OptimizationOptions CreateOptOptionsForEager() { |
35 | GrapplerItem::OptimizationOptions optimization_options; |
36 | // Tensorflow 2.0 in eager mode with automatic control dependencies will |
37 | // prune all nodes that are not in the transitive fanin of the fetch nodes. |
38 | // However because the function will be executed via FunctionLibraryRuntime, |
39 | // and current function implementation does not prune stateful and dataset |
40 | // ops, we rely on Grappler to do the correct graph pruning. |
41 | optimization_options.allow_pruning_stateful_and_dataset_ops = true; |
42 | |
43 | optimization_options.is_eager_mode = true; |
44 | |
45 | // All the nested function calls will be executed and optimized via |
46 | // PartitionedCallOp, there is no need to optimize functions now. |
47 | optimization_options.optimize_function_library = false; |
48 | |
49 | return optimization_options; |
50 | } |
51 | |
52 | GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const { |
53 | GrapplerItem item; |
54 | item.id = id; |
55 | item.feed = feed; |
56 | item.fetch = fetch; |
57 | item.init_ops = init_ops; |
58 | item.keep_ops = keep_ops; |
59 | item.expected_init_time = expected_init_time; |
60 | item.save_op = save_op; |
61 | item.restore_op = restore_op; |
62 | item.save_restore_loc_tensor = save_restore_loc_tensor; |
63 | item.queue_runners = queue_runners; |
64 | item.devices_ = devices_; |
65 | item.optimization_options_ = optimization_options_; |
66 | item.graph.Swap(&graph_def); |
67 | return item; |
68 | } |
69 | |
70 | std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const { |
71 | std::vector<const NodeDef*> fanin_nodes; |
72 | TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes)); |
73 | return fanin_nodes; |
74 | } |
75 | |
76 | std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const { |
77 | std::vector<string> enqueue_ops; |
78 | for (const auto& queue_runner : queue_runners) { |
79 | for (const string& enqueue_op : queue_runner.enqueue_op_name()) { |
80 | enqueue_ops.push_back(enqueue_op); |
81 | } |
82 | } |
83 | std::vector<const NodeDef*> fanin_nodes; |
84 | TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes)); |
85 | return fanin_nodes; |
86 | } |
87 | |
88 | std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const { |
89 | std::vector<const NodeDef*> fanin_nodes; |
90 | TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin_nodes)); |
91 | return fanin_nodes; |
92 | } |
93 | |
94 | std::vector<const NodeDef*> GrapplerItem::MainVariables() const { |
95 | std::vector<const NodeDef*> fanin; |
96 | TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin)); |
97 | std::vector<const NodeDef*> vars; |
98 | for (const NodeDef* node : fanin) { |
99 | if (IsVariable(*node)) { |
100 | vars.push_back(node); |
101 | } |
102 | } |
103 | return vars; |
104 | } |
105 | |
106 | std::unordered_set<string> GrapplerItem::NodesToPreserve() const { |
107 | std::unordered_set<string> result; |
108 | for (const string& f : fetch) { |
109 | VLOG(1) << "Add fetch " << f; |
110 | result.insert(NodeName(f)); |
111 | } |
112 | for (const auto& f : feed) { |
113 | VLOG(1) << "Add feed " << f.first; |
114 | result.insert(NodeName(f.first)); |
115 | } |
116 | for (const auto& node : init_ops) { |
117 | result.insert(NodeName(node)); |
118 | } |
119 | for (const auto& node : keep_ops) { |
120 | result.insert(NodeName(node)); |
121 | } |
122 | if (!save_op.empty()) { |
123 | result.insert(NodeName(save_op)); |
124 | } |
125 | if (!restore_op.empty()) { |
126 | result.insert(NodeName(restore_op)); |
127 | } |
128 | if (!save_restore_loc_tensor.empty()) { |
129 | result.insert(NodeName(save_restore_loc_tensor)); |
130 | } |
131 | |
132 | for (const auto& queue_runner : queue_runners) { |
133 | for (const string& enqueue_op : queue_runner.enqueue_op_name()) { |
134 | result.insert(NodeName(enqueue_op)); |
135 | } |
136 | if (!queue_runner.close_op_name().empty()) { |
137 | result.insert(NodeName(queue_runner.close_op_name())); |
138 | } |
139 | if (!queue_runner.cancel_op_name().empty()) { |
140 | result.insert(NodeName(queue_runner.cancel_op_name())); |
141 | } |
142 | } |
143 | |
144 | absl::optional<FunctionLibraryDefinition> fn_library; |
145 | if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) { |
146 | fn_library.emplace(OpRegistry::Global(), graph.library()); |
147 | } |
148 | for (const NodeDef& node : graph.node()) { |
149 | const auto attrs = AttrSlice(&node.attr()); |
150 | |
151 | // Tensorflow functions do not prune stateful or dataset-output ops from |
152 | // the function body (see PruneFunctionBody in common_runtime/function.cc). |
153 | if (!optimization_options_.allow_pruning_stateful_and_dataset_ops && |
154 | (IsStateful(node, &*fn_library) || IsDataset(node))) { |
155 | result.insert(node.name()); |
156 | } |
157 | |
158 | // Do not remove ops with attribute _grappler_do_not_remove. This is useful |
159 | // for debugging. |
160 | bool do_not_remove; |
161 | if (TryGetNodeAttr(attrs, "_grappler_do_not_remove" , &do_not_remove) && |
162 | do_not_remove) { |
163 | result.insert(node.name()); |
164 | } |
165 | } |
166 | |
167 | return result; |
168 | } |
169 | |
170 | const std::unordered_set<string>& GrapplerItem::devices() const { |
171 | return devices_; |
172 | } |
173 | |
174 | Status GrapplerItem::AddDevice(const string& device) { |
175 | DeviceNameUtils::ParsedName name; |
176 | |
177 | if (!DeviceNameUtils::ParseFullName(device, &name)) { |
178 | return errors::InvalidArgument("Invalid device name: device=" , device); |
179 | |
180 | } else if (!name.has_job || !name.has_replica || !name.has_task || |
181 | !name.has_type || !name.has_id) { |
182 | return errors::InvalidArgument("Not a fully defined device name: device=" , |
183 | device); |
184 | } |
185 | |
186 | devices_.insert(DeviceNameUtils::ParsedNameToString(name)); |
187 | return OkStatus(); |
188 | } |
189 | |
190 | Status GrapplerItem::AddDevices(const GrapplerItem& other) { |
191 | std::vector<absl::string_view> invalid_devices; |
192 | for (const string& device : other.devices()) { |
193 | Status added = AddDevice(device); |
194 | if (!added.ok()) invalid_devices.emplace_back(device); |
195 | } |
196 | return invalid_devices.empty() |
197 | ? OkStatus() |
198 | : errors::InvalidArgument("Skipped invalid devices: [" , |
199 | absl::StrJoin(invalid_devices, ", " ), |
200 | "]" ); |
201 | } |
202 | |
203 | Status GrapplerItem::InferDevicesFromGraph() { |
204 | absl::flat_hash_set<absl::string_view> invalid_devices; |
205 | for (const NodeDef& node : graph.node()) { |
206 | Status added = AddDevice(node.device()); |
207 | if (!added.ok()) invalid_devices.insert(node.device()); |
208 | } |
209 | VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", " ) << "]" ; |
210 | return invalid_devices.empty() |
211 | ? OkStatus() |
212 | : errors::InvalidArgument("Skipped invalid devices: [" , |
213 | absl::StrJoin(invalid_devices, ", " ), |
214 | "]" ); |
215 | } |
216 | |
217 | void GrapplerItem::ClearDevices() { devices_.clear(); } |
218 | |
219 | const GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() |
220 | const { |
221 | return optimization_options_; |
222 | } |
223 | |
224 | GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() { |
225 | return optimization_options_; |
226 | } |
227 | |
228 | } // end namespace grappler |
229 | } // end namespace tensorflow |
230 | |