1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/graph_execution_state.h" |
17 | |
18 | #include <memory> |
19 | #include <set> |
20 | #include <string> |
21 | #include <unordered_set> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "absl/container/flat_hash_set.h" |
26 | #include "absl/memory/memory.h" |
27 | #include "absl/strings/str_join.h" |
28 | #include "tensorflow/core/common_runtime/device.h" |
29 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
30 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
31 | #include "tensorflow/core/common_runtime/placer.h" |
32 | #include "tensorflow/core/framework/attr_value.pb.h" |
33 | #include "tensorflow/core/framework/device_factory.h" |
34 | #include "tensorflow/core/framework/function.h" |
35 | #include "tensorflow/core/framework/function.pb.h" |
36 | #include "tensorflow/core/framework/graph.pb.h" |
37 | #include "tensorflow/core/framework/graph_def_util.h" |
38 | #include "tensorflow/core/framework/metrics.h" |
39 | #include "tensorflow/core/framework/node_def.pb.h" |
40 | #include "tensorflow/core/framework/op.h" |
41 | #include "tensorflow/core/framework/tensor.pb.h" |
42 | #include "tensorflow/core/framework/versions.pb.h" |
43 | #include "tensorflow/core/graph/algorithm.h" |
44 | #include "tensorflow/core/graph/collective_order.h" |
45 | #include "tensorflow/core/graph/graph.h" |
46 | #include "tensorflow/core/graph/subgraph.h" |
47 | #include "tensorflow/core/graph/tensor_id.h" |
48 | #include "tensorflow/core/graph/validate.h" |
49 | #include "tensorflow/core/lib/core/errors.h" |
50 | #include "tensorflow/core/lib/core/status.h" |
51 | #include "tensorflow/core/lib/gtl/flatset.h" |
52 | #include "tensorflow/core/lib/strings/stringprintf.h" |
53 | #include "tensorflow/core/platform/logging.h" |
54 | #include "tensorflow/core/platform/types.h" |
55 | #include "tensorflow/core/util/device_name_utils.h" |
56 | #include "tensorflow/core/util/util.h" |
57 | |
58 | #ifndef IS_MOBILE_PLATFORM |
59 | #include "tensorflow/core/grappler/clusters/virtual_cluster.h" |
60 | #include "tensorflow/core/grappler/grappler_item.h" |
61 | #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" |
62 | #endif // IS_MOBILE_PLATFORM |
63 | |
64 | namespace tensorflow { |
65 | |
66 | namespace { |
67 | bool IsCollectiveV2(const string& op) { |
68 | return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" || |
69 | op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2" ; |
70 | } |
71 | } // namespace |
72 | |
73 | GraphExecutionState::GraphExecutionState( |
74 | std::unique_ptr<GraphDef>&& graph_def, |
75 | std::unique_ptr<FunctionLibraryDefinition>&& flib_def, |
76 | const GraphExecutionStateOptions& options) |
77 | : stateful_placements_(options.stateful_placements), |
78 | original_graph_def_(std::move(graph_def)), |
79 | device_set_(options.device_set), |
80 | session_options_(options.session_options), |
81 | session_handle_(options.session_handle), |
82 | flib_def_(std::move(flib_def)), |
83 | graph_(nullptr) {} |
84 | |
85 | GraphExecutionState::~GraphExecutionState() { |
86 | node_name_to_cost_id_map_.clear(); |
87 | delete graph_; |
88 | } |
89 | |
90 | /* static */ Status GraphExecutionState::MakeForBaseGraph( |
91 | GraphDef&& graph_def, const GraphExecutionStateOptions& options, |
92 | std::unique_ptr<GraphExecutionState>* out_state) { |
93 | #ifndef __ANDROID__ |
94 | VLOG(4) << "Graph proto is \n" << graph_def.DebugString(); |
95 | #endif // __ANDROID__ |
96 | |
97 | auto flib_def = std::make_unique<FunctionLibraryDefinition>( |
98 | OpRegistry::Global(), graph_def.library()); |
99 | |
100 | TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&graph_def, *flib_def, 0)); |
101 | |
102 | if (options.session_options->config.graph_options().place_pruned_graph() || |
103 | !options.session_options->config.experimental() |
104 | .optimize_for_static_graph()) { |
105 | auto ret = absl::WrapUnique(new GraphExecutionState( |
106 | std::make_unique<GraphDef>(std::move(graph_def)), std::move(flib_def), |
107 | options)); |
108 | |
109 | // When place_pruned_graph is true, a different Graph* will be initialized |
110 | // each time we prune the original graph, so there is no need to |
111 | // construct a Graph* in this case. |
112 | if (!options.session_options->config.graph_options().place_pruned_graph()) { |
113 | auto base_graph = std::make_unique<Graph>(OpRegistry::Global()); |
114 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_, |
115 | base_graph.get())); |
116 | TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); |
117 | } |
118 | *out_state = std::move(ret); |
119 | } else { |
120 | auto ret = absl::WrapUnique( |
121 | new GraphExecutionState(nullptr, std::move(flib_def), options)); |
122 | auto base_graph = std::make_unique<Graph>(OpRegistry::Global()); |
123 | TF_RETURN_IF_ERROR( |
124 | ConvertGraphDefToGraph({}, std::move(graph_def), base_graph.get())); |
125 | TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); |
126 | *out_state = std::move(ret); |
127 | } |
128 | return OkStatus(); |
129 | } |
130 | |
131 | /* static */ Status GraphExecutionState::MakeForPrunedGraph( |
132 | const GraphExecutionState& base_execution_state, |
133 | const GraphExecutionStateOptions& options, |
134 | const BuildGraphOptions& subgraph_options, |
135 | std::unique_ptr<GraphExecutionState>* out_state, |
136 | std::unique_ptr<ClientGraph>* out_client_graph) { |
137 | if (!(base_execution_state.session_options_->config.graph_options() |
138 | .place_pruned_graph() && |
139 | options.session_options->config.graph_options().place_pruned_graph())) { |
140 | return errors::Internal( |
141 | "MakeForPrunedGraph is only supported when the `place_pruned_graph` " |
142 | "option is true." ); |
143 | } |
144 | if (!base_execution_state.original_graph_def_) { |
145 | // NOTE(mrry): By adding this restriction, which matches the only current |
146 | // usage of this (fairly obscure) method, we do not need to store a |
147 | // redundant copy of the original graph in `*out_state`. |
148 | return errors::Internal( |
149 | "MakeForPrunedGraph is only supported when `base_execution_state` is " |
150 | "the Session-level `GraphExecutionState`." ); |
151 | } |
152 | |
153 | // NOTE(mrry): This makes a copy of `graph_def`, which is |
154 | // regrettable. We could make `GraphDef` objects shareable between |
155 | // execution states to optimize pruned graph execution, but since |
156 | // this case is primarily used for interactive sessions, we make the |
157 | // bet that graph construction is not performance-critical. (Note |
158 | // also that the previous version used `Extend()`, which is strictly |
159 | // more expensive than copying a `GraphDef`.) |
160 | GraphDef temp(*base_execution_state.original_graph_def_); |
161 | auto flib_def = std::make_unique<FunctionLibraryDefinition>( |
162 | OpRegistry::Global(), temp.library()); |
163 | TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&temp, *flib_def, 0)); |
164 | auto ret = absl::WrapUnique( |
165 | new GraphExecutionState(nullptr, std::move(flib_def), options)); |
166 | |
167 | auto base_graph = std::make_unique<Graph>(OpRegistry::Global()); |
168 | TF_RETURN_IF_ERROR( |
169 | ConvertGraphDefToGraph({}, std::move(temp), base_graph.get())); |
170 | |
171 | // Rewrite the graph before placement. |
172 | ret->rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata); |
173 | TF_RETURN_IF_ERROR(ret->PruneGraph(subgraph_options, base_graph.get(), |
174 | ret->rewrite_metadata_.get())); |
175 | TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); |
176 | TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph)); |
177 | *out_state = std::move(ret); |
178 | return OkStatus(); |
179 | } |
180 | |
181 | Status GraphExecutionState::Extend( |
182 | const GraphDef& extension_def, |
183 | std::unique_ptr<GraphExecutionState>* out) const { |
184 | if (session_options_->config.experimental().optimize_for_static_graph()) { |
185 | return errors::FailedPrecondition( |
186 | "Extending the graph is not supported when " |
187 | "`optimize_for_static_graph` is true." ); |
188 | } |
189 | |
190 | GraphDef gdef; |
191 | |
192 | // 1. Copy the function library. |
193 | TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library())); |
194 | *gdef.mutable_library() = flib_def_->ToProto(); |
195 | |
196 | // 2. Build an index of the new node names. |
197 | std::unordered_set<string> new_names; |
198 | for (const NodeDef& node : extension_def.node()) { |
199 | new_names.insert(node.name()); |
200 | } |
201 | |
202 | // 3. Add the non-duplicates from the old graph to the new graph. |
203 | // Return an error if the same node name appears in both the |
204 | // old graph and the extension. |
205 | for (const NodeDef& node : original_graph_def_->node()) { |
206 | if (new_names.count(node.name()) == 0) { |
207 | *gdef.add_node() = node; |
208 | } else { |
209 | return errors::InvalidArgument( |
210 | "GraphDef argument to Extend includes node '" , node.name(), |
211 | "', which was created by a previous call to Create or Extend in this " |
212 | "session." ); |
213 | } |
214 | } |
215 | |
216 | // 4. Merge the versions field. |
217 | int old_node_size = gdef.node_size(); |
218 | gdef.mutable_node()->MergeFrom(extension_def.node()); |
219 | TF_RETURN_IF_ERROR( |
220 | AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size)); |
221 | // Merge versions |
222 | if (gdef.has_versions()) { |
223 | if (gdef.versions().producer() != extension_def.versions().producer()) { |
224 | return errors::InvalidArgument( |
225 | "Can't extend GraphDef at version " , gdef.versions().producer(), |
226 | " with graph at version " , extension_def.versions().producer()); |
227 | } |
228 | VersionDef* versions = gdef.mutable_versions(); |
229 | versions->set_min_consumer(std::max( |
230 | versions->min_consumer(), extension_def.versions().min_consumer())); |
231 | if (extension_def.versions().bad_consumers_size()) { |
232 | // Add new bad_consumers that aren't already marked bad. |
233 | // |
234 | // Note: This implementation is quadratic time if there are many calls to |
235 | // ExtendLocked with many bad consumers. Since this is unlikely, and |
236 | // fixing it would require data structures outside of this routine, |
237 | // quadratic time it is. |
238 | auto* bad_consumers = versions->mutable_bad_consumers(); |
239 | const std::unordered_set<int> existing(bad_consumers->begin(), |
240 | bad_consumers->end()); |
241 | for (const int v : extension_def.versions().bad_consumers()) { |
242 | if (existing.find(v) == existing.end()) { |
243 | bad_consumers->Add(v); |
244 | } |
245 | } |
246 | } |
247 | |
248 | } else { |
249 | gdef.mutable_versions()->CopyFrom(extension_def.versions()); |
250 | } |
251 | |
252 | // 5. Validate that the final graphdef is valid. |
253 | if (gdef.versions().producer() >= 5) { |
254 | // Validate the graph: we assume that merging two valid graphs |
255 | // should maintain graph validity. |
256 | TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_)); |
257 | } |
258 | |
259 | // 6. Add the extension. |
260 | GraphExecutionStateOptions combined_options; |
261 | combined_options.device_set = device_set_; |
262 | combined_options.session_options = session_options_; |
263 | combined_options.session_handle = session_handle_; |
264 | combined_options.stateful_placements = stateful_placements_; |
265 | |
266 | TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *flib_def_, 0)); |
267 | auto flib_def = std::make_unique<FunctionLibraryDefinition>( |
268 | OpRegistry::Global(), gdef.library()); |
269 | auto new_execution_state = absl::WrapUnique( |
270 | new GraphExecutionState(std::make_unique<GraphDef>(std::move(gdef)), |
271 | std::move(flib_def), combined_options)); |
272 | |
273 | if (!session_options_->config.graph_options().place_pruned_graph()) { |
274 | auto base_graph = std::make_unique<Graph>(OpRegistry::Global()); |
275 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( |
276 | {}, *new_execution_state->original_graph_def_, base_graph.get())); |
277 | TF_RETURN_IF_ERROR( |
278 | new_execution_state->InitBaseGraph(std::move(base_graph))); |
279 | } |
280 | *out = std::move(new_execution_state); |
281 | |
282 | // NOTE(mrry): Extend() is likely to be used for non-throughput-sensitive |
283 | // interactive workloads, but in future we may want to transfer other |
284 | // parts of the placement and/or cost model. |
285 | return OkStatus(); |
286 | } |
287 | |
288 | void GraphExecutionState::SaveStatefulNodes(Graph* graph) { |
289 | for (Node* n : graph->nodes()) { |
290 | if (n->op_def().is_stateful()) { |
291 | VLOG(2) << "Saving " << n->DebugString(); |
292 | stateful_placements_[n->name()] = n->assigned_device_name(); |
293 | } |
294 | } |
295 | } |
296 | |
297 | void GraphExecutionState::RestoreStatefulNodes(Graph* graph) { |
298 | for (Node* n : graph->nodes()) { |
299 | if (n->op_def().is_stateful()) { |
300 | auto iter = stateful_placements_.find(n->name()); |
301 | if (iter != stateful_placements_.end()) { |
302 | n->set_assigned_device_name(iter->second); |
303 | VLOG(2) << "Restored " << n->DebugString(); |
304 | } |
305 | } |
306 | } |
307 | } |
308 | |
309 | namespace { |
310 | |
311 | class TensorConnectionPruneRewrite : public subgraph::PruneRewrite { |
312 | public: |
313 | TensorConnectionPruneRewrite(const string* endpoint_name, |
314 | NodeBuilder::NodeOut from_tensor) |
315 | : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */), |
316 | from_tensor_(std::move(from_tensor)) {} |
317 | |
318 | Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, |
319 | Node** out_node) override { |
320 | Status s; |
321 | auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) { |
322 | if (n == feed_tensor.node) { |
323 | s.Update(errors::InvalidArgument( |
324 | "Requested Tensor connection between nodes \"" , |
325 | feed_tensor.node->name(), "\" and \"" , from_tensor_.node->name(), |
326 | "\" would create a cycle." )); |
327 | } |
328 | }; |
329 | ReverseDFSFrom(*g, {from_tensor_.node}, std::move(check_no_cycle_fn), |
330 | nullptr); |
331 | TF_RETURN_IF_ERROR(s); |
332 | |
333 | TF_RETURN_IF_ERROR( |
334 | NodeBuilder(strings::StrCat("_identity_" , feed_tensor.node->name(), "_" , |
335 | feed_tensor.index), |
336 | "Identity" ) |
337 | .Input(from_tensor_) |
338 | .Attr("T" , |
339 | BaseType(from_tensor_.node->output_type(from_tensor_.index))) |
340 | .Finalize(g, out_node)); |
341 | |
342 | (*out_node)->set_assigned_device_name( |
343 | feed_tensor.node->assigned_device_name()); |
344 | return OkStatus(); |
345 | } |
346 | |
347 | private: |
348 | NodeBuilder::NodeOut from_tensor_; |
349 | }; |
350 | |
351 | template <class Map> |
352 | Status LookupDevice(const DeviceSet& device_set, const string& tensor_name, |
353 | const Map& tensor2device, |
354 | const tensorflow::DeviceAttributes** out_device_attrs) { |
355 | *out_device_attrs = nullptr; |
356 | if (tensor2device.empty()) { |
357 | *out_device_attrs = &device_set.client_device()->attributes(); |
358 | return OkStatus(); |
359 | } |
360 | const auto it = tensor2device.find(tensor_name); |
361 | if (it == tensor2device.end()) { |
362 | *out_device_attrs = &device_set.client_device()->attributes(); |
363 | return OkStatus(); |
364 | } |
365 | DeviceNameUtils::ParsedName parsed_name; |
366 | if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) { |
367 | return errors::InvalidArgument("Invalid device name ('" , it->second, |
368 | "') provided for the tensor '" , tensor_name, |
369 | "' in CallableOptions" ); |
370 | } |
371 | Device* device = device_set.FindDeviceByName( |
372 | DeviceNameUtils::ParsedNameToString(parsed_name)); |
373 | if (device == nullptr) { |
374 | return errors::InvalidArgument("Device '" , it->second, |
375 | "' specified for tensor '" , tensor_name, |
376 | "' in CallableOptions does not exist" ); |
377 | } |
378 | *out_device_attrs = &device->attributes(); |
379 | return OkStatus(); |
380 | } |
381 | |
382 | struct TensorAndDevice { |
383 | // WARNING: backing memory for the 'tensor' field is NOT owend. |
384 | const TensorId tensor; |
385 | // WARNING: device pointer is not owned, so must outlive TensorAndDevice. |
386 | const DeviceAttributes* device; |
387 | }; |
388 | |
389 | // Tensors of some DataTypes cannot placed in device memory as feeds or |
390 | // fetches. Validate against a allowlist of those known to work. |
391 | bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) { |
392 | // The mechanism for supporting feeds of device-backed Tensors requires |
393 | // the _Arg kernel to be registered for the corresponding type (and that |
394 | // the input to the kernel be in device and not host memory). |
395 | // |
396 | // The mechanism for supporting fetches of device-backed Tensors requires |
397 | // the _Retval kernel to be registered for the corresponding type (and |
398 | // that the output is produced in device and not host memory). |
399 | // |
400 | // For now, we return true iff there are _Arg AND _Retval kernels for dtype on |
401 | // the device. False negatives are okay, false positives would be bad. |
402 | // |
403 | // TODO(ashankar): Instead of a allowlist here, perhaps we could query |
404 | // the kernel registry for _Arg and _Retval kernels instead. |
405 | if (device_type == DEVICE_CPU) return true; |
406 | if (device_type != DEVICE_GPU && |
407 | !DeviceFactory::IsPluggableDevice(device_type)) |
408 | return false; |
409 | switch (dtype) { |
410 | case DT_BFLOAT16: |
411 | case DT_BOOL: |
412 | case DT_COMPLEX128: |
413 | case DT_COMPLEX64: |
414 | case DT_DOUBLE: |
415 | case DT_FLOAT: |
416 | case DT_HALF: |
417 | case DT_INT16: |
418 | case DT_INT64: |
419 | case DT_INT8: |
420 | case DT_UINT16: |
421 | case DT_UINT8: |
422 | return true; |
423 | default: |
424 | return false; |
425 | } |
426 | } |
427 | |
428 | Status ValidateFeedAndFetchDevices( |
429 | const Graph& graph, |
430 | const std::vector<TensorAndDevice>& tensors_and_devices) { |
431 | if (tensors_and_devices.empty()) return OkStatus(); |
432 | std::vector<bool> found(tensors_and_devices.size(), false); |
433 | for (const Node* node : graph.nodes()) { |
434 | // Linearly looping through all nodes and then all feed+fetch tensors isn't |
435 | // quite efficient. At the time of this writing, the expectation was that |
436 | // tensors_and_devices.size() is really small in practice, so this won't be |
437 | // problematic. |
438 | // Revist and make a more efficient lookup possible if needed (e.g., perhaps |
439 | // Graph can maintain a map from node name to Node*). |
440 | for (int i = 0; i < tensors_and_devices.size(); ++i) { |
441 | const TensorAndDevice& td = tensors_and_devices[i]; |
442 | if (td.tensor.first != node->name()) continue; |
443 | found[i] = true; |
444 | TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second)); |
445 | const DataType dtype = node->output_type(td.tensor.second); |
446 | if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) { |
447 | return errors::Unimplemented( |
448 | "Cannot feed or fetch tensor '" , td.tensor.ToString(), |
449 | "' from device " , td.device->name(), " as feeding/fetching from " , |
450 | td.device->device_type(), " devices is not yet supported for " , |
451 | DataTypeString(dtype), " tensors" ); |
452 | } |
453 | } |
454 | } |
455 | for (int i = 0; i < found.size(); ++i) { |
456 | if (!found[i]) { |
457 | return errors::InvalidArgument( |
458 | "Tensor " , tensors_and_devices[i].tensor.ToString(), |
459 | ", specified in either feed_devices or fetch_devices was not found " |
460 | "in the Graph" ); |
461 | } |
462 | } |
463 | return OkStatus(); |
464 | } |
465 | |
466 | Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node, |
467 | PartialTensorShape* shape, |
468 | DataType* type) { |
469 | static const gtl::FlatSet<string>* const kHasExplicitShapeAttribute = |
470 | CHECK_NOTNULL((new gtl::FlatSet<string>{ |
471 | "Placeholder" , "PlaceholderV2" , "PlaceholderWithDefault" , |
472 | "ParallelConcat" , "ImmutableConst" , "_ParallelConcatStart" , |
473 | "InfeedDequeue" , "OutfeedDequeue" , "CollectiveBcastSend" , |
474 | "CollectiveBcastRecv" , "AccumulateNV2" , "VariableV2" , "Variable" , |
475 | "TemporaryVariable" , "NcclBroadcast" , "_ScopedAllocator" , |
476 | "_ScopedAllocatorConcat" })); |
477 | |
478 | // All the node types handled here have their output datatype set in |
479 | // either attribute 'dtype' or 'T'. |
480 | if (!TryGetNodeAttr(node, "dtype" , type) && |
481 | !TryGetNodeAttr(node, "T" , type)) { |
482 | return errors::InvalidArgument( |
483 | "Could not determine output type for feed node: " , node.name(), |
484 | " of type " , node.op()); |
485 | } |
486 | |
487 | // First handle the case of feeding a const node. |
488 | if (node.op() == "Const" && HasNodeAttr(node, "value" )) { |
489 | *shape = |
490 | PartialTensorShape(node.attr().at("value" ).tensor().tensor_shape()); |
491 | } else if (kHasExplicitShapeAttribute->find(node.op()) != |
492 | kHasExplicitShapeAttribute->end()) { |
493 | TF_RETURN_IF_ERROR(GetNodeAttr(node, "shape" , shape)); |
494 | } else { |
495 | return errors::InvalidArgument("Could not determine shape for feed node: " , |
496 | node.name(), " of type " , node.op()); |
497 | } |
498 | return OkStatus(); |
499 | } |
500 | |
501 | } // namespace |
502 | |
503 | Status GraphExecutionState::PruneGraph( |
504 | const BuildGraphOptions& options, Graph* graph, |
505 | subgraph::RewriteGraphMetadata* out_rewrite_metadata) { |
506 | std::vector<std::unique_ptr<subgraph::PruneRewrite>> feed_rewrites; |
507 | feed_rewrites.reserve(options.callable_options.feed_size()); |
508 | std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites; |
509 | fetch_rewrites.reserve(options.callable_options.fetch_size()); |
510 | if (options.use_function_convention) { |
511 | std::vector<TensorAndDevice> tensors_and_devices; |
512 | for (int i = 0; i < options.callable_options.feed_size(); ++i) { |
513 | // WARNING: feed MUST be a reference, since ArgFeedRewrite and |
514 | // tensors_and_devices holds on to its address. |
515 | const string& feed = options.callable_options.feed(i); |
516 | const DeviceAttributes* device_info; |
517 | TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed, |
518 | options.callable_options.feed_devices(), |
519 | &device_info)); |
520 | feed_rewrites.emplace_back( |
521 | new subgraph::ArgFeedRewrite(&feed, device_info, i)); |
522 | tensors_and_devices.push_back({ParseTensorName(feed), device_info}); |
523 | } |
524 | if (!options.callable_options.fetch_devices().empty() && |
525 | !options.callable_options.fetch_skip_sync()) { |
526 | return errors::Unimplemented( |
527 | "CallableOptions.fetch_skip_sync = false is not yet implemented. You " |
528 | "can set it to true instead, but MUST ensure that Device::Sync() is " |
529 | "invoked on the Device corresponding to the fetched tensor before " |
530 | "dereferencing the Tensor's memory." ); |
531 | } |
532 | for (int i = 0; i < options.callable_options.fetch_size(); ++i) { |
533 | // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and |
534 | // tensors_and_devices holds on to its address. |
535 | const string& fetch = options.callable_options.fetch(i); |
536 | const DeviceAttributes* device_info; |
537 | TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch, |
538 | options.callable_options.fetch_devices(), |
539 | &device_info)); |
540 | fetch_rewrites.emplace_back( |
541 | new subgraph::RetvalFetchRewrite(&fetch, device_info, i)); |
542 | tensors_and_devices.push_back({ParseTensorName(fetch), device_info}); |
543 | } |
544 | TF_RETURN_IF_ERROR( |
545 | ValidateFeedAndFetchDevices(*graph, tensors_and_devices)); |
546 | } else { |
547 | if (!options.callable_options.feed_devices().empty() || |
548 | !options.callable_options.fetch_devices().empty()) { |
549 | return errors::Unimplemented( |
550 | "CallableOptions::feed_devices and CallableOptions::fetch_devices " |
551 | "to configure feeding/fetching tensors to/from device memory is not " |
552 | "yet supported when using a remote session." ); |
553 | } |
554 | const DeviceAttributes* device_info = |
555 | &device_set_->client_device()->attributes(); |
556 | for (const string& feed : options.callable_options.feed()) { |
557 | feed_rewrites.emplace_back( |
558 | new subgraph::RecvFeedRewrite(&feed, device_info)); |
559 | } |
560 | for (const string& fetch : options.callable_options.fetch()) { |
561 | fetch_rewrites.emplace_back( |
562 | new subgraph::SendFetchRewrite(&fetch, device_info)); |
563 | } |
564 | } |
565 | |
566 | for (const TensorConnection& tensor_connection : |
567 | options.callable_options.tensor_connection()) { |
568 | Node* from_node = nullptr; |
569 | TensorId from_id(ParseTensorName(tensor_connection.from_tensor())); |
570 | |
571 | for (Node* n : graph->nodes()) { |
572 | if (n->name() == from_id.first) { |
573 | from_node = n; |
574 | break; |
575 | } |
576 | } |
577 | if (from_node == nullptr) { |
578 | return errors::InvalidArgument( |
579 | "Requested tensor connection from unknown node: \"" , |
580 | tensor_connection.to_tensor(), "\"." ); |
581 | } |
582 | if (from_id.second >= from_node->num_outputs()) { |
583 | return errors::InvalidArgument( |
584 | "Requested tensor connection from unknown edge: \"" , |
585 | tensor_connection.to_tensor(), |
586 | "\" (actual number of outputs = " , from_node->num_outputs(), ")." ); |
587 | } |
588 | |
589 | feed_rewrites.emplace_back(new TensorConnectionPruneRewrite( |
590 | &tensor_connection.to_tensor(), {from_node, from_id.second})); |
591 | } |
592 | |
593 | std::vector<string> target_node_names( |
594 | options.callable_options.target().begin(), |
595 | options.callable_options.target().end()); |
596 | TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( |
597 | graph, feed_rewrites, fetch_rewrites, target_node_names, |
598 | out_rewrite_metadata)); |
599 | |
600 | CHECK_EQ(out_rewrite_metadata->feed_types.size(), |
601 | options.callable_options.feed_size() + |
602 | options.callable_options.tensor_connection_size()); |
603 | for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) { |
604 | out_rewrite_metadata->feed_types.pop_back(); |
605 | } |
606 | return OkStatus(); |
607 | } |
608 | |
609 | Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) { |
610 | // Save stateful placements before placing. |
611 | RestoreStatefulNodes(new_graph.get()); |
612 | |
613 | GraphOptimizationPassOptions optimization_options; |
614 | optimization_options.session_handle = session_handle_; |
615 | optimization_options.session_options = session_options_; |
616 | optimization_options.graph = &new_graph; |
617 | optimization_options.flib_def = flib_def_.get(); |
618 | optimization_options.device_set = device_set_; |
619 | |
620 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
621 | OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); |
622 | |
623 | Placer placer(new_graph.get(), "" , flib_def_.get(), device_set_, |
624 | /* default_local_device= */ nullptr, |
625 | session_options_ == nullptr || |
626 | session_options_->config.allow_soft_placement(), |
627 | session_options_ != nullptr && |
628 | session_options_->config.log_device_placement()); |
629 | // TODO(mrry): Consider making the Placer cancellable. |
630 | TF_RETURN_IF_ERROR(placer.Run()); |
631 | |
632 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
633 | OptimizationPassRegistry::POST_PLACEMENT, optimization_options)); |
634 | |
635 | for (const Node* n : new_graph->nodes()) { |
636 | VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id(); |
637 | node_name_to_cost_id_map_[n->name()] = n->cost_id(); |
638 | } |
639 | |
640 | SaveStatefulNodes(new_graph.get()); |
641 | graph_ = new_graph.release(); |
642 | return OkStatus(); |
643 | } |
644 | |
645 | Status GraphExecutionState::OptimizeGraph( |
646 | const BuildGraphOptions& options, const Graph& graph, |
647 | const FunctionLibraryDefinition* flib_def, |
648 | std::unique_ptr<Graph>* optimized_graph, |
649 | std::unique_ptr<FunctionLibraryDefinition>* optimized_flib) { |
650 | #ifdef IS_MOBILE_PLATFORM |
651 | return errors::InvalidArgument("Mobile platforms not supported" ); |
652 | #else |
653 | if (session_options_->config.graph_options().place_pruned_graph()) { |
654 | return errors::InvalidArgument("Can't optimize a pruned graph" ); |
655 | } |
656 | |
657 | if (grappler::MetaOptimizerEnabled(session_options_->config)) { |
658 | // Here we build the GrapplerItem before calling the optimizer. |
659 | grappler::GrapplerItem item; |
660 | item.id = "tf_graph" ; |
661 | |
662 | // Add devices to the GrapplerItem |
663 | // It's ok to skip invalid device annotations in Grappler. |
664 | for (const Device* d : device_set_->devices()) { |
665 | Status added_device = item.AddDevice(d->name()); |
666 | if (!added_device.ok()) VLOG(3) << added_device.error_message(); |
667 | } |
668 | VLOG(3) << "Grappler available devices: " |
669 | << absl::StrJoin(item.devices(), ", " ); |
670 | |
671 | // Add fetches to the GrapplerItem. |
672 | item.fetch.insert(item.fetch.end(), |
673 | options.callable_options.fetch().begin(), |
674 | options.callable_options.fetch().end()); |
675 | item.fetch.insert(item.fetch.end(), |
676 | options.callable_options.target().begin(), |
677 | options.callable_options.target().end()); |
678 | |
679 | for (const TensorConnection& tensor_connection : |
680 | options.callable_options.tensor_connection()) { |
681 | item.fetch.push_back(tensor_connection.from_tensor()); |
682 | } |
683 | |
684 | // Add feeds to the GrapplerItem if we know them. |
685 | absl::flat_hash_set<absl::string_view> node_names; |
686 | if (!(options.callable_options.feed().empty() && |
687 | options.callable_options.tensor_connection().empty())) { |
688 | std::vector<SafeTensorId> feeds; |
689 | |
690 | for (const string& feed : options.callable_options.feed()) { |
691 | feeds.emplace_back(ParseTensorName(feed)); |
692 | } |
693 | for (const TensorConnection& tensor_connection : |
694 | options.callable_options.tensor_connection()) { |
695 | feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor())); |
696 | } |
697 | |
698 | // For feeds with tensor index 0 we try to find the corresponding node in |
699 | // the graph to infer feed data type and shape. |
700 | absl::flat_hash_set<absl::string_view> feed_nodes; |
701 | |
702 | // For feeds with tensor index larger than 0, we can't infer data type or |
703 | // shape from the graph. Currently we only support type and shape |
704 | // inference from a small set of node types: Placeholder, Const, etc... |
705 | for (const SafeTensorId& feed : feeds) { |
706 | if (feed.index() > 0) { |
707 | VLOG(3) << "Add undefined feed for: " << feed.ToString(); |
708 | Tensor fake_input(DT_INVALID, {0}); |
709 | item.feed.emplace_back(feed.ToString(), fake_input); |
710 | } else { |
711 | VLOG(3) << "Add node for feed inference: " << feed.ToString(); |
712 | feed_nodes.insert(feed.node()); |
713 | continue; |
714 | } |
715 | } |
716 | |
717 | // For feeds with tensor index == 0 we try to infer data type and tensor |
718 | // shape from the graph, by looking at the fed node attributes. |
719 | node_names.reserve(graph.num_nodes()); |
720 | for (const Node* node : graph.nodes()) { |
721 | node_names.insert(node->name()); |
722 | if (feed_nodes.find(node->name()) == feed_nodes.end()) continue; |
723 | |
724 | // Try to get the type and shape of the feed node. |
725 | PartialTensorShape partial_shape; |
726 | DataType type; |
727 | Status st = GetFeedShapeAndTypeFromAttribute(node->def(), |
728 | &partial_shape, &type); |
729 | |
730 | // Failed to get type and shape of the feed node. |
731 | if (!st.ok()) { |
732 | VLOG(3) << "Failed to infer feed node type and shape." |
733 | << " Add undefined feed for: " << node->name(); |
734 | Tensor fake_input(DT_INVALID, {0}); |
735 | item.feed.emplace_back(node->name(), fake_input); |
736 | continue; |
737 | } |
738 | |
739 | // If the shape of the placeholder is only partially known, we are free |
740 | // to set unknown dimensions of its shape to any value we desire. We |
741 | // choose 0 to minimize the memory impact. Note that this only matters |
742 | // if an optimizer chooses to run the graph. |
743 | TensorShape shape; |
744 | if (partial_shape.unknown_rank()) { |
745 | shape = TensorShape({0}); |
746 | } else { |
747 | for (int i = 0; i < partial_shape.dims(); ++i) { |
748 | if (partial_shape.dim_size(i) < 0) { |
749 | partial_shape.set_dim(i, 0); |
750 | } |
751 | } |
752 | if (!partial_shape.AsTensorShape(&shape)) { |
753 | return errors::InvalidArgument( |
754 | "Could not derive shape for feed node: " , |
755 | node->def().DebugString()); |
756 | } |
757 | } |
758 | |
759 | VLOG(3) << "Add feed for: " << node->name() << "; type: " << type |
760 | << "; shape: " << shape; |
761 | Tensor fake_input(type, shape); |
762 | item.feed.emplace_back(node->name(), fake_input); |
763 | } |
764 | } |
765 | |
766 | // Validate that the feeds and fetches are valid. |
767 | if (node_names.empty()) { |
768 | // Collect all node names in the graph if we didn't already. |
769 | node_names.reserve(graph.num_nodes()); |
770 | for (const Node* node : graph.nodes()) { |
771 | node_names.insert(node->name()); |
772 | } |
773 | } |
774 | for (const auto& feed : item.feed) { |
775 | SafeTensorId tensor_id = ParseTensorName(feed.first); |
776 | if (node_names.find(tensor_id.node()) == node_names.end()) { |
777 | return errors::InvalidArgument("Invalid feed, no such node in graph: " , |
778 | feed.first); |
779 | } |
780 | } |
781 | for (const auto& fetch : item.fetch) { |
782 | SafeTensorId tensor_id = ParseTensorName(fetch); |
783 | if (node_names.find(tensor_id.node()) == node_names.end()) { |
784 | return errors::InvalidArgument("Invalid fetch, no such node in graph: " , |
785 | fetch); |
786 | } |
787 | } |
788 | |
789 | // Convert Graph to GraphDef and add it to the GrapplerItem. |
790 | graph.ToGraphDef(&item.graph); |
791 | // TODO(b/114748242): Add a unit test to test this bug fix. |
792 | if (flib_def) { |
793 | *item.graph.mutable_library() = flib_def->ToProto(); |
794 | } |
795 | |
796 | // Construct a virtual cluster and find the cpu_device, which the |
797 | // ConstantFolding optimizer will use for partial evaluation of the graph. |
798 | grappler::VirtualCluster cluster(device_set_); |
799 | Device* cpu_device = nullptr; |
800 | for (const auto& device : device_set_->devices()) { |
801 | if (device->parsed_name().id == 0 && |
802 | StringPiece(device->parsed_name().type) == "CPU" && |
803 | device->GetAllocator(AllocatorAttributes()) != nullptr) { |
804 | cpu_device = device; |
805 | } |
806 | } |
807 | |
808 | // Now we can run the MetaOptimizer on the constructed GrapplerItem. |
809 | GraphDef new_graph; |
810 | TF_RETURN_IF_ERROR( |
811 | grappler::RunMetaOptimizer(std::move(item), session_options_->config, |
812 | cpu_device, &cluster, &new_graph)); |
813 | |
814 | // Merge optimized graph function library with an original library. |
815 | // Optimized graph might have new functions specialized for it's |
816 | // instantiation context (see Grappler function optimizer), and modified |
817 | // function body for the existing functions. |
818 | optimized_flib->reset(new FunctionLibraryDefinition(*flib_def)); |
819 | |
820 | for (const FunctionDef& fdef : new_graph.library().function()) { |
821 | const string& func_name = fdef.signature().name(); |
822 | |
823 | if ((*optimized_flib)->Contains(func_name)) { |
824 | VLOG(3) << "Replace function: name=" << func_name; |
825 | TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef)); |
826 | } else { |
827 | VLOG(3) << "Add new function: name=" << func_name; |
828 | TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef)); |
829 | } |
830 | } |
831 | optimized_graph->reset(new Graph(OpRegistry::Global())); |
832 | |
833 | // Convert the optimized GraphDef back to a Graph. |
834 | GraphConstructorOptions opts; |
835 | opts.allow_internal_ops = true; |
836 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(new_graph), |
837 | optimized_graph->get())); |
838 | // The graph conversion sets the requested device names but not the |
839 | // assigned device names. However, since at this point the graph is placed |
840 | // TF expects an assigned device name for every node. Therefore we copy |
841 | // the requested device into the assigned device field. |
842 | for (Node* node : optimized_graph->get()->nodes()) { |
843 | node->set_assigned_device_name(node->requested_device()); |
844 | } |
845 | return OkStatus(); |
846 | } else { |
847 | return errors::InvalidArgument("Meta Optimizer disabled" ); |
848 | } |
849 | #endif // IS_MOBILE_PLATFORM |
850 | } |
851 | |
852 | Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, |
853 | std::unique_ptr<ClientGraph>* out) { |
854 | VLOG(1) << "BuildGraph" ; |
855 | const uint64 start_time_usecs = Env::Default()->NowMicros(); |
856 | if (!graph_) { |
857 | // It is only valid to call this method directly when the original graph |
858 | // was created with the option `place_pruned_graph == false`. |
859 | return errors::Internal( |
860 | "Attempted to prune a graph that has not been fully initialized." ); |
861 | } |
862 | |
863 | // Grappler optimization might change the structure of a graph itself, and |
864 | // also it can add/prune functions to/from the library. |
865 | std::unique_ptr<Graph> optimized_graph; |
866 | std::unique_ptr<FunctionLibraryDefinition> optimized_flib; |
867 | |
868 | Status s = OptimizeGraph(options, *graph_, flib_def_.get(), &optimized_graph, |
869 | &optimized_flib); |
870 | if (!s.ok()) { |
871 | VLOG(2) << "Grappler optimization failed. Error: " << s.error_message(); |
872 | // Simply copy the original graph and the function library if we couldn't |
873 | // optimize it. |
874 | optimized_graph.reset(new Graph(flib_def_.get())); |
875 | CopyGraph(*graph_, optimized_graph.get()); |
876 | optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_)); |
877 | } |
878 | |
879 | subgraph::RewriteGraphMetadata rewrite_metadata; |
880 | if (session_options_ == nullptr || |
881 | !session_options_->config.graph_options().place_pruned_graph()) { |
882 | TF_RETURN_IF_ERROR( |
883 | PruneGraph(options, optimized_graph.get(), &rewrite_metadata)); |
884 | } else { |
885 | // This GraphExecutionState represents a graph that was |
886 | // pruned when this was constructed, so we copy the metadata from |
887 | // a member variable. |
888 | CHECK(rewrite_metadata_); |
889 | rewrite_metadata = *rewrite_metadata_; |
890 | } |
891 | |
892 | CHECK_EQ(options.callable_options.feed_size(), |
893 | rewrite_metadata.feed_types.size()); |
894 | CHECK_EQ(options.callable_options.fetch_size(), |
895 | rewrite_metadata.fetch_types.size()); |
896 | |
897 | // TODO(andydavis): Clarify optimization pass requirements around CostModel. |
898 | GraphOptimizationPassOptions optimization_options; |
899 | optimization_options.session_options = session_options_; |
900 | optimization_options.graph = &optimized_graph; |
901 | optimization_options.flib_def = optimized_flib.get(); |
902 | optimization_options.device_set = device_set_; |
903 | |
904 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
905 | OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); |
906 | |
907 | int64_t collective_graph_key = options.collective_graph_key; |
908 | if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) { |
909 | // BuildGraphOptions does not specify a collective_graph_key. Check all |
910 | // nodes in the Graph and FunctionLibraryDefinition for collective ops and |
911 | // if found, initialize a collective_graph_key as a hash of the ordered set |
912 | // of instance keys. |
913 | std::set<int32> instance_key_set; |
914 | bool has_collective_v2 = false; |
915 | for (Node* node : optimized_graph->nodes()) { |
916 | if (node->IsCollective()) { |
917 | int32_t instance_key; |
918 | TF_RETURN_IF_ERROR( |
919 | GetNodeAttr(node->attrs(), "instance_key" , &instance_key)); |
920 | instance_key_set.emplace(instance_key); |
921 | } else if (IsCollectiveV2(node->type_string())) { |
922 | has_collective_v2 = true; |
923 | } else { |
924 | const FunctionDef* fdef = optimized_flib->Find(node->def().op()); |
925 | if (fdef != nullptr) { |
926 | for (const NodeDef& ndef : fdef->node_def()) { |
927 | if (ndef.op() == "CollectiveReduce" || |
928 | ndef.op() == "CollectiveBcastSend" || |
929 | ndef.op() == "CollectiveBcastRecv" || |
930 | ndef.op() == "CollectiveGather" ) { |
931 | int32_t instance_key; |
932 | TF_RETURN_IF_ERROR( |
933 | GetNodeAttr(ndef, "instance_key" , &instance_key)); |
934 | instance_key_set.emplace(instance_key); |
935 | } else if (IsCollectiveV2(ndef.op())) { |
936 | has_collective_v2 = true; |
937 | } |
938 | } |
939 | } |
940 | } |
941 | } |
942 | if (!instance_key_set.empty()) { |
943 | uint64 hash = 0x8774aa605c729c72ULL; |
944 | for (int32_t instance_key : instance_key_set) { |
945 | hash = Hash64Combine(instance_key, hash); |
946 | } |
947 | collective_graph_key = hash; |
948 | } else if (has_collective_v2) { |
949 | collective_graph_key = 0x8774aa605c729c72ULL; |
950 | } |
951 | } |
952 | |
953 | // Make collective execution order deterministic if needed. |
954 | if (options.collective_order != GraphCollectiveOrder::kNone) { |
955 | TF_RETURN_IF_ERROR( |
956 | OrderCollectives(optimized_graph.get(), options.collective_order)); |
957 | } |
958 | |
959 | // Copy the extracted graph in order to make its node ids dense, |
960 | // since the local CostModel used to record its stats is sized by |
961 | // the largest node id. |
962 | std::unique_ptr<ClientGraph> dense_copy( |
963 | new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types, |
964 | rewrite_metadata.fetch_types, collective_graph_key)); |
965 | CopyGraph(*optimized_graph, &dense_copy->graph); |
966 | |
967 | // TODO(vrv): We should check invariants of the graph here. |
968 | metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs); |
969 | *out = std::move(dense_copy); |
970 | return OkStatus(); |
971 | } |
972 | |
973 | } // namespace tensorflow |
974 | |