1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/graph_constructor.h"
17
18#include <algorithm>
19#include <set>
20#include <string>
21#include <unordered_map>
22#include <unordered_set>
23#include <vector>
24
25#include "absl/algorithm/container.h"
26#include "absl/container/flat_hash_set.h"
27#include "absl/strings/str_cat.h"
28#include "absl/strings/string_view.h"
29#include "tensorflow/core/common_runtime/shape_refiner.h"
30#include "tensorflow/core/framework/function.h"
31#include "tensorflow/core/framework/function.pb.h"
32#include "tensorflow/core/framework/graph.pb.h"
33#include "tensorflow/core/framework/node_def.pb.h"
34#include "tensorflow/core/framework/node_def_util.h"
35#include "tensorflow/core/framework/tensor_shape.pb.h"
36#include "tensorflow/core/framework/types.h"
37#include "tensorflow/core/framework/versions.h"
38#include "tensorflow/core/framework/versions.pb.h"
39#include "tensorflow/core/graph/algorithm.h"
40#include "tensorflow/core/graph/graph.h"
41#include "tensorflow/core/graph/tensor_id.h"
42#include "tensorflow/core/lib/core/errors.h"
43#include "tensorflow/core/lib/gtl/flatmap.h"
44#include "tensorflow/core/lib/gtl/flatset.h"
45#include "tensorflow/core/lib/gtl/inlined_vector.h"
46#include "tensorflow/core/lib/strings/scanner.h"
47#include "tensorflow/core/lib/strings/str_util.h"
48#include "tensorflow/core/platform/errors.h"
49#include "tensorflow/core/platform/logging.h"
50#include "tensorflow/core/platform/macros.h"
51#include "tensorflow/core/public/version.h"
52
53namespace tensorflow {
54
55namespace {
56
57// We remove duplicate control inputs before adding edges to the Graph, so we
58// can skip expensive duplicates check in 'AddControlEdge'.
59static constexpr const bool kDoNotCheckDuplicates = true;
60
61inline bool IsMerge(const NodeDef& node_def) {
62 return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
63 node_def.op() == "_XlaMerge";
64}
65
66inline bool IsNextIteration(const NodeDef& node_def) {
67 return node_def.op() == "NextIteration" ||
68 node_def.op() == "RefNextIteration";
69}
70
71bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
72 using ::tensorflow::strings::Scanner;
73 Scanner scanner(s);
74 scanner
75 .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
76 : Scanner::LETTER_DIGIT_DOT)
77 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
78
79 while (true) {
80 if (!scanner.GetResult()) // Some error in previous iteration.
81 return false;
82 if (scanner.empty()) // No error, but nothing left, good.
83 return true;
84
85 // Absorb another piece, starting with a '>'
86 scanner.One(Scanner::RANGLE)
87 .One(Scanner::LETTER_DIGIT_DOT)
88 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
89 }
90}
91
92class GraphConstructor {
93 public:
94 struct Options {
95 Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit)
96 : allow_internal_ops(in.allow_internal_ops),
97 expect_device_spec(in.expect_device_spec),
98 importing(false),
99 validate_nodes(in.validate_nodes),
100 validate_colocation_constraints(false),
101 add_default_attributes(in.add_default_attributes) {}
102 Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
103 : allow_internal_ops(false),
104 expect_device_spec(false),
105 prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
106 ? in.prefix
107 : in.prefix + "/"),
108 uniquify_names(in.uniquify_names),
109 uniquify_prefix(in.uniquify_prefix),
110 input_map(in.input_map.begin(), in.input_map.end()),
111 skip_mapped_nodes(in.skip_mapped_nodes),
112 control_dependencies(in.control_dependencies),
113 return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
114 return_nodes(in.return_nodes),
115 importing(true),
116 validate_nodes(true),
117 validate_colocation_constraints(in.validate_colocation_constraints),
118 validate_shape(in.validate_shape),
119 default_device(in.default_device) {}
120
121 bool allow_internal_ops;
122 bool expect_device_spec;
123
124 string prefix;
125 bool uniquify_names;
126 bool uniquify_prefix;
127 std::map<TensorId, TensorId> input_map;
128 bool skip_mapped_nodes;
129 std::vector<string> control_dependencies;
130 std::vector<TensorId> return_tensors;
131 std::vector<string> return_nodes;
132
133 // TODO(ashankar): This bool exists to separate out functionality required
134 // to make ImportGraphDef a close equivalent of Python's import_graph_def
135 // without affecting the behavior of ConvertGraphDefToGraph at the time
136 // ImportGraphDef was added.
137 //
138 // That said, the functionality here (shape and op validation) seems
139 // applicable to ConvertGraphDefToGraph as well, so make an attempt to
140 // remove this.
141 bool importing;
142 // If true, validates that nodes being converted have all expected attrs
143 // set and no unknown attrs set by calling ValidateNodeDef().
144 // `validate_nodes` is always true when `importing` is set.
145 bool validate_nodes;
146 bool validate_colocation_constraints;
147 bool validate_shape = true;
148
149 // If true, GraphConstructor will add attributes with their default
150 // value to the Node when they are missing from the NodeDef.
151 bool add_default_attributes = true;
152
153 string default_device;
154 };
155
156 typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
157
158 // versions and library may be nullptr
159 static Status Construct(
160 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
161 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
162 std::vector<std::pair<Node*, int>>* return_tensors,
163 std::vector<Node*>* return_nodes,
164 std::vector<SafeTensorId>* missing_unused_input_map_keys);
165
166 static Status Construct(
167 const Options& opts, GraphDef&& graph_def, Graph* g,
168 ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
169 std::vector<Node*>* return_nodes,
170 std::vector<SafeTensorId>* missing_unused_input_map_keys);
171
172 protected:
173 GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner,
174 std::vector<std::pair<Node*, int>>* return_tensors,
175 std::vector<Node*>* return_nodes,
176 std::vector<SafeTensorId>* missing_unused_input_map_keys)
177 : opts_(opts),
178 g_(g),
179 original_versions_(g->versions()),
180 prefix_(opts.prefix),
181 refiner_(refiner),
182 return_tensors_(return_tensors),
183 return_nodes_(return_nodes),
184 missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
185
186 virtual ~GraphConstructor() {}
187
188 Status TryImport() {
189 TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
190 TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
191 TF_RETURN_IF_ERROR(BuildNodeIndex());
192 TF_RETURN_IF_ERROR(InitFromEdges());
193
194 // NOTE: Convert() invokes `consume_node_def()` on each node in the input
195 // graph, so `get_node_def()` is no longer usable once it is called.
196 TF_RETURN_IF_ERROR(Convert());
197
198 TF_RETURN_IF_ERROR(AddBackEdges());
199 TF_RETURN_IF_ERROR(UpdateVersionDef());
200 TF_RETURN_IF_ERROR(PopulateReturnTensors());
201 TF_RETURN_IF_ERROR(PopulateReturnNodes());
202 TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
203 UpdateUniquifiedColocationNames();
204 FixupSourceAndSinkEdges(g_);
205 return OkStatus();
206 }
207
208 private:
209 Status EnsureNoNameCollisions();
210 Status ValidateInputMapAndControlDependencies();
211 Status BuildNodeIndex();
212 Status InitFromEdges();
213 Status Convert();
214 Status AddBackEdges();
215 Status UpdateVersionDef();
216 Status PopulateReturnTensors();
217 Status PopulateReturnNodes();
218 Status PopulateMissingUnusedInputMapKeys();
219
220 void Undo();
221
222 // Prints cycles in the graph.
223 void PrintCycles();
224 // Performs DFS starting at `cur_node` and prints any cycles found.
225 void DFS(int cur_node, std::vector<int>* cur_branch,
226 std::vector<bool>* is_on_cur_branch,
227 absl::flat_hash_set<int>* unvisited,
228 const std::vector<absl::string_view>& node_names);
229 Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
230 Status ValidateColocationConstraints(const NodeDef& node_def);
231 Status MakeNode(NodeDef&& node_def, Node** node);
232 Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
233 Status ValidateShape(Node* node);
234 Status ModifyNodeDefForImport(NodeDef* node_def);
235 // Modifies node_def's inputs according to opts_.input_map.
236 // input_already_exists is a pre-initialized vector of length
237 // node_def->input_size(). This function will mark inputs that are remapped to
238 // true.
239 void RemapNodeDefInputs(NodeDef* node_def,
240 std::vector<bool>* input_already_exists);
241 // input_already_exists is a pre-initialized vector of length
242 // node_def->input_size(). This function will add and mark control inputs as
243 // true.
244 void AddControlDependencies(NodeDef* node_def,
245 std::vector<bool>* input_already_exists);
246 void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
247 NodeDef* node_def);
248
249 // Modifies `node_def` if its name isn't unique, or if any of its inputs'
250 // names have been uniquified. This must be called in topological order on all
251 // nodes.
252 void UniquifyNames(const std::vector<bool>& input_already_exists,
253 NodeDef* node_def);
254
255 // Updates any constructed nodes' colocation group names if the name has been
256 // updated by UniquifyNames. This is called after all the nodes have been
257 // constructed so all the names have been uniquified if necessary.
258 void UpdateUniquifiedColocationNames();
259
260 // Returns true if `name` already exists in `g_` (either as a node name or
261 // prefix).
262 bool NameExistsInGraph(StringPiece name);
263
264 // Returns true if `name` already exists in the GraphDef being imported
265 // (either as a node name or prefix).
266 bool NameExistsInGraphDef(StringPiece name);
267
268 // Returns a unique version of `original_name`, or `original_name` if it's
269 // already unique in the graph.
270 string FindUniqueName(StringPiece original_name);
271
272 // Decrement pending count for users of `processed` and add the ones that now
273 // have all of their pending inputs satisfied to `ready_`.
274 void UpdatePendingCountAndReady(int processed, bool is_next_iteration);
275
276 // Subclasses override the following virtual methods to provide efficient
277 // access to the original protocol buffer-based graph.
278
279 // Returns the number of nodes in the graph.
280 virtual size_t node_def_count() const = 0;
281 // Returns the i^th node in the graph. Must not be called after
282 // consume_node_def(i).
283 virtual const NodeDef& get_node_def(int i) const = 0;
284 // Destructively reads the i^th node in the graph, avoiding a copy if
285 // possible. After calling this method, the result of get_node_def(i) is
286 // undefined.
287 virtual NodeDef consume_node_def(int i) = 0;
288 // Returns the version information for the graph, or nullptr if none is
289 // available.
290 virtual const VersionDef* versions() const = 0;
291 // Returns the function information for the graph, or nullptr if none is
292 // available.
293 virtual const FunctionDefLibrary* library() const = 0;
294
295 // From constructor
296 const Options opts_;
297 Graph* g_;
298 const VersionDef original_versions_;
299
300 // A copy of opts_.prefix, possibly uniquified.
301 string prefix_;
302
303 ShapeRefiner* refiner_;
304
305 // May be null. Not owned.
306 std::vector<std::pair<Node*, int>>* return_tensors_;
307
308 // May be null. Not owned.
309 std::vector<Node*>* return_nodes_;
310
311 // May be null. Not owned.
312 std::vector<SafeTensorId>* missing_unused_input_map_keys_;
313
314 // Intermediate datastructure used to populate
315 // `missing_unused_input_map_keys_`.
316 std::set<TensorId> used_input_map_keys_;
317
318 // Intermediate datastructure used to track the destinations of back edges.
319 absl::flat_hash_set<int> merge_node_indices_;
320
321 // Mapping from node name to the index within node_defs_.
322 struct NodeInfo {
323 explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
324 // Containers require that we have a default constructor.
325 NodeInfo() : NodeInfo(-1) {}
326 int gdef_index;
327 Node* node; // nullptr until the NodeDef is converted to a Node.
328 };
329 absl::flat_hash_map<std::string, NodeInfo> gdef_nodes_;
330
331 // Prefixes already used in the GraphDef being imported.
332 absl::flat_hash_set<StringPiece> gdef_prefixes_;
333
334 // Mapping from node name to the existing node in g_.
335 absl::flat_hash_map<StringPiece, Node*> existing_nodes_;
336
337 // Prefixes already used in the graph.
338 absl::flat_hash_set<StringPiece> existing_prefixes_;
339
340 // Imported node names that have been uniquified. The key is the original
341 // name, the value is the new unique name.
342 gtl::FlatMap<string, string> uniquified_names_;
343
344 // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
345 // (sorted) set so nodes are created in the order defined in the GraphDef.
346 std::set<int> ready_;
347
348 // Mapping between index within node_defs_ and the number of inputs that
349 // still need to be converted.
350 std::vector<int> pending_count_;
351
352 // Mapping between index within node_defs_ and the index within node_defs_ of
353 // all nodes it outputs to.
354 std::vector<gtl::InlinedVector<int, 4>> outputs_;
355
356 // Used in the conversion from node_defs_ to g_ to represent the ith input
357 // of a node.
358 struct InputInfo {
359 explicit InputInfo(const string& node_name, Node* n, int i)
360 : name(node_name), node(n), index(i) {}
361 // Use string instead of StringPiece so we don't have to manage lifetime
362 string name;
363 Node* node;
364 int index;
365
366 static bool IsControlInput(const InputInfo& input) {
367 return input.index == Graph::kControlSlot;
368 }
369 static int CompareName(const InputInfo& lhs, const InputInfo& rhs) {
370 return lhs.name < rhs.name;
371 }
372 static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) {
373 return lhs.name == rhs.name;
374 }
375 };
376
377 // Used in the conversion from node_defs_ to g_ to represent an edge from
378 // the node named 'name' to node 'n'.
379 struct EdgeInfo {
380 explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
381 : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
382 // Use string instead of StringPiece so we don't have to manage lifetime
383 string src_name;
384 int src_index;
385 Node* dst_node;
386 int dst_index;
387 };
388 std::vector<EdgeInfo> back_edges_;
389
390 TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor);
391};
392
393// Implementation of GraphConstructor that does not take ownership of the
394// input NodeDef messages and thus copies the nodes into the constructed Graph*.
395//
396// NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which
397// avoids copying each NodeDef into the constructed Graph*.
398class NodeDefCopyingGraphConstructor : public GraphConstructor {
399 public:
400 NodeDefCopyingGraphConstructor(
401 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
402 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
403 std::vector<std::pair<Node*, int>>* return_tensors,
404 std::vector<Node*>* return_nodes,
405 std::vector<SafeTensorId>* missing_unused_input_map_keys)
406 : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
407 missing_unused_input_map_keys),
408 node_defs_(node_defs),
409 versions_(versions),
410 library_(library) {}
411
412 private:
413 size_t node_def_count() const override { return node_defs_.size(); }
414 const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; }
415 NodeDef consume_node_def(int i) override { return *node_defs_[i]; }
416 const VersionDef* versions() const override { return versions_; }
417 const FunctionDefLibrary* library() const override { return library_; }
418
419 const NodeDefSlice node_defs_;
420 const VersionDef* const versions_;
421 const FunctionDefLibrary* const library_;
422};
423
424// Implementation of GraphConstructor that takes ownership of the input
425// GraphDef, and can perform destructive reads.
426class NodeDefMovingGraphConstructor : public GraphConstructor {
427 public:
428 NodeDefMovingGraphConstructor(
429 const Options& opts, GraphDef&& graph_def, Graph* g,
430 ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors,
431 std::vector<Node*>* return_nodes,
432 std::vector<SafeTensorId>* missing_unused_input_map_keys)
433 : GraphConstructor(opts, g, refiner, return_tensors, return_nodes,
434 missing_unused_input_map_keys),
435 graph_def_(std::move(graph_def)),
436 is_consumed_(graph_def_.node_size(), false) {}
437
438 private:
439 size_t node_def_count() const override { return graph_def_.node().size(); }
440 const NodeDef& get_node_def(int i) const override {
441 CHECK(!is_consumed_[i])
442 << "NodeDef " << i << " accessed after it was consumed.";
443 return graph_def_.node(i);
444 }
445 NodeDef consume_node_def(int i) override {
446 CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice.";
447 is_consumed_[i] = true;
448 return std::move(*graph_def_.mutable_node(i));
449 }
450 const VersionDef* versions() const override { return &graph_def_.versions(); }
451 const FunctionDefLibrary* library() const override {
452 return &graph_def_.library();
453 }
454
455 GraphDef graph_def_;
456 std::vector<bool> is_consumed_;
457};
458
459bool ForwardCompatibilityWindowPassed(const VersionDef& versions) {
460 // TF_GRAPH_DEF_VERSION is incremented daily.
461 // TF has a 3 week forward compatibility guarantee.
462 return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21;
463}
464
465Status MaybeAppendVersionWarning(const VersionDef* versions,
466 const Status& import_status) {
467 if (versions && ForwardCompatibilityWindowPassed(*versions)) {
468 return Status(
469 import_status.code(),
470 absl::StrCat(
471 "Converting GraphDef to Graph has failed with an error: '",
472 import_status.error_message(),
473 "' The binary trying to import the GraphDef was built when "
474 "GraphDef version was ",
475 TF_GRAPH_DEF_VERSION,
476 ". The GraphDef was produced by a binary built when GraphDef "
477 "version was ",
478 versions->producer(),
479 ". The difference between these versions is larger than "
480 "TensorFlow's forward compatibility guarantee, and might be the "
481 "root cause for failing to import the GraphDef."));
482 }
483 return import_status;
484}
485
486/* static */ Status GraphConstructor::Construct(
487 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
488 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
489 std::vector<std::pair<Node*, int>>* return_tensors,
490 std::vector<Node*>* return_nodes,
491 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
492 if (versions) {
493 TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
494 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
495 "GraphDef", "graph"));
496 }
497 NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
498 refiner, return_tensors, return_nodes,
499 missing_unused_input_map_keys);
500 Status s = c.TryImport();
501 if (!s.ok()) {
502 c.Undo();
503 s = MaybeAppendVersionWarning(versions, s);
504 }
505 return s;
506}
507
508/* static */ Status GraphConstructor::Construct(
509 const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner,
510 std::vector<std::pair<Node*, int>>* return_tensors,
511 std::vector<Node*>* return_nodes,
512 std::vector<SafeTensorId>* missing_unused_input_map_keys) {
513 TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
514 TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
515 "GraphDef", "graph"));
516 VersionDef version_def = graph_def.versions();
517 NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
518 return_tensors, return_nodes,
519 missing_unused_input_map_keys);
520 Status s = c.TryImport();
521 if (!s.ok()) {
522 c.Undo();
523 s = MaybeAppendVersionWarning(&version_def, s);
524 }
525 return s;
526}
527
528void GraphConstructor::UpdatePendingCountAndReady(int processed,
529 bool is_next_iteration) {
530 for (size_t i = 0; i < outputs_[processed].size(); ++i) {
531 const int output = outputs_[processed][i];
532 // We didn't consider NextIteration->Merge edges when computing
533 // pending_counts_ so we should not have to consider it here either.
534 bool is_next_iteration_to_merge_edge =
535 is_next_iteration && merge_node_indices_.count(output) == 1;
536 if (!is_next_iteration_to_merge_edge) {
537 int* current_pending_count = &pending_count_[output];
538 CHECK_GT(*current_pending_count, 0);
539 (*current_pending_count)--;
540 if (*current_pending_count == 0) {
541 ready_.insert(output);
542 }
543 }
544 }
545}
546
547// This could be expensive but we don't expect to call it often, if at all (only
548// if there are multiple nodes in g_ with the same name)
549bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
550 const StringPiece& node_name) {
551 for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
552 if (iter->second.first == node_name) return true;
553 }
554 return false;
555}
556
557bool NodeNameInValues(const std::vector<string>& control_dependencies,
558 const StringPiece& node_name) {
559 return std::find(control_dependencies.begin(), control_dependencies.end(),
560 node_name) != control_dependencies.end();
561}
562
563// Adds any prefixes of `node_name` (not including the full name itself) to
564// `prefixes`.
565void AddPrefixes(StringPiece node_name,
566 absl::flat_hash_set<StringPiece>* prefixes) {
567 size_t idx = -1;
568 while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
569 prefixes->insert(node_name.substr(0, idx));
570 }
571}
572
573Status GraphConstructor::EnsureNoNameCollisions() {
574 existing_nodes_.reserve(g_->num_nodes());
575 // Populate existing_nodes_ and existing_prefixes_.
576 for (Node* n : g_->nodes()) {
577 bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
578 if (already_exists) {
579 if (NodeNameInValues(opts_.input_map, n->name())) {
580 return errors::InvalidArgument(
581 "cannot resolve input_map because multiple nodes exist with name '",
582 n->name(), "'");
583 }
584 if (NodeNameInValues(opts_.control_dependencies, n->name())) {
585 return errors::InvalidArgument(
586 "cannot resolve control_dependencies because multiple nodes exist "
587 "with name '",
588 n->name(), "'");
589 }
590 }
591 AddPrefixes(n->name(), &existing_prefixes_);
592 }
593 if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
594 for (size_t i = 0; i < node_def_count(); ++i) {
595 const string& name = get_node_def(i).name();
596 if (NameExistsInGraph(name)) {
597 return errors::InvalidArgument("Node name '", name,
598 "' already exists in the Graph");
599 }
600 }
601 } else if (!prefix_.empty()) {
602 StringPiece prefix_no_slash(prefix_);
603 prefix_no_slash.remove_suffix(1);
604 if (!IsValidNodeName(prefix_no_slash, false)) {
605 return errors::InvalidArgument("Imported node name prefix '", prefix_,
606 "' would lead to invalid node names");
607 }
608 if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
609 prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
610 }
611 }
612 return OkStatus();
613}
614
615Status GraphConstructor::ValidateInputMapAndControlDependencies() {
616 for (const auto& mapping : opts_.input_map) {
617 TensorId src = mapping.first;
618 TensorId dst = mapping.second;
619 if (existing_nodes_.count(dst.first) == 0) {
620 return errors::InvalidArgument(
621 "node '", dst.first, "' in input_map does not exist in graph ",
622 "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
623 }
624 if ((src.second == Graph::kControlSlot) !=
625 (dst.second == Graph::kControlSlot)) {
626 return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
627 dst.ToString(), " between ",
628 "control edge and non-control edge");
629 }
630 }
631 for (const string& node : opts_.control_dependencies) {
632 if (existing_nodes_.count(node) == 0) {
633 return errors::InvalidArgument(
634 "node '", node,
635 "' in control_dependencies does not exist in "
636 "graph");
637 }
638 }
639 return OkStatus();
640}
641
642Status GraphConstructor::BuildNodeIndex() {
643 // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
644 for (int n = 0; n < node_def_count(); ++n) {
645 const NodeDef& node_def = get_node_def(n);
646 if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
647 return errors::InvalidArgument(
648 "Node '", node_def.name(),
649 "': Node name contains invalid characters");
650 }
651 if (!gdef_nodes_.insert(std::make_pair(node_def.name(), NodeInfo(n)))
652 .second) {
653 return errors::InvalidArgument("Node '", node_def.name(),
654 "' is not unique");
655 }
656 // Validate the operation's type.
657 if (node_def.op().empty()) {
658 return errors::InvalidArgument("Node '", node_def.name(),
659 "' does not specify an operation");
660 }
661 if (opts_.expect_device_spec && node_def.device().empty()) {
662 return errors::InvalidArgument("Node '", node_def.name(),
663 "' is missing a device specification");
664 }
665 if (IsMerge(node_def)) {
666 merge_node_indices_.insert(n);
667 }
668 // Validate control edges at end
669 bool in_control_dependence = false;
670 for (int i = 0; i < node_def.input_size(); ++i) {
671 StringPiece input_name = node_def.input(i);
672 if (!input_name.empty() && absl::StartsWith(input_name, "^")) {
673 in_control_dependence = true;
674 } else if (in_control_dependence) {
675 return errors::InvalidArgument(
676 "Node '", node_def.name(),
677 "': Control dependencies must come after regular dependencies");
678 }
679 }
680 // Update gdef_prefixes_.
681 AddPrefixes(node_def.name(), &gdef_prefixes_);
682 }
683 return OkStatus();
684}
685
686Status GraphConstructor::InitFromEdges() {
687 const int num_nodes = node_def_count();
688 pending_count_.reserve(num_nodes);
689 outputs_.resize(num_nodes);
690 gtl::FlatSet<string> next_iteration_nodes;
691 for (int n = 0; n < node_def_count(); ++n) {
692 const NodeDef& node_def = get_node_def(n);
693 if (IsNextIteration(node_def)) {
694 next_iteration_nodes.insert(node_def.name());
695 }
696 }
697
698 // Parse the inputs for each node.
699 for (int n = 0; n < num_nodes; ++n) {
700 const NodeDef& node_def = get_node_def(n);
701 int pending_count = node_def.input_size();
702 if (IsMerge(node_def)) {
703 // Cycles in the graph are only allowed for while loops. A while loop is
704 // identified by an edge from a NextIteration node to a Merge node. For
705 // such Merge nodes, only wait for one non-control input before
706 // considering the node ready to process in Convert().
707 int32_t num_control_edges = 0;
708 bool has_loop_back_edge = false;
709 for (int i = 0; i < node_def.input_size(); ++i) {
710 StringPiece input_name(node_def.input(i));
711 if (absl::StartsWith(input_name, "^")) {
712 num_control_edges++;
713 } else {
714 TensorId id(ParseTensorName(input_name));
715 if (next_iteration_nodes.find(string(id.first)) !=
716 next_iteration_nodes.end()) {
717 has_loop_back_edge = true;
718 }
719 }
720 }
721 if (has_loop_back_edge) {
722 pending_count = num_control_edges + 1;
723 }
724 }
725 for (int i = 0; i < node_def.input_size(); ++i) {
726 StringPiece input_name = node_def.input(i);
727 TensorId id(ParseTensorName(input_name));
728 if (opts_.input_map.count(id) == 0) {
729 // If an input is not mapped, then the input should appear in the graph
730 // being imported.
731 auto iter = gdef_nodes_.find(id.first);
732 if (iter == gdef_nodes_.end()) {
733 return errors::InvalidArgument("Node '", node_def.name(),
734 "': Unknown input node '",
735 node_def.input(i), "'");
736 }
737 outputs_[iter->second.gdef_index].push_back(n);
738 } else {
739 // This input is mapped to an existing edge. Therefore this input is
740 // as good as being already processed.
741 --pending_count;
742 DCHECK_GE(pending_count, 0);
743 }
744 }
745 if (pending_count == 0) {
746 ready_.insert(n);
747 }
748 pending_count_.push_back(pending_count);
749 }
750 return OkStatus();
751}
752
753Status GraphConstructor::ValidateColocationConstraints(
754 const NodeDef& node_def) {
755 if (!opts_.validate_colocation_constraints || !opts_.importing)
756 return OkStatus();
757 const auto iter = node_def.attr().find(kColocationAttrName);
758 if (iter == node_def.attr().end()) return OkStatus();
759 for (const string& c : iter->second.list().s()) {
760 StringPiece s(c);
761 if (absl::ConsumePrefix(&s, kColocationGroupPrefix) &&
762 gdef_nodes_.find(s) == gdef_nodes_.end()) {
763 return errors::InvalidArgument(
764 "Node '", node_def.name(),
765 "' expects to be colocated with unknown node '", s, "'");
766 }
767 }
768 return OkStatus();
769}
770
771Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) {
772 // Add the node to the graph.
773 Status status;
774 *node = g_->AddNode(std::move(node_def), &status);
775 if (!status.ok()) return status;
776 if (opts_.expect_device_spec) {
777 (*node)->set_assigned_device_name((*node)->def().device());
778 }
779 return OkStatus();
780}
781
782Status GraphConstructor::ValidateShape(Node* node) {
783 if (!opts_.importing || !opts_.validate_shape) return OkStatus();
784 TF_RETURN_IF_ERROR(refiner_->AddNode(node));
785 // For nodes with the _output_shapes attribute, override the shape.
786 std::vector<const TensorShapeProto*> shape_attrs;
787 const char* kAttrName = "_output_shapes";
788 if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) {
789 // No _output_shapes attribute, the AddNode call above was sufficient.
790 return OkStatus();
791 }
792 auto* ic = refiner_->GetContext(node);
793 DCHECK(ic != nullptr)
794 << "ShapeRefiner::AddNode() should have created the InferenceContext";
795 if (shape_attrs.size() < node->num_outputs()) {
796 return errors::InvalidArgument(
797 "Node '", node->name(), "' has ", node->num_outputs(),
798 " outputs but the ", kAttrName, " attribute specifies shapes for ",
799 shape_attrs.size(), " outputs");
800 }
801 // NOTE(skyewm): we don't raise an error here because some users depend on
802 // this behavior, even though it's unsafe.
803 // TODO(b/74619486): raise an error.
804 if (shape_attrs.size() > node->num_outputs()) {
805 LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs()
806 << " outputs but the " << kAttrName
807 << " attribute specifies shapes for " << shape_attrs.size()
808 << " outputs. Output shapes may be inaccurate.";
809 }
810 for (int i = 0; i < node->num_outputs(); ++i) {
811 const TensorShapeProto& p = *shape_attrs[i];
812 shape_inference::ShapeHandle h;
813 Status s = ic->MakeShapeFromShapeProto(p, &h);
814 if (!s.ok()) {
815 return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
816 kAttrName, " attribute (shape #", i,
817 " error:'", s.error_message(), "'");
818 }
819 s = refiner_->SetShape(node, i, h);
820 if (!s.ok()) {
821 return errors::InvalidArgument(
822 "Node '", node->name(), "' has an ", kAttrName,
823 " attribute inconsistent with the GraphDef for output #", i, ": ",
824 s.error_message());
825 }
826 }
827 node->ClearAttr(kAttrName);
828 return OkStatus();
829}
830
831Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
832 const OpDef* op_def;
833 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
834 AddDefaultsToNodeDef(*op_def, node_def);
835 TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
836 if (versions()) {
837 TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer()));
838 }
839 return OkStatus();
840}
841
842void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
843 std::vector<bool>* input_already_exists) {
844 // Remove 'inputs_to_remove' from 'node_def'
845 NodeDef copy;
846 copy.mutable_input()->Reserve(node_def->input_size() -
847 inputs_to_remove.size());
848 for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
849 if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
850 ++j;
851 } else {
852 copy.add_input()->swap(*node_def->mutable_input(i));
853 }
854 }
855 node_def->mutable_input()->Swap(copy.mutable_input());
856 // Remove 'inputs_to_remove' from 'input_already_exists'
857 for (int idx : inputs_to_remove) {
858 input_already_exists->erase(input_already_exists->begin() + idx);
859 }
860 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
861}
862
863void GraphConstructor::RemapNodeDefInputs(
864 NodeDef* node_def, std::vector<bool>* input_already_exists) {
865 DCHECK_EQ(input_already_exists->size(), node_def->input_size());
866 std::set<TensorId> control_inputs;
867 std::vector<int> inputs_to_remove;
868
869 for (int i = 0; i < node_def->input_size(); ++i) {
870 auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
871 if (iter == opts_.input_map.end()) continue;
872 used_input_map_keys_.insert(iter->first);
873
874 TensorId new_input = iter->second;
875 if (new_input.second == Graph::kControlSlot) {
876 // Check if we've already remapped a different input to new_input, and if
877 // so remove this input.
878 if (control_inputs.count(new_input) > 0) {
879 inputs_to_remove.push_back(i);
880 continue;
881 }
882 control_inputs.insert(new_input);
883 }
884 node_def->set_input(i, new_input.ToString());
885 (*input_already_exists)[i] = true;
886 }
887 if (!inputs_to_remove.empty()) {
888 RemoveInputs(inputs_to_remove, node_def, input_already_exists);
889 }
890}
891
892void GraphConstructor::AddControlDependencies(
893 NodeDef* node_def, std::vector<bool>* input_already_exists) {
894 // To avoid adding redundant control dependencies to every imported node, skip
895 // nodes that will inherit the dependencies from another imported node.
896 bool inherits_deps = false;
897 for (int i = 0; i < node_def->input_size(); ++i) {
898 // Assume we won't inherit dependencies from remapped inputs that already
899 // exist in the graph. Even if we're wrong, we'll only add redundant
900 // dependencies.
901 if ((*input_already_exists)[i]) continue;
902
903 // If this input is a backedge, assume we won't inherit the dependencies.
904 // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
905 // worth optimizing these.
906 TensorId id(ParseTensorName(node_def->input(i)));
907 auto iter = gdef_nodes_.find(id.first);
908 DCHECK(iter != gdef_nodes_.end()) << id.first;
909 if (iter->second.node == nullptr) {
910 // Input hasn't been created yet, indicating it's a backedge.
911 continue;
912 }
913 inherits_deps = true;
914 }
915 if (inherits_deps) return;
916
917 // node_def either has no inputs or all remapped inputs, add the control
918 // dependencies
919 for (const string& control_dep : opts_.control_dependencies) {
920 string input = TensorId(control_dep, Graph::kControlSlot).ToString();
921 bool found = false;
922 for (int i = node_def->input_size() - 1; i >= 0; --i) {
923 const string& node_input = node_def->input(i);
924 if (node_input[0] != '^') {
925 // Control inputs are at the end. Break when we reach the non-control
926 // inputs.
927 break;
928 }
929 if (node_input == input) {
930 // Control dependency already exists
931 found = true;
932 break;
933 }
934 }
935 if (found) {
936 continue;
937 }
938 node_def->add_input(input);
939 input_already_exists->push_back(true);
940 }
941}
942
943void GraphConstructor::AddPrefixToNodeDef(
944 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
945 if (prefix_.empty()) return;
946 node_def->set_name(strings::StrCat(prefix_, node_def->name()));
947 // Update names of input nodes
948 for (int i = 0; i < node_def->input_size(); ++i) {
949 // Skip remapped inputs (which already exist in g_ and are not being
950 // imported).
951 if (input_already_exists[i]) continue;
952 StringPiece input(node_def->input(i));
953 if (absl::ConsumePrefix(&input, "^")) {
954 node_def->set_input(i, strings::StrCat("^", prefix_, input));
955 } else {
956 node_def->set_input(i, strings::StrCat(prefix_, input));
957 }
958 }
959 // Update names of colocation groups
960 if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
961 auto* list =
962 node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
963 for (int i = 0; i < list->s_size(); ++i) {
964 StringPiece v(list->s(i));
965 if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) {
966 list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
967 }
968 }
969 }
970}
971
972void GraphConstructor::UniquifyNames(
973 const std::vector<bool>& input_already_exists, NodeDef* node_def) {
974 if (NameExistsInGraph(node_def->name())) {
975 string old_name = node_def->name();
976 node_def->set_name(FindUniqueName(node_def->name()));
977 uniquified_names_[old_name] = node_def->name();
978 // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
979 // `name` because we guarantee the original NodeDef names are unique,
980 // meaning we won't generate this name again.
981 }
982 for (int i = 0; i < node_def->input_size(); ++i) {
983 // Skip remapped inputs (which already exist in g_ and are not being
984 // imported).
985 if (input_already_exists[i]) continue;
986 TensorId id = ParseTensorName(node_def->input(i));
987 // We require that UniquifyNames() is called on all NodeDefs in topological
988 // order. This guarantees that node_def's inputs will already be uniquified
989 // if necessary.
990 auto iter = uniquified_names_.find(string(id.first));
991 if (iter == uniquified_names_.end()) continue;
992 id.first = iter->second;
993 node_def->set_input(i, id.ToString());
994 }
995}
996
997void GraphConstructor::UpdateUniquifiedColocationNames() {
998 for (const auto& pair : gdef_nodes_) {
999 Node* node = pair.second.node;
1000 if (node == nullptr) continue;
1001 std::vector<string> coloc_values;
1002 if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values))
1003 continue;
1004 bool updated = false;
1005 for (size_t i = 0; i < coloc_values.size(); ++i) {
1006 StringPiece val(coloc_values[i]);
1007 if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) {
1008 auto name_pair = uniquified_names_.find(string(val));
1009 if (name_pair == uniquified_names_.end()) continue;
1010 updated = true;
1011 coloc_values[i] =
1012 strings::StrCat(kColocationGroupPrefix, name_pair->second);
1013 }
1014 }
1015 if (updated) {
1016 node->AddAttr(kColocationAttrName, std::move(coloc_values));
1017 }
1018 }
1019}
1020
1021bool GraphConstructor::NameExistsInGraph(StringPiece name) {
1022 if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
1023 if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
1024 return false;
1025}
1026
1027bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
1028 if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
1029 if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
1030 return false;
1031}
1032
1033string GraphConstructor::FindUniqueName(StringPiece original_name) {
1034 string name(original_name);
1035 int count = 0;
1036 // Check that any generated names don't collide with imported NodeDefs (as
1037 // well as nodes in g_).
1038 while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
1039 name = strings::StrCat(original_name, "_", ++count);
1040 }
1041 return name;
1042}
1043
1044Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
1045 bool* is_node_mapped) {
1046 const OpDef* op_def;
1047 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1048 for (int i = 0; i < op_def->output_arg_size(); ++i) {
1049 if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
1050 *is_node_mapped = false;
1051 return OkStatus();
1052 }
1053 }
1054 *is_node_mapped = true;
1055 return OkStatus();
1056}
1057
1058void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch,
1059 std::vector<bool>* is_on_cur_branch,
1060 absl::flat_hash_set<int>* unvisited,
1061 const std::vector<absl::string_view>& node_names) {
1062 cur_branch->push_back(cur_node);
1063 is_on_cur_branch->at(cur_node) = true;
1064 for (auto next_node : outputs_[cur_node]) {
1065 if (unvisited->find(next_node) != unvisited->end()) {
1066 if (is_on_cur_branch->at(next_node)) {
1067 auto iter =
1068 std::find(cur_branch->begin(), cur_branch->end(), next_node);
1069 LOG(WARNING) << "Cycle detected:";
1070 while (iter != cur_branch->end()) {
1071 const absl::string_view name = node_names[*iter];
1072 DCHECK(!name.empty());
1073 LOG(WARNING) << "node id=" << *iter << ", name=" << name;
1074 ++iter;
1075 }
1076 LOG(WARNING) << "End of cycle";
1077 } else {
1078 DFS(next_node, cur_branch, is_on_cur_branch, unvisited, node_names);
1079 }
1080 }
1081 }
1082 cur_branch->pop_back();
1083 is_on_cur_branch->at(cur_node) = false;
1084 unvisited->erase(cur_node);
1085}
1086
1087void GraphConstructor::PrintCycles() {
1088 int num_nodes = outputs_.size();
1089
1090 std::vector<absl::string_view> node_names;
1091 node_names.resize(num_nodes);
1092 for (const auto& named_node : gdef_nodes_) {
1093 DCHECK_GE(named_node.second.gdef_index, 0);
1094 DCHECK_LT(named_node.second.gdef_index, num_nodes);
1095 node_names[named_node.second.gdef_index] = named_node.first;
1096 }
1097
1098 absl::flat_hash_set<int> unvisited;
1099 for (int i = 0; i < num_nodes; i++) {
1100 unvisited.insert(i);
1101 }
1102
1103 while (!unvisited.empty()) {
1104 int cur_node = *unvisited.begin();
1105 // Nodes on the current branch of DFS in traversal order. This is used for
1106 // printing the nodes in the cycle.
1107 std::vector<int> cur_branch;
1108 // This is just to make lookups O(1).
1109 // is_on_cur_branch[i] ==
1110 // (std::find(cur_branch.start(),
1111 // cur_branch.end(), i) != cur_branch.end())
1112 std::vector<bool> is_on_cur_branch(num_nodes, false);
1113 DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited, node_names);
1114 }
1115}
1116
1117Status GraphConstructor::Convert() {
1118 // Import functions before adding nodes, since imported nodes may refer to
1119 // functions
1120 if (library()) {
1121 // TODO(b/135705010): Add rvalue overloads into the function library, to
1122 // avoid unnecessarily copying `*library()` here.
1123 TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library()));
1124 }
1125
1126 std::vector<InputInfo> inputs;
1127 int processed = 0;
1128
1129 std::vector<bool> input_already_exists;
1130
1131 // Process the NodeDefs in topological order.
1132 // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
1133 // inputs, pending_counts_ with the number of inputs for each node and
1134 // outputs_ with the outputs of each node).
1135 while (!ready_.empty()) {
1136 int o = *ready_.begin();
1137 ready_.erase(ready_.begin());
1138 ++processed;
1139 inputs.clear();
1140 bool has_data_back_edge = false;
1141
1142 NodeDef node_def = consume_node_def(o);
1143
1144 // input_already_exists[i] is true iff the i-th input of the node we're
1145 // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
1146 // to importing node_defs_). Conversely, input_already_exists[i] is false
1147 // iff the input refers to a node in node_defs_.
1148 input_already_exists.clear();
1149 input_already_exists.resize(node_def.input_size(), false);
1150
1151 std::string node_name = node_def.name();
1152
1153 if (opts_.importing) {
1154 if (opts_.skip_mapped_nodes) {
1155 bool is_node_mapped = false;
1156 TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped));
1157 if (is_node_mapped) {
1158 // Skip this node after updating pending_count_ for outputs
1159 UpdatePendingCountAndReady(o, IsNextIteration(node_def));
1160 continue;
1161 }
1162 }
1163
1164 if (!opts_.input_map.empty()) {
1165 // Note that input_already_exists can shrink here
1166 RemapNodeDefInputs(&node_def, &input_already_exists);
1167 }
1168 if (!opts_.control_dependencies.empty()) {
1169 // Note that input_already_exists can grow here
1170 AddControlDependencies(&node_def, &input_already_exists);
1171 }
1172 if (!opts_.default_device.empty() && node_def.device().empty()) {
1173 node_def.set_device(opts_.default_device);
1174 }
1175 }
1176
1177 DCHECK_EQ(node_def.input_size(), input_already_exists.size());
1178 TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def));
1179 for (int i = 0; i < node_def.input_size(); ++i) {
1180 TensorId tensor_id = ParseTensorName(node_def.input(i));
1181 Node* src_node;
1182 int src_index;
1183
1184 if (!input_already_exists[i]) {
1185 // Locate input in newly-imported nodes
1186 auto iter = gdef_nodes_.find(tensor_id.node());
1187 DCHECK(iter != gdef_nodes_.end()) << tensor_id.node();
1188 src_node = iter->second.node;
1189 src_index = tensor_id.index();
1190 if (src_node == nullptr) has_data_back_edge = true;
1191 } else {
1192 // Input refers to preexistng node in graph
1193 auto iter = existing_nodes_.find(tensor_id.node());
1194 DCHECK(iter != existing_nodes_.end()) << tensor_id.node();
1195 src_node = iter->second;
1196 src_index = tensor_id.index();
1197 }
1198
1199 if (src_node != nullptr && src_index >= src_node->num_outputs()) {
1200 std::ostringstream out;
1201 out << "Node '" << node_def.name() << "': Connecting to invalid output "
1202 << tensor_id.index() << " of source node " << tensor_id.node()
1203 << " which has " << src_node->num_outputs() << " outputs.";
1204
1205 if (src_node->type_string() == "If" ||
1206 src_node->type_string() == "StatelessIf" ||
1207 src_node->type_string() == "While" ||
1208 src_node->type_string() == "StatelessWhile") {
1209 out << " Try using "
1210 << "tf.compat.v1.experimental.output_all_intermediates(True).";
1211 }
1212 return errors::InvalidArgument(out.str());
1213 }
1214
1215 inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
1216 }
1217
1218 if (has_data_back_edge && !IsMerge(node_def)) {
1219 return errors::InvalidArgument(
1220 "Node '", node_def.name(),
1221 "' had a back edge, but only Merge nodes can have back edges.");
1222 }
1223
1224 Node* node;
1225 if (opts_.importing) {
1226 if (!prefix_.empty()) {
1227 AddPrefixToNodeDef(input_already_exists, &node_def);
1228 }
1229 // Note: no need to uniquify names if the prefix already guarantees
1230 // uniqueness
1231 if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
1232 UniquifyNames(input_already_exists, &node_def);
1233 }
1234 }
1235
1236 if (opts_.importing) {
1237 TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
1238 } else {
1239 const OpDef* op_def;
1240 TF_RETURN_IF_ERROR(
1241 g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1242 if (opts_.add_default_attributes) {
1243 AddDefaultsToNodeDef(*op_def, &node_def);
1244 }
1245 if (opts_.validate_nodes) {
1246 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
1247 }
1248 }
1249
1250 TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
1251
1252 gdef_nodes_[node_name].node = node;
1253
1254 // Remove duplicate control inputs before adding edges to the graph. It
1255 // will allow us to skip expensive duplicates check in 'AddControlEdge'.
1256 auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput);
1257 auto first_control_copy = first_control;
1258 std::sort(first_control, inputs.end(), &InputInfo::CompareName);
1259 inputs.erase(
1260 std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName),
1261 inputs.end());
1262
1263 // Add edges from inputs to *node to the graph.
1264 for (size_t i = 0; i < inputs.size(); ++i) {
1265 if (inputs[i].node == nullptr) {
1266 // Record this back edge, which will be added after all nodes
1267 // are created.
1268 back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
1269 } else if (inputs[i].index == Graph::kControlSlot) {
1270 g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates);
1271 } else {
1272 TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
1273 }
1274 }
1275
1276 TF_RETURN_IF_ERROR(ValidateShape(node));
1277
1278 // Update pending_count_ for outputs.
1279 UpdatePendingCountAndReady(o, node->IsNextIteration());
1280 }
1281
1282 if (processed < node_def_count()) {
1283 LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed)
1284 << " NODES IN A CYCLE";
1285 for (int64_t i = 0; i < node_def_count(); i++) {
1286 if (pending_count_[i] != 0) {
1287 LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i))
1288 << " WITH PENDING COUNT = " << pending_count_[i];
1289 }
1290 }
1291 PrintCycles();
1292 return errors::InvalidArgument(node_def_count() - processed,
1293 " nodes in a cycle");
1294 }
1295
1296 return OkStatus();
1297}
1298
1299Status GraphConstructor::AddBackEdges() {
1300 // Add the back edges after all nodes are created.
1301 for (const auto& e : back_edges_) {
1302 Node* src_node = gdef_nodes_[e.src_name].node;
1303 if (e.src_index == Graph::kControlSlot) {
1304 g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates);
1305 } else {
1306 TF_RETURN_IF_ERROR(
1307 MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
1308 }
1309
1310 VLOG(2) << "Add back edge: " << src_node->name() << " -> "
1311 << e.dst_node->name();
1312 }
1313 return OkStatus();
1314}
1315
1316Status GraphConstructor::UpdateVersionDef() {
1317 if (versions() == nullptr) return OkStatus();
1318
1319 if (!opts_.importing) {
1320 g_->set_versions(*versions());
1321 return OkStatus();
1322 }
1323 VersionDef g_versions = g_->versions();
1324 g_versions.set_producer(
1325 std::min(g_versions.producer(), versions()->producer()));
1326 g_versions.set_min_consumer(
1327 std::max(g_versions.min_consumer(), versions()->min_consumer()));
1328 if (versions()->bad_consumers_size() > 0) {
1329 std::set<int> bad(g_versions.bad_consumers().begin(),
1330 g_versions.bad_consumers().end());
1331 bad.insert(versions()->bad_consumers().begin(),
1332 versions()->bad_consumers().end());
1333 g_versions.clear_bad_consumers();
1334 for (int v : bad) {
1335 g_versions.add_bad_consumers(v);
1336 }
1337 }
1338 g_->set_versions(g_versions);
1339 return OkStatus();
1340}
1341
1342Status GraphConstructor::PopulateReturnTensors() {
1343 if (opts_.return_tensors.empty()) return OkStatus();
1344 for (const TensorId& id : opts_.return_tensors) {
1345 auto iter = opts_.input_map.find(id);
1346 if (iter == opts_.input_map.end()) {
1347 // Locate id in imported nodes
1348 auto iter = gdef_nodes_.find(id.first);
1349 if (iter == gdef_nodes_.end()) {
1350 return errors::InvalidArgument("Requested return tensor '",
1351 id.ToString(),
1352 "' not found in graph def");
1353 }
1354 int num_outputs = iter->second.node->num_outputs();
1355 if ((id.second < 0 || id.second >= num_outputs) &&
1356 id.second != Graph::kControlSlot) {
1357 return errors::InvalidArgument("Invalid return output ", id.second,
1358 " of node '", id.first, "', which has ",
1359 num_outputs, " output(s)");
1360 }
1361 return_tensors_->push_back({iter->second.node, id.second});
1362 } else {
1363 // id was remapped to existing node
1364 TensorId remapped_id = iter->second;
1365 DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
1366 Node* node = existing_nodes_[remapped_id.first];
1367 return_tensors_->push_back({node, remapped_id.second});
1368 }
1369 }
1370 return OkStatus();
1371}
1372
1373Status GraphConstructor::PopulateReturnNodes() {
1374 if (opts_.return_nodes.empty()) return OkStatus();
1375 for (StringPiece name : opts_.return_nodes) {
1376 auto iter = gdef_nodes_.find(name);
1377 if (iter == gdef_nodes_.end()) {
1378 return errors::InvalidArgument("Requested return node '", name,
1379 "' not found in graph def");
1380 }
1381 return_nodes_->push_back(iter->second.node);
1382 }
1383 return OkStatus();
1384}
1385
1386Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
1387 if (missing_unused_input_map_keys_ == nullptr) return OkStatus();
1388 for (const auto& input_map_pair : opts_.input_map) {
1389 TensorId key = input_map_pair.first;
1390 if (used_input_map_keys_.count(key) > 0) continue;
1391
1392 auto pair = gdef_nodes_.find(key.first);
1393 if (pair == gdef_nodes_.end()) {
1394 // key's node doesn't exist in GraphDef
1395 missing_unused_input_map_keys_->push_back(key);
1396 continue;
1397 }
1398
1399 // Check that key's index is in bounds. Get the number of outputs from the
1400 // NodeDef, rather than the imported Node, since the Node may not exist if
1401 // opts_.skip_mapped_nodes is true.
1402 const NodeDef& node_def = get_node_def(pair->second.gdef_index);
1403 const OpDef* op_def;
1404 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
1405 int num_outputs;
1406 TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs));
1407 if (key.second >= num_outputs) {
1408 // key's index out of bounds
1409 missing_unused_input_map_keys_->push_back(key);
1410 }
1411 }
1412 return OkStatus();
1413}
1414
1415void GraphConstructor::Undo() {
1416 for (const auto& iter : gdef_nodes_) {
1417 if (iter.second.node != nullptr) {
1418 g_->RemoveNode(iter.second.node);
1419 }
1420 }
1421 g_->set_versions(original_versions_);
1422}
1423
1424Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
1425 int input_index) {
1426 if (output_index >= src->num_outputs()) {
1427 return errors::InvalidArgument(
1428 "Output ", output_index, " of node ", src->name(),
1429 " does not exist. Node only has ", src->num_outputs(), " outputs.");
1430 }
1431 if (input_index >= dst->num_inputs()) {
1432 return errors::InvalidArgument(
1433 "Input ", input_index, " of node ", dst->name(),
1434 " does not exist. Node only has ", dst->num_inputs(), " inputs.");
1435 }
1436
1437 DataType src_out = src->output_type(output_index);
1438 DataType dst_in = dst->input_type(input_index);
1439 if (!TypesCompatible(dst_in, src_out)) {
1440 return errors::InvalidArgument(
1441 "Input ", input_index, " of node ", dst->name(), " was passed ",
1442 DataTypeString(src_out), " from ", src->name(), ":", output_index,
1443 " incompatible with expected ", DataTypeString(dst_in), ".");
1444 }
1445 g_->AddEdge(src, output_index, dst, input_index);
1446 return OkStatus();
1447}
1448
1449} // namespace
1450
1451Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1452 const GraphDef& gdef, Graph* g) {
1453 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1454 return GraphConstructor::Construct(
1455 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
1456 /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
1457 /*missing_unused_input_map_keys=*/nullptr);
1458}
1459
1460Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
1461 GraphDef&& gdef, Graph* g) {
1462 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
1463 return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner,
1464 /*return_tensors=*/nullptr,
1465 /*return_nodes=*/nullptr,
1466 /*missing_unused_input_map_keys=*/nullptr);
1467}
1468
1469Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
1470 gtl::ArraySlice<NodeDef> nodes, Graph* g) {
1471 ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
1472 // TODO(irving): Copy will go away once NodeInfo exists
1473 std::vector<const NodeDef*> node_defs;
1474 node_defs.reserve(nodes.size());
1475 for (const auto& n : nodes) {
1476 node_defs.push_back(&n);
1477 }
1478 return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
1479 &refiner, /*return_tensors=*/nullptr,
1480 /*return_nodes=*/nullptr,
1481 /*missing_unused_input_map_keys=*/nullptr);
1482}
1483
1484Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
1485 Graph* g, ShapeRefiner* refiner,
1486 ImportGraphDefResults* results) {
1487 if (!opts.return_tensors.empty()) {
1488 if (results == nullptr) {
1489 return errors::InvalidArgument(
1490 "results argument to ImportGraphDef() must be non-null if "
1491 "opts.return_tensors is non-empty");
1492 }
1493 }
1494
1495 if (!opts.return_nodes.empty()) {
1496 if (opts.skip_mapped_nodes) {
1497 return errors::InvalidArgument(
1498 "Requesting return_nodes with skip_mapped_nodes set is not currently "
1499 "supported");
1500 }
1501 if (results == nullptr) {
1502 return errors::InvalidArgument(
1503 "results argument to ImportGraphDef() must be non-null if "
1504 "opts.return_nodes is non-empty");
1505 }
1506 }
1507
1508 if (results != nullptr) {
1509 if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
1510 !results->missing_unused_input_map_keys.empty()) {
1511 return errors::InvalidArgument(
1512 "All fields in results argument to ImportGraphDef() must be empty.");
1513 }
1514 }
1515
1516 ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
1517 if (refiner == nullptr) {
1518 refiner = &default_refiner;
1519 } else {
1520 // Log a warning if we are importing a GraphDef at an older
1521 // producer version after already having added non-source/sink
1522 // nodes to the graph in the past.
1523 if (gdef.versions().producer() > 0 &&
1524 gdef.versions().producer() < refiner->graph_def_version() &&
1525 g->num_nodes() > 2) {
1526 LOG(WARNING) << "Importing a graph with a lower producer version "
1527 << gdef.versions().producer()
1528 << " into an existing graph with producer version "
1529 << refiner->graph_def_version() << ". Shape inference will "
1530 << "have run different parts of the graph with different "
1531 << "producer versions.";
1532 }
1533 }
1534
1535 // Set the graph def version of the refiner as the min of the
1536 // current value and the version from the graph we are about to
1537 // import.
1538 //
1539 // Note: to match Run() semantics, we should re-run shape inference
1540 // on the entire graph if the producer version has changed. For now
1541 // we log the warning above.
1542 refiner->set_graph_def_version(
1543 std::min(refiner->graph_def_version(), gdef.versions().producer()));
1544
1545 if (results == nullptr) {
1546 return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
1547 &gdef.library(), g, refiner, nullptr,
1548 nullptr, nullptr);
1549 } else {
1550 return GraphConstructor::Construct(
1551 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
1552 &results->return_tensors, &results->return_nodes,
1553 &results->missing_unused_input_map_keys);
1554 }
1555}
1556
1557void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
1558
1559} // namespace tensorflow
1560