1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/graph_execution_state.h"
17
18#include <memory>
19#include <set>
20#include <string>
21#include <unordered_set>
22#include <utility>
23#include <vector>
24
25#include "absl/container/flat_hash_set.h"
26#include "absl/memory/memory.h"
27#include "absl/strings/str_join.h"
28#include "tensorflow/core/common_runtime/device.h"
29#include "tensorflow/core/common_runtime/graph_constructor.h"
30#include "tensorflow/core/common_runtime/optimization_registry.h"
31#include "tensorflow/core/common_runtime/placer.h"
32#include "tensorflow/core/framework/attr_value.pb.h"
33#include "tensorflow/core/framework/device_factory.h"
34#include "tensorflow/core/framework/function.h"
35#include "tensorflow/core/framework/function.pb.h"
36#include "tensorflow/core/framework/graph.pb.h"
37#include "tensorflow/core/framework/graph_def_util.h"
38#include "tensorflow/core/framework/metrics.h"
39#include "tensorflow/core/framework/node_def.pb.h"
40#include "tensorflow/core/framework/op.h"
41#include "tensorflow/core/framework/tensor.pb.h"
42#include "tensorflow/core/framework/versions.pb.h"
43#include "tensorflow/core/graph/algorithm.h"
44#include "tensorflow/core/graph/collective_order.h"
45#include "tensorflow/core/graph/graph.h"
46#include "tensorflow/core/graph/subgraph.h"
47#include "tensorflow/core/graph/tensor_id.h"
48#include "tensorflow/core/graph/validate.h"
49#include "tensorflow/core/lib/core/errors.h"
50#include "tensorflow/core/lib/core/status.h"
51#include "tensorflow/core/lib/gtl/flatset.h"
52#include "tensorflow/core/lib/strings/stringprintf.h"
53#include "tensorflow/core/platform/logging.h"
54#include "tensorflow/core/platform/types.h"
55#include "tensorflow/core/util/device_name_utils.h"
56#include "tensorflow/core/util/util.h"
57
58#ifndef IS_MOBILE_PLATFORM
59#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
60#include "tensorflow/core/grappler/grappler_item.h"
61#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
62#endif // IS_MOBILE_PLATFORM
63
64namespace tensorflow {
65
66namespace {
67bool IsCollectiveV2(const string& op) {
68 return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" ||
69 op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2";
70}
71} // namespace
72
73GraphExecutionState::GraphExecutionState(
74 std::unique_ptr<GraphDef>&& graph_def,
75 std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
76 const GraphExecutionStateOptions& options)
77 : stateful_placements_(options.stateful_placements),
78 original_graph_def_(std::move(graph_def)),
79 device_set_(options.device_set),
80 session_options_(options.session_options),
81 session_handle_(options.session_handle),
82 flib_def_(std::move(flib_def)),
83 graph_(nullptr) {}
84
85GraphExecutionState::~GraphExecutionState() {
86 node_name_to_cost_id_map_.clear();
87 delete graph_;
88}
89
90/* static */ Status GraphExecutionState::MakeForBaseGraph(
91 GraphDef&& graph_def, const GraphExecutionStateOptions& options,
92 std::unique_ptr<GraphExecutionState>* out_state) {
93#ifndef __ANDROID__
94 VLOG(4) << "Graph proto is \n" << graph_def.DebugString();
95#endif // __ANDROID__
96
97 auto flib_def = std::make_unique<FunctionLibraryDefinition>(
98 OpRegistry::Global(), graph_def.library());
99
100 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&graph_def, *flib_def, 0));
101
102 if (options.session_options->config.graph_options().place_pruned_graph() ||
103 !options.session_options->config.experimental()
104 .optimize_for_static_graph()) {
105 auto ret = absl::WrapUnique(new GraphExecutionState(
106 std::make_unique<GraphDef>(std::move(graph_def)), std::move(flib_def),
107 options));
108
109 // When place_pruned_graph is true, a different Graph* will be initialized
110 // each time we prune the original graph, so there is no need to
111 // construct a Graph* in this case.
112 if (!options.session_options->config.graph_options().place_pruned_graph()) {
113 auto base_graph = std::make_unique<Graph>(OpRegistry::Global());
114 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_,
115 base_graph.get()));
116 TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
117 }
118 *out_state = std::move(ret);
119 } else {
120 auto ret = absl::WrapUnique(
121 new GraphExecutionState(nullptr, std::move(flib_def), options));
122 auto base_graph = std::make_unique<Graph>(OpRegistry::Global());
123 TF_RETURN_IF_ERROR(
124 ConvertGraphDefToGraph({}, std::move(graph_def), base_graph.get()));
125 TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
126 *out_state = std::move(ret);
127 }
128 return OkStatus();
129}
130
131/* static */ Status GraphExecutionState::MakeForPrunedGraph(
132 const GraphExecutionState& base_execution_state,
133 const GraphExecutionStateOptions& options,
134 const BuildGraphOptions& subgraph_options,
135 std::unique_ptr<GraphExecutionState>* out_state,
136 std::unique_ptr<ClientGraph>* out_client_graph) {
137 if (!(base_execution_state.session_options_->config.graph_options()
138 .place_pruned_graph() &&
139 options.session_options->config.graph_options().place_pruned_graph())) {
140 return errors::Internal(
141 "MakeForPrunedGraph is only supported when the `place_pruned_graph` "
142 "option is true.");
143 }
144 if (!base_execution_state.original_graph_def_) {
145 // NOTE(mrry): By adding this restriction, which matches the only current
146 // usage of this (fairly obscure) method, we do not need to store a
147 // redundant copy of the original graph in `*out_state`.
148 return errors::Internal(
149 "MakeForPrunedGraph is only supported when `base_execution_state` is "
150 "the Session-level `GraphExecutionState`.");
151 }
152
153 // NOTE(mrry): This makes a copy of `graph_def`, which is
154 // regrettable. We could make `GraphDef` objects shareable between
155 // execution states to optimize pruned graph execution, but since
156 // this case is primarily used for interactive sessions, we make the
157 // bet that graph construction is not performance-critical. (Note
158 // also that the previous version used `Extend()`, which is strictly
159 // more expensive than copying a `GraphDef`.)
160 GraphDef temp(*base_execution_state.original_graph_def_);
161 auto flib_def = std::make_unique<FunctionLibraryDefinition>(
162 OpRegistry::Global(), temp.library());
163 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&temp, *flib_def, 0));
164 auto ret = absl::WrapUnique(
165 new GraphExecutionState(nullptr, std::move(flib_def), options));
166
167 auto base_graph = std::make_unique<Graph>(OpRegistry::Global());
168 TF_RETURN_IF_ERROR(
169 ConvertGraphDefToGraph({}, std::move(temp), base_graph.get()));
170
171 // Rewrite the graph before placement.
172 ret->rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
173 TF_RETURN_IF_ERROR(ret->PruneGraph(subgraph_options, base_graph.get(),
174 ret->rewrite_metadata_.get()));
175 TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
176 TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
177 *out_state = std::move(ret);
178 return OkStatus();
179}
180
181Status GraphExecutionState::Extend(
182 const GraphDef& extension_def,
183 std::unique_ptr<GraphExecutionState>* out) const {
184 if (session_options_->config.experimental().optimize_for_static_graph()) {
185 return errors::FailedPrecondition(
186 "Extending the graph is not supported when "
187 "`optimize_for_static_graph` is true.");
188 }
189
190 GraphDef gdef;
191
192 // 1. Copy the function library.
193 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library()));
194 *gdef.mutable_library() = flib_def_->ToProto();
195
196 // 2. Build an index of the new node names.
197 std::unordered_set<string> new_names;
198 for (const NodeDef& node : extension_def.node()) {
199 new_names.insert(node.name());
200 }
201
202 // 3. Add the non-duplicates from the old graph to the new graph.
203 // Return an error if the same node name appears in both the
204 // old graph and the extension.
205 for (const NodeDef& node : original_graph_def_->node()) {
206 if (new_names.count(node.name()) == 0) {
207 *gdef.add_node() = node;
208 } else {
209 return errors::InvalidArgument(
210 "GraphDef argument to Extend includes node '", node.name(),
211 "', which was created by a previous call to Create or Extend in this "
212 "session.");
213 }
214 }
215
216 // 4. Merge the versions field.
217 int old_node_size = gdef.node_size();
218 gdef.mutable_node()->MergeFrom(extension_def.node());
219 TF_RETURN_IF_ERROR(
220 AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size));
221 // Merge versions
222 if (gdef.has_versions()) {
223 if (gdef.versions().producer() != extension_def.versions().producer()) {
224 return errors::InvalidArgument(
225 "Can't extend GraphDef at version ", gdef.versions().producer(),
226 " with graph at version ", extension_def.versions().producer());
227 }
228 VersionDef* versions = gdef.mutable_versions();
229 versions->set_min_consumer(std::max(
230 versions->min_consumer(), extension_def.versions().min_consumer()));
231 if (extension_def.versions().bad_consumers_size()) {
232 // Add new bad_consumers that aren't already marked bad.
233 //
234 // Note: This implementation is quadratic time if there are many calls to
235 // ExtendLocked with many bad consumers. Since this is unlikely, and
236 // fixing it would require data structures outside of this routine,
237 // quadratic time it is.
238 auto* bad_consumers = versions->mutable_bad_consumers();
239 const std::unordered_set<int> existing(bad_consumers->begin(),
240 bad_consumers->end());
241 for (const int v : extension_def.versions().bad_consumers()) {
242 if (existing.find(v) == existing.end()) {
243 bad_consumers->Add(v);
244 }
245 }
246 }
247
248 } else {
249 gdef.mutable_versions()->CopyFrom(extension_def.versions());
250 }
251
252 // 5. Validate that the final graphdef is valid.
253 if (gdef.versions().producer() >= 5) {
254 // Validate the graph: we assume that merging two valid graphs
255 // should maintain graph validity.
256 TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_));
257 }
258
259 // 6. Add the extension.
260 GraphExecutionStateOptions combined_options;
261 combined_options.device_set = device_set_;
262 combined_options.session_options = session_options_;
263 combined_options.session_handle = session_handle_;
264 combined_options.stateful_placements = stateful_placements_;
265
266 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *flib_def_, 0));
267 auto flib_def = std::make_unique<FunctionLibraryDefinition>(
268 OpRegistry::Global(), gdef.library());
269 auto new_execution_state = absl::WrapUnique(
270 new GraphExecutionState(std::make_unique<GraphDef>(std::move(gdef)),
271 std::move(flib_def), combined_options));
272
273 if (!session_options_->config.graph_options().place_pruned_graph()) {
274 auto base_graph = std::make_unique<Graph>(OpRegistry::Global());
275 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
276 {}, *new_execution_state->original_graph_def_, base_graph.get()));
277 TF_RETURN_IF_ERROR(
278 new_execution_state->InitBaseGraph(std::move(base_graph)));
279 }
280 *out = std::move(new_execution_state);
281
282 // NOTE(mrry): Extend() is likely to be used for non-throughput-sensitive
283 // interactive workloads, but in future we may want to transfer other
284 // parts of the placement and/or cost model.
285 return OkStatus();
286}
287
288void GraphExecutionState::SaveStatefulNodes(Graph* graph) {
289 for (Node* n : graph->nodes()) {
290 if (n->op_def().is_stateful()) {
291 VLOG(2) << "Saving " << n->DebugString();
292 stateful_placements_[n->name()] = n->assigned_device_name();
293 }
294 }
295}
296
297void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
298 for (Node* n : graph->nodes()) {
299 if (n->op_def().is_stateful()) {
300 auto iter = stateful_placements_.find(n->name());
301 if (iter != stateful_placements_.end()) {
302 n->set_assigned_device_name(iter->second);
303 VLOG(2) << "Restored " << n->DebugString();
304 }
305 }
306 }
307}
308
309namespace {
310
311class TensorConnectionPruneRewrite : public subgraph::PruneRewrite {
312 public:
313 TensorConnectionPruneRewrite(const string* endpoint_name,
314 NodeBuilder::NodeOut from_tensor)
315 : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */),
316 from_tensor_(std::move(from_tensor)) {}
317
318 Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
319 Node** out_node) override {
320 Status s;
321 auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) {
322 if (n == feed_tensor.node) {
323 s.Update(errors::InvalidArgument(
324 "Requested Tensor connection between nodes \"",
325 feed_tensor.node->name(), "\" and \"", from_tensor_.node->name(),
326 "\" would create a cycle."));
327 }
328 };
329 ReverseDFSFrom(*g, {from_tensor_.node}, std::move(check_no_cycle_fn),
330 nullptr);
331 TF_RETURN_IF_ERROR(s);
332
333 TF_RETURN_IF_ERROR(
334 NodeBuilder(strings::StrCat("_identity_", feed_tensor.node->name(), "_",
335 feed_tensor.index),
336 "Identity")
337 .Input(from_tensor_)
338 .Attr("T",
339 BaseType(from_tensor_.node->output_type(from_tensor_.index)))
340 .Finalize(g, out_node));
341
342 (*out_node)->set_assigned_device_name(
343 feed_tensor.node->assigned_device_name());
344 return OkStatus();
345 }
346
347 private:
348 NodeBuilder::NodeOut from_tensor_;
349};
350
351template <class Map>
352Status LookupDevice(const DeviceSet& device_set, const string& tensor_name,
353 const Map& tensor2device,
354 const tensorflow::DeviceAttributes** out_device_attrs) {
355 *out_device_attrs = nullptr;
356 if (tensor2device.empty()) {
357 *out_device_attrs = &device_set.client_device()->attributes();
358 return OkStatus();
359 }
360 const auto it = tensor2device.find(tensor_name);
361 if (it == tensor2device.end()) {
362 *out_device_attrs = &device_set.client_device()->attributes();
363 return OkStatus();
364 }
365 DeviceNameUtils::ParsedName parsed_name;
366 if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) {
367 return errors::InvalidArgument("Invalid device name ('", it->second,
368 "') provided for the tensor '", tensor_name,
369 "' in CallableOptions");
370 }
371 Device* device = device_set.FindDeviceByName(
372 DeviceNameUtils::ParsedNameToString(parsed_name));
373 if (device == nullptr) {
374 return errors::InvalidArgument("Device '", it->second,
375 "' specified for tensor '", tensor_name,
376 "' in CallableOptions does not exist");
377 }
378 *out_device_attrs = &device->attributes();
379 return OkStatus();
380}
381
382struct TensorAndDevice {
383 // WARNING: backing memory for the 'tensor' field is NOT owend.
384 const TensorId tensor;
385 // WARNING: device pointer is not owned, so must outlive TensorAndDevice.
386 const DeviceAttributes* device;
387};
388
389// Tensors of some DataTypes cannot placed in device memory as feeds or
390// fetches. Validate against a allowlist of those known to work.
391bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
392 // The mechanism for supporting feeds of device-backed Tensors requires
393 // the _Arg kernel to be registered for the corresponding type (and that
394 // the input to the kernel be in device and not host memory).
395 //
396 // The mechanism for supporting fetches of device-backed Tensors requires
397 // the _Retval kernel to be registered for the corresponding type (and
398 // that the output is produced in device and not host memory).
399 //
400 // For now, we return true iff there are _Arg AND _Retval kernels for dtype on
401 // the device. False negatives are okay, false positives would be bad.
402 //
403 // TODO(ashankar): Instead of a allowlist here, perhaps we could query
404 // the kernel registry for _Arg and _Retval kernels instead.
405 if (device_type == DEVICE_CPU) return true;
406 if (device_type != DEVICE_GPU &&
407 !DeviceFactory::IsPluggableDevice(device_type))
408 return false;
409 switch (dtype) {
410 case DT_BFLOAT16:
411 case DT_BOOL:
412 case DT_COMPLEX128:
413 case DT_COMPLEX64:
414 case DT_DOUBLE:
415 case DT_FLOAT:
416 case DT_HALF:
417 case DT_INT16:
418 case DT_INT64:
419 case DT_INT8:
420 case DT_UINT16:
421 case DT_UINT8:
422 return true;
423 default:
424 return false;
425 }
426}
427
428Status ValidateFeedAndFetchDevices(
429 const Graph& graph,
430 const std::vector<TensorAndDevice>& tensors_and_devices) {
431 if (tensors_and_devices.empty()) return OkStatus();
432 std::vector<bool> found(tensors_and_devices.size(), false);
433 for (const Node* node : graph.nodes()) {
434 // Linearly looping through all nodes and then all feed+fetch tensors isn't
435 // quite efficient. At the time of this writing, the expectation was that
436 // tensors_and_devices.size() is really small in practice, so this won't be
437 // problematic.
438 // Revist and make a more efficient lookup possible if needed (e.g., perhaps
439 // Graph can maintain a map from node name to Node*).
440 for (int i = 0; i < tensors_and_devices.size(); ++i) {
441 const TensorAndDevice& td = tensors_and_devices[i];
442 if (td.tensor.first != node->name()) continue;
443 found[i] = true;
444 TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second));
445 const DataType dtype = node->output_type(td.tensor.second);
446 if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) {
447 return errors::Unimplemented(
448 "Cannot feed or fetch tensor '", td.tensor.ToString(),
449 "' from device ", td.device->name(), " as feeding/fetching from ",
450 td.device->device_type(), " devices is not yet supported for ",
451 DataTypeString(dtype), " tensors");
452 }
453 }
454 }
455 for (int i = 0; i < found.size(); ++i) {
456 if (!found[i]) {
457 return errors::InvalidArgument(
458 "Tensor ", tensors_and_devices[i].tensor.ToString(),
459 ", specified in either feed_devices or fetch_devices was not found "
460 "in the Graph");
461 }
462 }
463 return OkStatus();
464}
465
466Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node,
467 PartialTensorShape* shape,
468 DataType* type) {
469 static const gtl::FlatSet<string>* const kHasExplicitShapeAttribute =
470 CHECK_NOTNULL((new gtl::FlatSet<string>{
471 "Placeholder", "PlaceholderV2", "PlaceholderWithDefault",
472 "ParallelConcat", "ImmutableConst", "_ParallelConcatStart",
473 "InfeedDequeue", "OutfeedDequeue", "CollectiveBcastSend",
474 "CollectiveBcastRecv", "AccumulateNV2", "VariableV2", "Variable",
475 "TemporaryVariable", "NcclBroadcast", "_ScopedAllocator",
476 "_ScopedAllocatorConcat"}));
477
478 // All the node types handled here have their output datatype set in
479 // either attribute 'dtype' or 'T'.
480 if (!TryGetNodeAttr(node, "dtype", type) &&
481 !TryGetNodeAttr(node, "T", type)) {
482 return errors::InvalidArgument(
483 "Could not determine output type for feed node: ", node.name(),
484 " of type ", node.op());
485 }
486
487 // First handle the case of feeding a const node.
488 if (node.op() == "Const" && HasNodeAttr(node, "value")) {
489 *shape =
490 PartialTensorShape(node.attr().at("value").tensor().tensor_shape());
491 } else if (kHasExplicitShapeAttribute->find(node.op()) !=
492 kHasExplicitShapeAttribute->end()) {
493 TF_RETURN_IF_ERROR(GetNodeAttr(node, "shape", shape));
494 } else {
495 return errors::InvalidArgument("Could not determine shape for feed node: ",
496 node.name(), " of type ", node.op());
497 }
498 return OkStatus();
499}
500
501} // namespace
502
503Status GraphExecutionState::PruneGraph(
504 const BuildGraphOptions& options, Graph* graph,
505 subgraph::RewriteGraphMetadata* out_rewrite_metadata) {
506 std::vector<std::unique_ptr<subgraph::PruneRewrite>> feed_rewrites;
507 feed_rewrites.reserve(options.callable_options.feed_size());
508 std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites;
509 fetch_rewrites.reserve(options.callable_options.fetch_size());
510 if (options.use_function_convention) {
511 std::vector<TensorAndDevice> tensors_and_devices;
512 for (int i = 0; i < options.callable_options.feed_size(); ++i) {
513 // WARNING: feed MUST be a reference, since ArgFeedRewrite and
514 // tensors_and_devices holds on to its address.
515 const string& feed = options.callable_options.feed(i);
516 const DeviceAttributes* device_info;
517 TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed,
518 options.callable_options.feed_devices(),
519 &device_info));
520 feed_rewrites.emplace_back(
521 new subgraph::ArgFeedRewrite(&feed, device_info, i));
522 tensors_and_devices.push_back({ParseTensorName(feed), device_info});
523 }
524 if (!options.callable_options.fetch_devices().empty() &&
525 !options.callable_options.fetch_skip_sync()) {
526 return errors::Unimplemented(
527 "CallableOptions.fetch_skip_sync = false is not yet implemented. You "
528 "can set it to true instead, but MUST ensure that Device::Sync() is "
529 "invoked on the Device corresponding to the fetched tensor before "
530 "dereferencing the Tensor's memory.");
531 }
532 for (int i = 0; i < options.callable_options.fetch_size(); ++i) {
533 // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and
534 // tensors_and_devices holds on to its address.
535 const string& fetch = options.callable_options.fetch(i);
536 const DeviceAttributes* device_info;
537 TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch,
538 options.callable_options.fetch_devices(),
539 &device_info));
540 fetch_rewrites.emplace_back(
541 new subgraph::RetvalFetchRewrite(&fetch, device_info, i));
542 tensors_and_devices.push_back({ParseTensorName(fetch), device_info});
543 }
544 TF_RETURN_IF_ERROR(
545 ValidateFeedAndFetchDevices(*graph, tensors_and_devices));
546 } else {
547 if (!options.callable_options.feed_devices().empty() ||
548 !options.callable_options.fetch_devices().empty()) {
549 return errors::Unimplemented(
550 "CallableOptions::feed_devices and CallableOptions::fetch_devices "
551 "to configure feeding/fetching tensors to/from device memory is not "
552 "yet supported when using a remote session.");
553 }
554 const DeviceAttributes* device_info =
555 &device_set_->client_device()->attributes();
556 for (const string& feed : options.callable_options.feed()) {
557 feed_rewrites.emplace_back(
558 new subgraph::RecvFeedRewrite(&feed, device_info));
559 }
560 for (const string& fetch : options.callable_options.fetch()) {
561 fetch_rewrites.emplace_back(
562 new subgraph::SendFetchRewrite(&fetch, device_info));
563 }
564 }
565
566 for (const TensorConnection& tensor_connection :
567 options.callable_options.tensor_connection()) {
568 Node* from_node = nullptr;
569 TensorId from_id(ParseTensorName(tensor_connection.from_tensor()));
570
571 for (Node* n : graph->nodes()) {
572 if (n->name() == from_id.first) {
573 from_node = n;
574 break;
575 }
576 }
577 if (from_node == nullptr) {
578 return errors::InvalidArgument(
579 "Requested tensor connection from unknown node: \"",
580 tensor_connection.to_tensor(), "\".");
581 }
582 if (from_id.second >= from_node->num_outputs()) {
583 return errors::InvalidArgument(
584 "Requested tensor connection from unknown edge: \"",
585 tensor_connection.to_tensor(),
586 "\" (actual number of outputs = ", from_node->num_outputs(), ").");
587 }
588
589 feed_rewrites.emplace_back(new TensorConnectionPruneRewrite(
590 &tensor_connection.to_tensor(), {from_node, from_id.second}));
591 }
592
593 std::vector<string> target_node_names(
594 options.callable_options.target().begin(),
595 options.callable_options.target().end());
596 TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
597 graph, feed_rewrites, fetch_rewrites, target_node_names,
598 out_rewrite_metadata));
599
600 CHECK_EQ(out_rewrite_metadata->feed_types.size(),
601 options.callable_options.feed_size() +
602 options.callable_options.tensor_connection_size());
603 for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) {
604 out_rewrite_metadata->feed_types.pop_back();
605 }
606 return OkStatus();
607}
608
609Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) {
610 // Save stateful placements before placing.
611 RestoreStatefulNodes(new_graph.get());
612
613 GraphOptimizationPassOptions optimization_options;
614 optimization_options.session_handle = session_handle_;
615 optimization_options.session_options = session_options_;
616 optimization_options.graph = &new_graph;
617 optimization_options.flib_def = flib_def_.get();
618 optimization_options.device_set = device_set_;
619
620 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
621 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
622
623 Placer placer(new_graph.get(), "", flib_def_.get(), device_set_,
624 /* default_local_device= */ nullptr,
625 session_options_ == nullptr ||
626 session_options_->config.allow_soft_placement(),
627 session_options_ != nullptr &&
628 session_options_->config.log_device_placement());
629 // TODO(mrry): Consider making the Placer cancellable.
630 TF_RETURN_IF_ERROR(placer.Run());
631
632 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
633 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
634
635 for (const Node* n : new_graph->nodes()) {
636 VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
637 node_name_to_cost_id_map_[n->name()] = n->cost_id();
638 }
639
640 SaveStatefulNodes(new_graph.get());
641 graph_ = new_graph.release();
642 return OkStatus();
643}
644
645Status GraphExecutionState::OptimizeGraph(
646 const BuildGraphOptions& options, const Graph& graph,
647 const FunctionLibraryDefinition* flib_def,
648 std::unique_ptr<Graph>* optimized_graph,
649 std::unique_ptr<FunctionLibraryDefinition>* optimized_flib) {
650#ifdef IS_MOBILE_PLATFORM
651 return errors::InvalidArgument("Mobile platforms not supported");
652#else
653 if (session_options_->config.graph_options().place_pruned_graph()) {
654 return errors::InvalidArgument("Can't optimize a pruned graph");
655 }
656
657 if (grappler::MetaOptimizerEnabled(session_options_->config)) {
658 // Here we build the GrapplerItem before calling the optimizer.
659 grappler::GrapplerItem item;
660 item.id = "tf_graph";
661
662 // Add devices to the GrapplerItem
663 // It's ok to skip invalid device annotations in Grappler.
664 for (const Device* d : device_set_->devices()) {
665 Status added_device = item.AddDevice(d->name());
666 if (!added_device.ok()) VLOG(3) << added_device.error_message();
667 }
668 VLOG(3) << "Grappler available devices: "
669 << absl::StrJoin(item.devices(), ", ");
670
671 // Add fetches to the GrapplerItem.
672 item.fetch.insert(item.fetch.end(),
673 options.callable_options.fetch().begin(),
674 options.callable_options.fetch().end());
675 item.fetch.insert(item.fetch.end(),
676 options.callable_options.target().begin(),
677 options.callable_options.target().end());
678
679 for (const TensorConnection& tensor_connection :
680 options.callable_options.tensor_connection()) {
681 item.fetch.push_back(tensor_connection.from_tensor());
682 }
683
684 // Add feeds to the GrapplerItem if we know them.
685 absl::flat_hash_set<absl::string_view> node_names;
686 if (!(options.callable_options.feed().empty() &&
687 options.callable_options.tensor_connection().empty())) {
688 std::vector<SafeTensorId> feeds;
689
690 for (const string& feed : options.callable_options.feed()) {
691 feeds.emplace_back(ParseTensorName(feed));
692 }
693 for (const TensorConnection& tensor_connection :
694 options.callable_options.tensor_connection()) {
695 feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor()));
696 }
697
698 // For feeds with tensor index 0 we try to find the corresponding node in
699 // the graph to infer feed data type and shape.
700 absl::flat_hash_set<absl::string_view> feed_nodes;
701
702 // For feeds with tensor index larger than 0, we can't infer data type or
703 // shape from the graph. Currently we only support type and shape
704 // inference from a small set of node types: Placeholder, Const, etc...
705 for (const SafeTensorId& feed : feeds) {
706 if (feed.index() > 0) {
707 VLOG(3) << "Add undefined feed for: " << feed.ToString();
708 Tensor fake_input(DT_INVALID, {0});
709 item.feed.emplace_back(feed.ToString(), fake_input);
710 } else {
711 VLOG(3) << "Add node for feed inference: " << feed.ToString();
712 feed_nodes.insert(feed.node());
713 continue;
714 }
715 }
716
717 // For feeds with tensor index == 0 we try to infer data type and tensor
718 // shape from the graph, by looking at the fed node attributes.
719 node_names.reserve(graph.num_nodes());
720 for (const Node* node : graph.nodes()) {
721 node_names.insert(node->name());
722 if (feed_nodes.find(node->name()) == feed_nodes.end()) continue;
723
724 // Try to get the type and shape of the feed node.
725 PartialTensorShape partial_shape;
726 DataType type;
727 Status st = GetFeedShapeAndTypeFromAttribute(node->def(),
728 &partial_shape, &type);
729
730 // Failed to get type and shape of the feed node.
731 if (!st.ok()) {
732 VLOG(3) << "Failed to infer feed node type and shape."
733 << " Add undefined feed for: " << node->name();
734 Tensor fake_input(DT_INVALID, {0});
735 item.feed.emplace_back(node->name(), fake_input);
736 continue;
737 }
738
739 // If the shape of the placeholder is only partially known, we are free
740 // to set unknown dimensions of its shape to any value we desire. We
741 // choose 0 to minimize the memory impact. Note that this only matters
742 // if an optimizer chooses to run the graph.
743 TensorShape shape;
744 if (partial_shape.unknown_rank()) {
745 shape = TensorShape({0});
746 } else {
747 for (int i = 0; i < partial_shape.dims(); ++i) {
748 if (partial_shape.dim_size(i) < 0) {
749 partial_shape.set_dim(i, 0);
750 }
751 }
752 if (!partial_shape.AsTensorShape(&shape)) {
753 return errors::InvalidArgument(
754 "Could not derive shape for feed node: ",
755 node->def().DebugString());
756 }
757 }
758
759 VLOG(3) << "Add feed for: " << node->name() << "; type: " << type
760 << "; shape: " << shape;
761 Tensor fake_input(type, shape);
762 item.feed.emplace_back(node->name(), fake_input);
763 }
764 }
765
766 // Validate that the feeds and fetches are valid.
767 if (node_names.empty()) {
768 // Collect all node names in the graph if we didn't already.
769 node_names.reserve(graph.num_nodes());
770 for (const Node* node : graph.nodes()) {
771 node_names.insert(node->name());
772 }
773 }
774 for (const auto& feed : item.feed) {
775 SafeTensorId tensor_id = ParseTensorName(feed.first);
776 if (node_names.find(tensor_id.node()) == node_names.end()) {
777 return errors::InvalidArgument("Invalid feed, no such node in graph: ",
778 feed.first);
779 }
780 }
781 for (const auto& fetch : item.fetch) {
782 SafeTensorId tensor_id = ParseTensorName(fetch);
783 if (node_names.find(tensor_id.node()) == node_names.end()) {
784 return errors::InvalidArgument("Invalid fetch, no such node in graph: ",
785 fetch);
786 }
787 }
788
789 // Convert Graph to GraphDef and add it to the GrapplerItem.
790 graph.ToGraphDef(&item.graph);
791 // TODO(b/114748242): Add a unit test to test this bug fix.
792 if (flib_def) {
793 *item.graph.mutable_library() = flib_def->ToProto();
794 }
795
796 // Construct a virtual cluster and find the cpu_device, which the
797 // ConstantFolding optimizer will use for partial evaluation of the graph.
798 grappler::VirtualCluster cluster(device_set_);
799 Device* cpu_device = nullptr;
800 for (const auto& device : device_set_->devices()) {
801 if (device->parsed_name().id == 0 &&
802 StringPiece(device->parsed_name().type) == "CPU" &&
803 device->GetAllocator(AllocatorAttributes()) != nullptr) {
804 cpu_device = device;
805 }
806 }
807
808 // Now we can run the MetaOptimizer on the constructed GrapplerItem.
809 GraphDef new_graph;
810 TF_RETURN_IF_ERROR(
811 grappler::RunMetaOptimizer(std::move(item), session_options_->config,
812 cpu_device, &cluster, &new_graph));
813
814 // Merge optimized graph function library with an original library.
815 // Optimized graph might have new functions specialized for it's
816 // instantiation context (see Grappler function optimizer), and modified
817 // function body for the existing functions.
818 optimized_flib->reset(new FunctionLibraryDefinition(*flib_def));
819
820 for (const FunctionDef& fdef : new_graph.library().function()) {
821 const string& func_name = fdef.signature().name();
822
823 if ((*optimized_flib)->Contains(func_name)) {
824 VLOG(3) << "Replace function: name=" << func_name;
825 TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef));
826 } else {
827 VLOG(3) << "Add new function: name=" << func_name;
828 TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
829 }
830 }
831 optimized_graph->reset(new Graph(OpRegistry::Global()));
832
833 // Convert the optimized GraphDef back to a Graph.
834 GraphConstructorOptions opts;
835 opts.allow_internal_ops = true;
836 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(new_graph),
837 optimized_graph->get()));
838 // The graph conversion sets the requested device names but not the
839 // assigned device names. However, since at this point the graph is placed
840 // TF expects an assigned device name for every node. Therefore we copy
841 // the requested device into the assigned device field.
842 for (Node* node : optimized_graph->get()->nodes()) {
843 node->set_assigned_device_name(node->requested_device());
844 }
845 return OkStatus();
846 } else {
847 return errors::InvalidArgument("Meta Optimizer disabled");
848 }
849#endif // IS_MOBILE_PLATFORM
850}
851
852Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
853 std::unique_ptr<ClientGraph>* out) {
854 VLOG(1) << "BuildGraph";
855 const uint64 start_time_usecs = Env::Default()->NowMicros();
856 if (!graph_) {
857 // It is only valid to call this method directly when the original graph
858 // was created with the option `place_pruned_graph == false`.
859 return errors::Internal(
860 "Attempted to prune a graph that has not been fully initialized.");
861 }
862
863 // Grappler optimization might change the structure of a graph itself, and
864 // also it can add/prune functions to/from the library.
865 std::unique_ptr<Graph> optimized_graph;
866 std::unique_ptr<FunctionLibraryDefinition> optimized_flib;
867
868 Status s = OptimizeGraph(options, *graph_, flib_def_.get(), &optimized_graph,
869 &optimized_flib);
870 if (!s.ok()) {
871 VLOG(2) << "Grappler optimization failed. Error: " << s.error_message();
872 // Simply copy the original graph and the function library if we couldn't
873 // optimize it.
874 optimized_graph.reset(new Graph(flib_def_.get()));
875 CopyGraph(*graph_, optimized_graph.get());
876 optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_));
877 }
878
879 subgraph::RewriteGraphMetadata rewrite_metadata;
880 if (session_options_ == nullptr ||
881 !session_options_->config.graph_options().place_pruned_graph()) {
882 TF_RETURN_IF_ERROR(
883 PruneGraph(options, optimized_graph.get(), &rewrite_metadata));
884 } else {
885 // This GraphExecutionState represents a graph that was
886 // pruned when this was constructed, so we copy the metadata from
887 // a member variable.
888 CHECK(rewrite_metadata_);
889 rewrite_metadata = *rewrite_metadata_;
890 }
891
892 CHECK_EQ(options.callable_options.feed_size(),
893 rewrite_metadata.feed_types.size());
894 CHECK_EQ(options.callable_options.fetch_size(),
895 rewrite_metadata.fetch_types.size());
896
897 // TODO(andydavis): Clarify optimization pass requirements around CostModel.
898 GraphOptimizationPassOptions optimization_options;
899 optimization_options.session_options = session_options_;
900 optimization_options.graph = &optimized_graph;
901 optimization_options.flib_def = optimized_flib.get();
902 optimization_options.device_set = device_set_;
903
904 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
905 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
906
907 int64_t collective_graph_key = options.collective_graph_key;
908 if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
909 // BuildGraphOptions does not specify a collective_graph_key. Check all
910 // nodes in the Graph and FunctionLibraryDefinition for collective ops and
911 // if found, initialize a collective_graph_key as a hash of the ordered set
912 // of instance keys.
913 std::set<int32> instance_key_set;
914 bool has_collective_v2 = false;
915 for (Node* node : optimized_graph->nodes()) {
916 if (node->IsCollective()) {
917 int32_t instance_key;
918 TF_RETURN_IF_ERROR(
919 GetNodeAttr(node->attrs(), "instance_key", &instance_key));
920 instance_key_set.emplace(instance_key);
921 } else if (IsCollectiveV2(node->type_string())) {
922 has_collective_v2 = true;
923 } else {
924 const FunctionDef* fdef = optimized_flib->Find(node->def().op());
925 if (fdef != nullptr) {
926 for (const NodeDef& ndef : fdef->node_def()) {
927 if (ndef.op() == "CollectiveReduce" ||
928 ndef.op() == "CollectiveBcastSend" ||
929 ndef.op() == "CollectiveBcastRecv" ||
930 ndef.op() == "CollectiveGather") {
931 int32_t instance_key;
932 TF_RETURN_IF_ERROR(
933 GetNodeAttr(ndef, "instance_key", &instance_key));
934 instance_key_set.emplace(instance_key);
935 } else if (IsCollectiveV2(ndef.op())) {
936 has_collective_v2 = true;
937 }
938 }
939 }
940 }
941 }
942 if (!instance_key_set.empty()) {
943 uint64 hash = 0x8774aa605c729c72ULL;
944 for (int32_t instance_key : instance_key_set) {
945 hash = Hash64Combine(instance_key, hash);
946 }
947 collective_graph_key = hash;
948 } else if (has_collective_v2) {
949 collective_graph_key = 0x8774aa605c729c72ULL;
950 }
951 }
952
953 // Make collective execution order deterministic if needed.
954 if (options.collective_order != GraphCollectiveOrder::kNone) {
955 TF_RETURN_IF_ERROR(
956 OrderCollectives(optimized_graph.get(), options.collective_order));
957 }
958
959 // Copy the extracted graph in order to make its node ids dense,
960 // since the local CostModel used to record its stats is sized by
961 // the largest node id.
962 std::unique_ptr<ClientGraph> dense_copy(
963 new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
964 rewrite_metadata.fetch_types, collective_graph_key));
965 CopyGraph(*optimized_graph, &dense_copy->graph);
966
967 // TODO(vrv): We should check invariants of the graph here.
968 metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs);
969 *out = std::move(dense_copy);
970 return OkStatus();
971}
972
973} // namespace tensorflow
974