1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
17 | |
18 | #include <algorithm> |
19 | #include <set> |
20 | #include <string> |
21 | #include <unordered_map> |
22 | #include <unordered_set> |
23 | #include <vector> |
24 | |
25 | #include "absl/algorithm/container.h" |
26 | #include "absl/container/flat_hash_set.h" |
27 | #include "absl/strings/str_cat.h" |
28 | #include "absl/strings/string_view.h" |
29 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
30 | #include "tensorflow/core/framework/function.h" |
31 | #include "tensorflow/core/framework/function.pb.h" |
32 | #include "tensorflow/core/framework/graph.pb.h" |
33 | #include "tensorflow/core/framework/node_def.pb.h" |
34 | #include "tensorflow/core/framework/node_def_util.h" |
35 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
36 | #include "tensorflow/core/framework/types.h" |
37 | #include "tensorflow/core/framework/versions.h" |
38 | #include "tensorflow/core/framework/versions.pb.h" |
39 | #include "tensorflow/core/graph/algorithm.h" |
40 | #include "tensorflow/core/graph/graph.h" |
41 | #include "tensorflow/core/graph/tensor_id.h" |
42 | #include "tensorflow/core/lib/core/errors.h" |
43 | #include "tensorflow/core/lib/gtl/flatmap.h" |
44 | #include "tensorflow/core/lib/gtl/flatset.h" |
45 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
46 | #include "tensorflow/core/lib/strings/scanner.h" |
47 | #include "tensorflow/core/lib/strings/str_util.h" |
48 | #include "tensorflow/core/platform/errors.h" |
49 | #include "tensorflow/core/platform/logging.h" |
50 | #include "tensorflow/core/platform/macros.h" |
51 | #include "tensorflow/core/public/version.h" |
52 | |
53 | namespace tensorflow { |
54 | |
55 | namespace { |
56 | |
57 | // We remove duplicate control inputs before adding edges to the Graph, so we |
58 | // can skip expensive duplicates check in 'AddControlEdge'. |
59 | static constexpr const bool kDoNotCheckDuplicates = true; |
60 | |
61 | inline bool IsMerge(const NodeDef& node_def) { |
62 | return node_def.op() == "Merge" || node_def.op() == "RefMerge" || |
63 | node_def.op() == "_XlaMerge" ; |
64 | } |
65 | |
66 | inline bool IsNextIteration(const NodeDef& node_def) { |
67 | return node_def.op() == "NextIteration" || |
68 | node_def.op() == "RefNextIteration" ; |
69 | } |
70 | |
71 | bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { |
72 | using ::tensorflow::strings::Scanner; |
73 | Scanner scanner(s); |
74 | scanner |
75 | .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE |
76 | : Scanner::LETTER_DIGIT_DOT) |
77 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); |
78 | |
79 | while (true) { |
80 | if (!scanner.GetResult()) // Some error in previous iteration. |
81 | return false; |
82 | if (scanner.empty()) // No error, but nothing left, good. |
83 | return true; |
84 | |
85 | // Absorb another piece, starting with a '>' |
86 | scanner.One(Scanner::RANGLE) |
87 | .One(Scanner::LETTER_DIGIT_DOT) |
88 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); |
89 | } |
90 | } |
91 | |
92 | class GraphConstructor { |
93 | public: |
94 | struct Options { |
95 | Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit) |
96 | : allow_internal_ops(in.allow_internal_ops), |
97 | expect_device_spec(in.expect_device_spec), |
98 | importing(false), |
99 | validate_nodes(in.validate_nodes), |
100 | validate_colocation_constraints(false), |
101 | add_default_attributes(in.add_default_attributes) {} |
102 | Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit) |
103 | : allow_internal_ops(false), |
104 | expect_device_spec(false), |
105 | prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/" ) |
106 | ? in.prefix |
107 | : in.prefix + "/" ), |
108 | uniquify_names(in.uniquify_names), |
109 | uniquify_prefix(in.uniquify_prefix), |
110 | input_map(in.input_map.begin(), in.input_map.end()), |
111 | skip_mapped_nodes(in.skip_mapped_nodes), |
112 | control_dependencies(in.control_dependencies), |
113 | return_tensors(in.return_tensors.begin(), in.return_tensors.end()), |
114 | return_nodes(in.return_nodes), |
115 | importing(true), |
116 | validate_nodes(true), |
117 | validate_colocation_constraints(in.validate_colocation_constraints), |
118 | validate_shape(in.validate_shape), |
119 | default_device(in.default_device) {} |
120 | |
121 | bool allow_internal_ops; |
122 | bool expect_device_spec; |
123 | |
124 | string prefix; |
125 | bool uniquify_names; |
126 | bool uniquify_prefix; |
127 | std::map<TensorId, TensorId> input_map; |
128 | bool skip_mapped_nodes; |
129 | std::vector<string> control_dependencies; |
130 | std::vector<TensorId> return_tensors; |
131 | std::vector<string> return_nodes; |
132 | |
133 | // TODO(ashankar): This bool exists to separate out functionality required |
134 | // to make ImportGraphDef a close equivalent of Python's import_graph_def |
135 | // without affecting the behavior of ConvertGraphDefToGraph at the time |
136 | // ImportGraphDef was added. |
137 | // |
138 | // That said, the functionality here (shape and op validation) seems |
139 | // applicable to ConvertGraphDefToGraph as well, so make an attempt to |
140 | // remove this. |
141 | bool importing; |
142 | // If true, validates that nodes being converted have all expected attrs |
143 | // set and no unknown attrs set by calling ValidateNodeDef(). |
144 | // `validate_nodes` is always true when `importing` is set. |
145 | bool validate_nodes; |
146 | bool validate_colocation_constraints; |
147 | bool validate_shape = true; |
148 | |
149 | // If true, GraphConstructor will add attributes with their default |
150 | // value to the Node when they are missing from the NodeDef. |
151 | bool add_default_attributes = true; |
152 | |
153 | string default_device; |
154 | }; |
155 | |
156 | typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice; |
157 | |
158 | // versions and library may be nullptr |
159 | static Status Construct( |
160 | const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, |
161 | const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, |
162 | std::vector<std::pair<Node*, int>>* return_tensors, |
163 | std::vector<Node*>* return_nodes, |
164 | std::vector<SafeTensorId>* missing_unused_input_map_keys); |
165 | |
166 | static Status Construct( |
167 | const Options& opts, GraphDef&& graph_def, Graph* g, |
168 | ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors, |
169 | std::vector<Node*>* return_nodes, |
170 | std::vector<SafeTensorId>* missing_unused_input_map_keys); |
171 | |
172 | protected: |
173 | GraphConstructor(const Options& opts, Graph* g, ShapeRefiner* refiner, |
174 | std::vector<std::pair<Node*, int>>* return_tensors, |
175 | std::vector<Node*>* return_nodes, |
176 | std::vector<SafeTensorId>* missing_unused_input_map_keys) |
177 | : opts_(opts), |
178 | g_(g), |
179 | original_versions_(g->versions()), |
180 | prefix_(opts.prefix), |
181 | refiner_(refiner), |
182 | return_tensors_(return_tensors), |
183 | return_nodes_(return_nodes), |
184 | missing_unused_input_map_keys_(missing_unused_input_map_keys) {} |
185 | |
186 | virtual ~GraphConstructor() {} |
187 | |
188 | Status TryImport() { |
189 | TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); |
190 | TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies()); |
191 | TF_RETURN_IF_ERROR(BuildNodeIndex()); |
192 | TF_RETURN_IF_ERROR(InitFromEdges()); |
193 | |
194 | // NOTE: Convert() invokes `consume_node_def()` on each node in the input |
195 | // graph, so `get_node_def()` is no longer usable once it is called. |
196 | TF_RETURN_IF_ERROR(Convert()); |
197 | |
198 | TF_RETURN_IF_ERROR(AddBackEdges()); |
199 | TF_RETURN_IF_ERROR(UpdateVersionDef()); |
200 | TF_RETURN_IF_ERROR(PopulateReturnTensors()); |
201 | TF_RETURN_IF_ERROR(PopulateReturnNodes()); |
202 | TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys()); |
203 | UpdateUniquifiedColocationNames(); |
204 | FixupSourceAndSinkEdges(g_); |
205 | return OkStatus(); |
206 | } |
207 | |
208 | private: |
209 | Status EnsureNoNameCollisions(); |
210 | Status ValidateInputMapAndControlDependencies(); |
211 | Status BuildNodeIndex(); |
212 | Status InitFromEdges(); |
213 | Status Convert(); |
214 | Status AddBackEdges(); |
215 | Status UpdateVersionDef(); |
216 | Status PopulateReturnTensors(); |
217 | Status PopulateReturnNodes(); |
218 | Status PopulateMissingUnusedInputMapKeys(); |
219 | |
220 | void Undo(); |
221 | |
222 | // Prints cycles in the graph. |
223 | void PrintCycles(); |
224 | // Performs DFS starting at `cur_node` and prints any cycles found. |
225 | void DFS(int cur_node, std::vector<int>* cur_branch, |
226 | std::vector<bool>* is_on_cur_branch, |
227 | absl::flat_hash_set<int>* unvisited, |
228 | const std::vector<absl::string_view>& node_names); |
229 | Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped); |
230 | Status ValidateColocationConstraints(const NodeDef& node_def); |
231 | Status MakeNode(NodeDef&& node_def, Node** node); |
232 | Status MakeEdge(Node* src, int output_index, Node* dst, int input_index); |
233 | Status ValidateShape(Node* node); |
234 | Status ModifyNodeDefForImport(NodeDef* node_def); |
235 | // Modifies node_def's inputs according to opts_.input_map. |
236 | // input_already_exists is a pre-initialized vector of length |
237 | // node_def->input_size(). This function will mark inputs that are remapped to |
238 | // true. |
239 | void RemapNodeDefInputs(NodeDef* node_def, |
240 | std::vector<bool>* input_already_exists); |
241 | // input_already_exists is a pre-initialized vector of length |
242 | // node_def->input_size(). This function will add and mark control inputs as |
243 | // true. |
244 | void AddControlDependencies(NodeDef* node_def, |
245 | std::vector<bool>* input_already_exists); |
246 | void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists, |
247 | NodeDef* node_def); |
248 | |
249 | // Modifies `node_def` if its name isn't unique, or if any of its inputs' |
250 | // names have been uniquified. This must be called in topological order on all |
251 | // nodes. |
252 | void UniquifyNames(const std::vector<bool>& input_already_exists, |
253 | NodeDef* node_def); |
254 | |
255 | // Updates any constructed nodes' colocation group names if the name has been |
256 | // updated by UniquifyNames. This is called after all the nodes have been |
257 | // constructed so all the names have been uniquified if necessary. |
258 | void UpdateUniquifiedColocationNames(); |
259 | |
260 | // Returns true if `name` already exists in `g_` (either as a node name or |
261 | // prefix). |
262 | bool NameExistsInGraph(StringPiece name); |
263 | |
264 | // Returns true if `name` already exists in the GraphDef being imported |
265 | // (either as a node name or prefix). |
266 | bool NameExistsInGraphDef(StringPiece name); |
267 | |
268 | // Returns a unique version of `original_name`, or `original_name` if it's |
269 | // already unique in the graph. |
270 | string FindUniqueName(StringPiece original_name); |
271 | |
272 | // Decrement pending count for users of `processed` and add the ones that now |
273 | // have all of their pending inputs satisfied to `ready_`. |
274 | void UpdatePendingCountAndReady(int processed, bool is_next_iteration); |
275 | |
276 | // Subclasses override the following virtual methods to provide efficient |
277 | // access to the original protocol buffer-based graph. |
278 | |
279 | // Returns the number of nodes in the graph. |
280 | virtual size_t node_def_count() const = 0; |
281 | // Returns the i^th node in the graph. Must not be called after |
282 | // consume_node_def(i). |
283 | virtual const NodeDef& get_node_def(int i) const = 0; |
284 | // Destructively reads the i^th node in the graph, avoiding a copy if |
285 | // possible. After calling this method, the result of get_node_def(i) is |
286 | // undefined. |
287 | virtual NodeDef consume_node_def(int i) = 0; |
288 | // Returns the version information for the graph, or nullptr if none is |
289 | // available. |
290 | virtual const VersionDef* versions() const = 0; |
291 | // Returns the function information for the graph, or nullptr if none is |
292 | // available. |
293 | virtual const FunctionDefLibrary* library() const = 0; |
294 | |
295 | // From constructor |
296 | const Options opts_; |
297 | Graph* g_; |
298 | const VersionDef original_versions_; |
299 | |
300 | // A copy of opts_.prefix, possibly uniquified. |
301 | string prefix_; |
302 | |
303 | ShapeRefiner* refiner_; |
304 | |
305 | // May be null. Not owned. |
306 | std::vector<std::pair<Node*, int>>* return_tensors_; |
307 | |
308 | // May be null. Not owned. |
309 | std::vector<Node*>* return_nodes_; |
310 | |
311 | // May be null. Not owned. |
312 | std::vector<SafeTensorId>* missing_unused_input_map_keys_; |
313 | |
314 | // Intermediate datastructure used to populate |
315 | // `missing_unused_input_map_keys_`. |
316 | std::set<TensorId> used_input_map_keys_; |
317 | |
318 | // Intermediate datastructure used to track the destinations of back edges. |
319 | absl::flat_hash_set<int> merge_node_indices_; |
320 | |
321 | // Mapping from node name to the index within node_defs_. |
322 | struct NodeInfo { |
323 | explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {} |
324 | // Containers require that we have a default constructor. |
325 | NodeInfo() : NodeInfo(-1) {} |
326 | int gdef_index; |
327 | Node* node; // nullptr until the NodeDef is converted to a Node. |
328 | }; |
329 | absl::flat_hash_map<std::string, NodeInfo> gdef_nodes_; |
330 | |
331 | // Prefixes already used in the GraphDef being imported. |
332 | absl::flat_hash_set<StringPiece> gdef_prefixes_; |
333 | |
334 | // Mapping from node name to the existing node in g_. |
335 | absl::flat_hash_map<StringPiece, Node*> existing_nodes_; |
336 | |
337 | // Prefixes already used in the graph. |
338 | absl::flat_hash_set<StringPiece> existing_prefixes_; |
339 | |
340 | // Imported node names that have been uniquified. The key is the original |
341 | // name, the value is the new unique name. |
342 | gtl::FlatMap<string, string> uniquified_names_; |
343 | |
344 | // Index of NodeDefs in node_defs_ with all inputs already converted. We use a |
345 | // (sorted) set so nodes are created in the order defined in the GraphDef. |
346 | std::set<int> ready_; |
347 | |
348 | // Mapping between index within node_defs_ and the number of inputs that |
349 | // still need to be converted. |
350 | std::vector<int> pending_count_; |
351 | |
352 | // Mapping between index within node_defs_ and the index within node_defs_ of |
353 | // all nodes it outputs to. |
354 | std::vector<gtl::InlinedVector<int, 4>> outputs_; |
355 | |
356 | // Used in the conversion from node_defs_ to g_ to represent the ith input |
357 | // of a node. |
358 | struct InputInfo { |
359 | explicit InputInfo(const string& node_name, Node* n, int i) |
360 | : name(node_name), node(n), index(i) {} |
361 | // Use string instead of StringPiece so we don't have to manage lifetime |
362 | string name; |
363 | Node* node; |
364 | int index; |
365 | |
366 | static bool IsControlInput(const InputInfo& input) { |
367 | return input.index == Graph::kControlSlot; |
368 | } |
369 | static int CompareName(const InputInfo& lhs, const InputInfo& rhs) { |
370 | return lhs.name < rhs.name; |
371 | } |
372 | static bool IsSameName(const InputInfo& lhs, const InputInfo& rhs) { |
373 | return lhs.name == rhs.name; |
374 | } |
375 | }; |
376 | |
377 | // Used in the conversion from node_defs_ to g_ to represent an edge from |
378 | // the node named 'name' to node 'n'. |
379 | struct EdgeInfo { |
380 | explicit EdgeInfo(const string& name, int i1, Node* n, int i2) |
381 | : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {} |
382 | // Use string instead of StringPiece so we don't have to manage lifetime |
383 | string src_name; |
384 | int src_index; |
385 | Node* dst_node; |
386 | int dst_index; |
387 | }; |
388 | std::vector<EdgeInfo> back_edges_; |
389 | |
390 | TF_DISALLOW_COPY_AND_ASSIGN(GraphConstructor); |
391 | }; |
392 | |
393 | // Implementation of GraphConstructor that does not take ownership of the |
394 | // input NodeDef messages and thus copies the nodes into the constructed Graph*. |
395 | // |
396 | // NOTE(mrry): Whenever possible, use NodeDefMovingGraphConstructor, which |
397 | // avoids copying each NodeDef into the constructed Graph*. |
398 | class NodeDefCopyingGraphConstructor : public GraphConstructor { |
399 | public: |
400 | NodeDefCopyingGraphConstructor( |
401 | const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, |
402 | const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, |
403 | std::vector<std::pair<Node*, int>>* return_tensors, |
404 | std::vector<Node*>* return_nodes, |
405 | std::vector<SafeTensorId>* missing_unused_input_map_keys) |
406 | : GraphConstructor(opts, g, refiner, return_tensors, return_nodes, |
407 | missing_unused_input_map_keys), |
408 | node_defs_(node_defs), |
409 | versions_(versions), |
410 | library_(library) {} |
411 | |
412 | private: |
413 | size_t node_def_count() const override { return node_defs_.size(); } |
414 | const NodeDef& get_node_def(int i) const override { return *node_defs_[i]; } |
415 | NodeDef consume_node_def(int i) override { return *node_defs_[i]; } |
416 | const VersionDef* versions() const override { return versions_; } |
417 | const FunctionDefLibrary* library() const override { return library_; } |
418 | |
419 | const NodeDefSlice node_defs_; |
420 | const VersionDef* const versions_; |
421 | const FunctionDefLibrary* const library_; |
422 | }; |
423 | |
424 | // Implementation of GraphConstructor that takes ownership of the input |
425 | // GraphDef, and can perform destructive reads. |
426 | class NodeDefMovingGraphConstructor : public GraphConstructor { |
427 | public: |
428 | NodeDefMovingGraphConstructor( |
429 | const Options& opts, GraphDef&& graph_def, Graph* g, |
430 | ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors, |
431 | std::vector<Node*>* return_nodes, |
432 | std::vector<SafeTensorId>* missing_unused_input_map_keys) |
433 | : GraphConstructor(opts, g, refiner, return_tensors, return_nodes, |
434 | missing_unused_input_map_keys), |
435 | graph_def_(std::move(graph_def)), |
436 | is_consumed_(graph_def_.node_size(), false) {} |
437 | |
438 | private: |
439 | size_t node_def_count() const override { return graph_def_.node().size(); } |
440 | const NodeDef& get_node_def(int i) const override { |
441 | CHECK(!is_consumed_[i]) |
442 | << "NodeDef " << i << " accessed after it was consumed." ; |
443 | return graph_def_.node(i); |
444 | } |
445 | NodeDef consume_node_def(int i) override { |
446 | CHECK(!is_consumed_[i]) << "NodeDef " << i << " consumed twice." ; |
447 | is_consumed_[i] = true; |
448 | return std::move(*graph_def_.mutable_node(i)); |
449 | } |
450 | const VersionDef* versions() const override { return &graph_def_.versions(); } |
451 | const FunctionDefLibrary* library() const override { |
452 | return &graph_def_.library(); |
453 | } |
454 | |
455 | GraphDef graph_def_; |
456 | std::vector<bool> is_consumed_; |
457 | }; |
458 | |
459 | bool ForwardCompatibilityWindowPassed(const VersionDef& versions) { |
460 | // TF_GRAPH_DEF_VERSION is incremented daily. |
461 | // TF has a 3 week forward compatibility guarantee. |
462 | return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21; |
463 | } |
464 | |
465 | Status MaybeAppendVersionWarning(const VersionDef* versions, |
466 | const Status& import_status) { |
467 | if (versions && ForwardCompatibilityWindowPassed(*versions)) { |
468 | return Status( |
469 | import_status.code(), |
470 | absl::StrCat( |
471 | "Converting GraphDef to Graph has failed with an error: '" , |
472 | import_status.error_message(), |
473 | "' The binary trying to import the GraphDef was built when " |
474 | "GraphDef version was " , |
475 | TF_GRAPH_DEF_VERSION, |
476 | ". The GraphDef was produced by a binary built when GraphDef " |
477 | "version was " , |
478 | versions->producer(), |
479 | ". The difference between these versions is larger than " |
480 | "TensorFlow's forward compatibility guarantee, and might be the " |
481 | "root cause for failing to import the GraphDef." )); |
482 | } |
483 | return import_status; |
484 | } |
485 | |
486 | /* static */ Status GraphConstructor::Construct( |
487 | const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, |
488 | const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, |
489 | std::vector<std::pair<Node*, int>>* return_tensors, |
490 | std::vector<Node*>* return_nodes, |
491 | std::vector<SafeTensorId>* missing_unused_input_map_keys) { |
492 | if (versions) { |
493 | TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, |
494 | TF_GRAPH_DEF_VERSION_MIN_PRODUCER, |
495 | "GraphDef" , "graph" )); |
496 | } |
497 | NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g, |
498 | refiner, return_tensors, return_nodes, |
499 | missing_unused_input_map_keys); |
500 | Status s = c.TryImport(); |
501 | if (!s.ok()) { |
502 | c.Undo(); |
503 | s = MaybeAppendVersionWarning(versions, s); |
504 | } |
505 | return s; |
506 | } |
507 | |
508 | /* static */ Status GraphConstructor::Construct( |
509 | const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner, |
510 | std::vector<std::pair<Node*, int>>* return_tensors, |
511 | std::vector<Node*>* return_nodes, |
512 | std::vector<SafeTensorId>* missing_unused_input_map_keys) { |
513 | TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION, |
514 | TF_GRAPH_DEF_VERSION_MIN_PRODUCER, |
515 | "GraphDef" , "graph" )); |
516 | VersionDef version_def = graph_def.versions(); |
517 | NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner, |
518 | return_tensors, return_nodes, |
519 | missing_unused_input_map_keys); |
520 | Status s = c.TryImport(); |
521 | if (!s.ok()) { |
522 | c.Undo(); |
523 | s = MaybeAppendVersionWarning(&version_def, s); |
524 | } |
525 | return s; |
526 | } |
527 | |
528 | void GraphConstructor::UpdatePendingCountAndReady(int processed, |
529 | bool is_next_iteration) { |
530 | for (size_t i = 0; i < outputs_[processed].size(); ++i) { |
531 | const int output = outputs_[processed][i]; |
532 | // We didn't consider NextIteration->Merge edges when computing |
533 | // pending_counts_ so we should not have to consider it here either. |
534 | bool is_next_iteration_to_merge_edge = |
535 | is_next_iteration && merge_node_indices_.count(output) == 1; |
536 | if (!is_next_iteration_to_merge_edge) { |
537 | int* current_pending_count = &pending_count_[output]; |
538 | CHECK_GT(*current_pending_count, 0); |
539 | (*current_pending_count)--; |
540 | if (*current_pending_count == 0) { |
541 | ready_.insert(output); |
542 | } |
543 | } |
544 | } |
545 | } |
546 | |
547 | // This could be expensive but we don't expect to call it often, if at all (only |
548 | // if there are multiple nodes in g_ with the same name) |
549 | bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map, |
550 | const StringPiece& node_name) { |
551 | for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) { |
552 | if (iter->second.first == node_name) return true; |
553 | } |
554 | return false; |
555 | } |
556 | |
557 | bool NodeNameInValues(const std::vector<string>& control_dependencies, |
558 | const StringPiece& node_name) { |
559 | return std::find(control_dependencies.begin(), control_dependencies.end(), |
560 | node_name) != control_dependencies.end(); |
561 | } |
562 | |
563 | // Adds any prefixes of `node_name` (not including the full name itself) to |
564 | // `prefixes`. |
565 | void AddPrefixes(StringPiece node_name, |
566 | absl::flat_hash_set<StringPiece>* prefixes) { |
567 | size_t idx = -1; |
568 | while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) { |
569 | prefixes->insert(node_name.substr(0, idx)); |
570 | } |
571 | } |
572 | |
573 | Status GraphConstructor::EnsureNoNameCollisions() { |
574 | existing_nodes_.reserve(g_->num_nodes()); |
575 | // Populate existing_nodes_ and existing_prefixes_. |
576 | for (Node* n : g_->nodes()) { |
577 | bool already_exists = !existing_nodes_.insert({n->name(), n}).second; |
578 | if (already_exists) { |
579 | if (NodeNameInValues(opts_.input_map, n->name())) { |
580 | return errors::InvalidArgument( |
581 | "cannot resolve input_map because multiple nodes exist with name '" , |
582 | n->name(), "'" ); |
583 | } |
584 | if (NodeNameInValues(opts_.control_dependencies, n->name())) { |
585 | return errors::InvalidArgument( |
586 | "cannot resolve control_dependencies because multiple nodes exist " |
587 | "with name '" , |
588 | n->name(), "'" ); |
589 | } |
590 | } |
591 | AddPrefixes(n->name(), &existing_prefixes_); |
592 | } |
593 | if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) { |
594 | for (size_t i = 0; i < node_def_count(); ++i) { |
595 | const string& name = get_node_def(i).name(); |
596 | if (NameExistsInGraph(name)) { |
597 | return errors::InvalidArgument("Node name '" , name, |
598 | "' already exists in the Graph" ); |
599 | } |
600 | } |
601 | } else if (!prefix_.empty()) { |
602 | StringPiece prefix_no_slash(prefix_); |
603 | prefix_no_slash.remove_suffix(1); |
604 | if (!IsValidNodeName(prefix_no_slash, false)) { |
605 | return errors::InvalidArgument("Imported node name prefix '" , prefix_, |
606 | "' would lead to invalid node names" ); |
607 | } |
608 | if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) { |
609 | prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/" ); |
610 | } |
611 | } |
612 | return OkStatus(); |
613 | } |
614 | |
615 | Status GraphConstructor::ValidateInputMapAndControlDependencies() { |
616 | for (const auto& mapping : opts_.input_map) { |
617 | TensorId src = mapping.first; |
618 | TensorId dst = mapping.second; |
619 | if (existing_nodes_.count(dst.first) == 0) { |
620 | return errors::InvalidArgument( |
621 | "node '" , dst.first, "' in input_map does not exist in graph " , |
622 | "(input_map entry: " , src.ToString(), "->" , dst.ToString(), ")" ); |
623 | } |
624 | if ((src.second == Graph::kControlSlot) != |
625 | (dst.second == Graph::kControlSlot)) { |
626 | return errors::InvalidArgument("input_map entry " , src.ToString(), "->" , |
627 | dst.ToString(), " between " , |
628 | "control edge and non-control edge" ); |
629 | } |
630 | } |
631 | for (const string& node : opts_.control_dependencies) { |
632 | if (existing_nodes_.count(node) == 0) { |
633 | return errors::InvalidArgument( |
634 | "node '" , node, |
635 | "' in control_dependencies does not exist in " |
636 | "graph" ); |
637 | } |
638 | } |
639 | return OkStatus(); |
640 | } |
641 | |
642 | Status GraphConstructor::BuildNodeIndex() { |
643 | // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_. |
644 | for (int n = 0; n < node_def_count(); ++n) { |
645 | const NodeDef& node_def = get_node_def(n); |
646 | if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) { |
647 | return errors::InvalidArgument( |
648 | "Node '" , node_def.name(), |
649 | "': Node name contains invalid characters" ); |
650 | } |
651 | if (!gdef_nodes_.insert(std::make_pair(node_def.name(), NodeInfo(n))) |
652 | .second) { |
653 | return errors::InvalidArgument("Node '" , node_def.name(), |
654 | "' is not unique" ); |
655 | } |
656 | // Validate the operation's type. |
657 | if (node_def.op().empty()) { |
658 | return errors::InvalidArgument("Node '" , node_def.name(), |
659 | "' does not specify an operation" ); |
660 | } |
661 | if (opts_.expect_device_spec && node_def.device().empty()) { |
662 | return errors::InvalidArgument("Node '" , node_def.name(), |
663 | "' is missing a device specification" ); |
664 | } |
665 | if (IsMerge(node_def)) { |
666 | merge_node_indices_.insert(n); |
667 | } |
668 | // Validate control edges at end |
669 | bool in_control_dependence = false; |
670 | for (int i = 0; i < node_def.input_size(); ++i) { |
671 | StringPiece input_name = node_def.input(i); |
672 | if (!input_name.empty() && absl::StartsWith(input_name, "^" )) { |
673 | in_control_dependence = true; |
674 | } else if (in_control_dependence) { |
675 | return errors::InvalidArgument( |
676 | "Node '" , node_def.name(), |
677 | "': Control dependencies must come after regular dependencies" ); |
678 | } |
679 | } |
680 | // Update gdef_prefixes_. |
681 | AddPrefixes(node_def.name(), &gdef_prefixes_); |
682 | } |
683 | return OkStatus(); |
684 | } |
685 | |
686 | Status GraphConstructor::InitFromEdges() { |
687 | const int num_nodes = node_def_count(); |
688 | pending_count_.reserve(num_nodes); |
689 | outputs_.resize(num_nodes); |
690 | gtl::FlatSet<string> next_iteration_nodes; |
691 | for (int n = 0; n < node_def_count(); ++n) { |
692 | const NodeDef& node_def = get_node_def(n); |
693 | if (IsNextIteration(node_def)) { |
694 | next_iteration_nodes.insert(node_def.name()); |
695 | } |
696 | } |
697 | |
698 | // Parse the inputs for each node. |
699 | for (int n = 0; n < num_nodes; ++n) { |
700 | const NodeDef& node_def = get_node_def(n); |
701 | int pending_count = node_def.input_size(); |
702 | if (IsMerge(node_def)) { |
703 | // Cycles in the graph are only allowed for while loops. A while loop is |
704 | // identified by an edge from a NextIteration node to a Merge node. For |
705 | // such Merge nodes, only wait for one non-control input before |
706 | // considering the node ready to process in Convert(). |
707 | int32_t num_control_edges = 0; |
708 | bool has_loop_back_edge = false; |
709 | for (int i = 0; i < node_def.input_size(); ++i) { |
710 | StringPiece input_name(node_def.input(i)); |
711 | if (absl::StartsWith(input_name, "^" )) { |
712 | num_control_edges++; |
713 | } else { |
714 | TensorId id(ParseTensorName(input_name)); |
715 | if (next_iteration_nodes.find(string(id.first)) != |
716 | next_iteration_nodes.end()) { |
717 | has_loop_back_edge = true; |
718 | } |
719 | } |
720 | } |
721 | if (has_loop_back_edge) { |
722 | pending_count = num_control_edges + 1; |
723 | } |
724 | } |
725 | for (int i = 0; i < node_def.input_size(); ++i) { |
726 | StringPiece input_name = node_def.input(i); |
727 | TensorId id(ParseTensorName(input_name)); |
728 | if (opts_.input_map.count(id) == 0) { |
729 | // If an input is not mapped, then the input should appear in the graph |
730 | // being imported. |
731 | auto iter = gdef_nodes_.find(id.first); |
732 | if (iter == gdef_nodes_.end()) { |
733 | return errors::InvalidArgument("Node '" , node_def.name(), |
734 | "': Unknown input node '" , |
735 | node_def.input(i), "'" ); |
736 | } |
737 | outputs_[iter->second.gdef_index].push_back(n); |
738 | } else { |
739 | // This input is mapped to an existing edge. Therefore this input is |
740 | // as good as being already processed. |
741 | --pending_count; |
742 | DCHECK_GE(pending_count, 0); |
743 | } |
744 | } |
745 | if (pending_count == 0) { |
746 | ready_.insert(n); |
747 | } |
748 | pending_count_.push_back(pending_count); |
749 | } |
750 | return OkStatus(); |
751 | } |
752 | |
753 | Status GraphConstructor::ValidateColocationConstraints( |
754 | const NodeDef& node_def) { |
755 | if (!opts_.validate_colocation_constraints || !opts_.importing) |
756 | return OkStatus(); |
757 | const auto iter = node_def.attr().find(kColocationAttrName); |
758 | if (iter == node_def.attr().end()) return OkStatus(); |
759 | for (const string& c : iter->second.list().s()) { |
760 | StringPiece s(c); |
761 | if (absl::ConsumePrefix(&s, kColocationGroupPrefix) && |
762 | gdef_nodes_.find(s) == gdef_nodes_.end()) { |
763 | return errors::InvalidArgument( |
764 | "Node '" , node_def.name(), |
765 | "' expects to be colocated with unknown node '" , s, "'" ); |
766 | } |
767 | } |
768 | return OkStatus(); |
769 | } |
770 | |
771 | Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) { |
772 | // Add the node to the graph. |
773 | Status status; |
774 | *node = g_->AddNode(std::move(node_def), &status); |
775 | if (!status.ok()) return status; |
776 | if (opts_.expect_device_spec) { |
777 | (*node)->set_assigned_device_name((*node)->def().device()); |
778 | } |
779 | return OkStatus(); |
780 | } |
781 | |
782 | Status GraphConstructor::ValidateShape(Node* node) { |
783 | if (!opts_.importing || !opts_.validate_shape) return OkStatus(); |
784 | TF_RETURN_IF_ERROR(refiner_->AddNode(node)); |
785 | // For nodes with the _output_shapes attribute, override the shape. |
786 | std::vector<const TensorShapeProto*> shape_attrs; |
787 | const char* kAttrName = "_output_shapes" ; |
788 | if (!TryGetNodeAttr(node->attrs(), kAttrName, &shape_attrs)) { |
789 | // No _output_shapes attribute, the AddNode call above was sufficient. |
790 | return OkStatus(); |
791 | } |
792 | auto* ic = refiner_->GetContext(node); |
793 | DCHECK(ic != nullptr) |
794 | << "ShapeRefiner::AddNode() should have created the InferenceContext" ; |
795 | if (shape_attrs.size() < node->num_outputs()) { |
796 | return errors::InvalidArgument( |
797 | "Node '" , node->name(), "' has " , node->num_outputs(), |
798 | " outputs but the " , kAttrName, " attribute specifies shapes for " , |
799 | shape_attrs.size(), " outputs" ); |
800 | } |
801 | // NOTE(skyewm): we don't raise an error here because some users depend on |
802 | // this behavior, even though it's unsafe. |
803 | // TODO(b/74619486): raise an error. |
804 | if (shape_attrs.size() > node->num_outputs()) { |
805 | LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs() |
806 | << " outputs but the " << kAttrName |
807 | << " attribute specifies shapes for " << shape_attrs.size() |
808 | << " outputs. Output shapes may be inaccurate." ; |
809 | } |
810 | for (int i = 0; i < node->num_outputs(); ++i) { |
811 | const TensorShapeProto& p = *shape_attrs[i]; |
812 | shape_inference::ShapeHandle h; |
813 | Status s = ic->MakeShapeFromShapeProto(p, &h); |
814 | if (!s.ok()) { |
815 | return errors::InvalidArgument("Node '" , node->name(), " has an invalid " , |
816 | kAttrName, " attribute (shape #" , i, |
817 | " error:'" , s.error_message(), "'" ); |
818 | } |
819 | s = refiner_->SetShape(node, i, h); |
820 | if (!s.ok()) { |
821 | return errors::InvalidArgument( |
822 | "Node '" , node->name(), "' has an " , kAttrName, |
823 | " attribute inconsistent with the GraphDef for output #" , i, ": " , |
824 | s.error_message()); |
825 | } |
826 | } |
827 | node->ClearAttr(kAttrName); |
828 | return OkStatus(); |
829 | } |
830 | |
831 | Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { |
832 | const OpDef* op_def; |
833 | TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); |
834 | AddDefaultsToNodeDef(*op_def, node_def); |
835 | TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def)); |
836 | if (versions()) { |
837 | TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions()->producer())); |
838 | } |
839 | return OkStatus(); |
840 | } |
841 | |
842 | void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def, |
843 | std::vector<bool>* input_already_exists) { |
844 | // Remove 'inputs_to_remove' from 'node_def' |
845 | NodeDef copy; |
846 | copy.mutable_input()->Reserve(node_def->input_size() - |
847 | inputs_to_remove.size()); |
848 | for (int i = 0, j = 0; i < node_def->input_size(); ++i) { |
849 | if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) { |
850 | ++j; |
851 | } else { |
852 | copy.add_input()->swap(*node_def->mutable_input(i)); |
853 | } |
854 | } |
855 | node_def->mutable_input()->Swap(copy.mutable_input()); |
856 | // Remove 'inputs_to_remove' from 'input_already_exists' |
857 | for (int idx : inputs_to_remove) { |
858 | input_already_exists->erase(input_already_exists->begin() + idx); |
859 | } |
860 | DCHECK_EQ(input_already_exists->size(), node_def->input_size()); |
861 | } |
862 | |
863 | void GraphConstructor::RemapNodeDefInputs( |
864 | NodeDef* node_def, std::vector<bool>* input_already_exists) { |
865 | DCHECK_EQ(input_already_exists->size(), node_def->input_size()); |
866 | std::set<TensorId> control_inputs; |
867 | std::vector<int> inputs_to_remove; |
868 | |
869 | for (int i = 0; i < node_def->input_size(); ++i) { |
870 | auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i))); |
871 | if (iter == opts_.input_map.end()) continue; |
872 | used_input_map_keys_.insert(iter->first); |
873 | |
874 | TensorId new_input = iter->second; |
875 | if (new_input.second == Graph::kControlSlot) { |
876 | // Check if we've already remapped a different input to new_input, and if |
877 | // so remove this input. |
878 | if (control_inputs.count(new_input) > 0) { |
879 | inputs_to_remove.push_back(i); |
880 | continue; |
881 | } |
882 | control_inputs.insert(new_input); |
883 | } |
884 | node_def->set_input(i, new_input.ToString()); |
885 | (*input_already_exists)[i] = true; |
886 | } |
887 | if (!inputs_to_remove.empty()) { |
888 | RemoveInputs(inputs_to_remove, node_def, input_already_exists); |
889 | } |
890 | } |
891 | |
892 | void GraphConstructor::AddControlDependencies( |
893 | NodeDef* node_def, std::vector<bool>* input_already_exists) { |
894 | // To avoid adding redundant control dependencies to every imported node, skip |
895 | // nodes that will inherit the dependencies from another imported node. |
896 | bool inherits_deps = false; |
897 | for (int i = 0; i < node_def->input_size(); ++i) { |
898 | // Assume we won't inherit dependencies from remapped inputs that already |
899 | // exist in the graph. Even if we're wrong, we'll only add redundant |
900 | // dependencies. |
901 | if ((*input_already_exists)[i]) continue; |
902 | |
903 | // If this input is a backedge, assume we won't inherit the dependencies. |
904 | // TODO(skyewm): we have many redundant ParseTensorName calls. It could be |
905 | // worth optimizing these. |
906 | TensorId id(ParseTensorName(node_def->input(i))); |
907 | auto iter = gdef_nodes_.find(id.first); |
908 | DCHECK(iter != gdef_nodes_.end()) << id.first; |
909 | if (iter->second.node == nullptr) { |
910 | // Input hasn't been created yet, indicating it's a backedge. |
911 | continue; |
912 | } |
913 | inherits_deps = true; |
914 | } |
915 | if (inherits_deps) return; |
916 | |
917 | // node_def either has no inputs or all remapped inputs, add the control |
918 | // dependencies |
919 | for (const string& control_dep : opts_.control_dependencies) { |
920 | string input = TensorId(control_dep, Graph::kControlSlot).ToString(); |
921 | bool found = false; |
922 | for (int i = node_def->input_size() - 1; i >= 0; --i) { |
923 | const string& node_input = node_def->input(i); |
924 | if (node_input[0] != '^') { |
925 | // Control inputs are at the end. Break when we reach the non-control |
926 | // inputs. |
927 | break; |
928 | } |
929 | if (node_input == input) { |
930 | // Control dependency already exists |
931 | found = true; |
932 | break; |
933 | } |
934 | } |
935 | if (found) { |
936 | continue; |
937 | } |
938 | node_def->add_input(input); |
939 | input_already_exists->push_back(true); |
940 | } |
941 | } |
942 | |
943 | void GraphConstructor::AddPrefixToNodeDef( |
944 | const std::vector<bool>& input_already_exists, NodeDef* node_def) { |
945 | if (prefix_.empty()) return; |
946 | node_def->set_name(strings::StrCat(prefix_, node_def->name())); |
947 | // Update names of input nodes |
948 | for (int i = 0; i < node_def->input_size(); ++i) { |
949 | // Skip remapped inputs (which already exist in g_ and are not being |
950 | // imported). |
951 | if (input_already_exists[i]) continue; |
952 | StringPiece input(node_def->input(i)); |
953 | if (absl::ConsumePrefix(&input, "^" )) { |
954 | node_def->set_input(i, strings::StrCat("^" , prefix_, input)); |
955 | } else { |
956 | node_def->set_input(i, strings::StrCat(prefix_, input)); |
957 | } |
958 | } |
959 | // Update names of colocation groups |
960 | if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) { |
961 | auto* list = |
962 | node_def->mutable_attr()->at(kColocationAttrName).mutable_list(); |
963 | for (int i = 0; i < list->s_size(); ++i) { |
964 | StringPiece v(list->s(i)); |
965 | if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) { |
966 | list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v)); |
967 | } |
968 | } |
969 | } |
970 | } |
971 | |
972 | void GraphConstructor::UniquifyNames( |
973 | const std::vector<bool>& input_already_exists, NodeDef* node_def) { |
974 | if (NameExistsInGraph(node_def->name())) { |
975 | string old_name = node_def->name(); |
976 | node_def->set_name(FindUniqueName(node_def->name())); |
977 | uniquified_names_[old_name] = node_def->name(); |
978 | // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with |
979 | // `name` because we guarantee the original NodeDef names are unique, |
980 | // meaning we won't generate this name again. |
981 | } |
982 | for (int i = 0; i < node_def->input_size(); ++i) { |
983 | // Skip remapped inputs (which already exist in g_ and are not being |
984 | // imported). |
985 | if (input_already_exists[i]) continue; |
986 | TensorId id = ParseTensorName(node_def->input(i)); |
987 | // We require that UniquifyNames() is called on all NodeDefs in topological |
988 | // order. This guarantees that node_def's inputs will already be uniquified |
989 | // if necessary. |
990 | auto iter = uniquified_names_.find(string(id.first)); |
991 | if (iter == uniquified_names_.end()) continue; |
992 | id.first = iter->second; |
993 | node_def->set_input(i, id.ToString()); |
994 | } |
995 | } |
996 | |
997 | void GraphConstructor::UpdateUniquifiedColocationNames() { |
998 | for (const auto& pair : gdef_nodes_) { |
999 | Node* node = pair.second.node; |
1000 | if (node == nullptr) continue; |
1001 | std::vector<string> coloc_values; |
1002 | if (!TryGetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values)) |
1003 | continue; |
1004 | bool updated = false; |
1005 | for (size_t i = 0; i < coloc_values.size(); ++i) { |
1006 | StringPiece val(coloc_values[i]); |
1007 | if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) { |
1008 | auto name_pair = uniquified_names_.find(string(val)); |
1009 | if (name_pair == uniquified_names_.end()) continue; |
1010 | updated = true; |
1011 | coloc_values[i] = |
1012 | strings::StrCat(kColocationGroupPrefix, name_pair->second); |
1013 | } |
1014 | } |
1015 | if (updated) { |
1016 | node->AddAttr(kColocationAttrName, std::move(coloc_values)); |
1017 | } |
1018 | } |
1019 | } |
1020 | |
1021 | bool GraphConstructor::NameExistsInGraph(StringPiece name) { |
1022 | if (existing_nodes_.find(name) != existing_nodes_.end()) return true; |
1023 | if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true; |
1024 | return false; |
1025 | } |
1026 | |
1027 | bool GraphConstructor::NameExistsInGraphDef(StringPiece name) { |
1028 | if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true; |
1029 | if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true; |
1030 | return false; |
1031 | } |
1032 | |
1033 | string GraphConstructor::FindUniqueName(StringPiece original_name) { |
1034 | string name(original_name); |
1035 | int count = 0; |
1036 | // Check that any generated names don't collide with imported NodeDefs (as |
1037 | // well as nodes in g_). |
1038 | while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) { |
1039 | name = strings::StrCat(original_name, "_" , ++count); |
1040 | } |
1041 | return name; |
1042 | } |
1043 | |
1044 | Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def, |
1045 | bool* is_node_mapped) { |
1046 | const OpDef* op_def; |
1047 | TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def)); |
1048 | for (int i = 0; i < op_def->output_arg_size(); ++i) { |
1049 | if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) { |
1050 | *is_node_mapped = false; |
1051 | return OkStatus(); |
1052 | } |
1053 | } |
1054 | *is_node_mapped = true; |
1055 | return OkStatus(); |
1056 | } |
1057 | |
1058 | void GraphConstructor::DFS(int cur_node, std::vector<int>* cur_branch, |
1059 | std::vector<bool>* is_on_cur_branch, |
1060 | absl::flat_hash_set<int>* unvisited, |
1061 | const std::vector<absl::string_view>& node_names) { |
1062 | cur_branch->push_back(cur_node); |
1063 | is_on_cur_branch->at(cur_node) = true; |
1064 | for (auto next_node : outputs_[cur_node]) { |
1065 | if (unvisited->find(next_node) != unvisited->end()) { |
1066 | if (is_on_cur_branch->at(next_node)) { |
1067 | auto iter = |
1068 | std::find(cur_branch->begin(), cur_branch->end(), next_node); |
1069 | LOG(WARNING) << "Cycle detected:" ; |
1070 | while (iter != cur_branch->end()) { |
1071 | const absl::string_view name = node_names[*iter]; |
1072 | DCHECK(!name.empty()); |
1073 | LOG(WARNING) << "node id=" << *iter << ", name=" << name; |
1074 | ++iter; |
1075 | } |
1076 | LOG(WARNING) << "End of cycle" ; |
1077 | } else { |
1078 | DFS(next_node, cur_branch, is_on_cur_branch, unvisited, node_names); |
1079 | } |
1080 | } |
1081 | } |
1082 | cur_branch->pop_back(); |
1083 | is_on_cur_branch->at(cur_node) = false; |
1084 | unvisited->erase(cur_node); |
1085 | } |
1086 | |
1087 | void GraphConstructor::PrintCycles() { |
1088 | int num_nodes = outputs_.size(); |
1089 | |
1090 | std::vector<absl::string_view> node_names; |
1091 | node_names.resize(num_nodes); |
1092 | for (const auto& named_node : gdef_nodes_) { |
1093 | DCHECK_GE(named_node.second.gdef_index, 0); |
1094 | DCHECK_LT(named_node.second.gdef_index, num_nodes); |
1095 | node_names[named_node.second.gdef_index] = named_node.first; |
1096 | } |
1097 | |
1098 | absl::flat_hash_set<int> unvisited; |
1099 | for (int i = 0; i < num_nodes; i++) { |
1100 | unvisited.insert(i); |
1101 | } |
1102 | |
1103 | while (!unvisited.empty()) { |
1104 | int cur_node = *unvisited.begin(); |
1105 | // Nodes on the current branch of DFS in traversal order. This is used for |
1106 | // printing the nodes in the cycle. |
1107 | std::vector<int> cur_branch; |
1108 | // This is just to make lookups O(1). |
1109 | // is_on_cur_branch[i] == |
1110 | // (std::find(cur_branch.start(), |
1111 | // cur_branch.end(), i) != cur_branch.end()) |
1112 | std::vector<bool> is_on_cur_branch(num_nodes, false); |
1113 | DFS(cur_node, &cur_branch, &is_on_cur_branch, &unvisited, node_names); |
1114 | } |
1115 | } |
1116 | |
1117 | Status GraphConstructor::Convert() { |
1118 | // Import functions before adding nodes, since imported nodes may refer to |
1119 | // functions |
1120 | if (library()) { |
1121 | // TODO(b/135705010): Add rvalue overloads into the function library, to |
1122 | // avoid unnecessarily copying `*library()` here. |
1123 | TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library())); |
1124 | } |
1125 | |
1126 | std::vector<InputInfo> inputs; |
1127 | int processed = 0; |
1128 | |
1129 | std::vector<bool> input_already_exists; |
1130 | |
1131 | // Process the NodeDefs in topological order. |
1132 | // (InitFromEdges() sets this up by filling in ready_ with nodes that have no |
1133 | // inputs, pending_counts_ with the number of inputs for each node and |
1134 | // outputs_ with the outputs of each node). |
1135 | while (!ready_.empty()) { |
1136 | int o = *ready_.begin(); |
1137 | ready_.erase(ready_.begin()); |
1138 | ++processed; |
1139 | inputs.clear(); |
1140 | bool has_data_back_edge = false; |
1141 | |
1142 | NodeDef node_def = consume_node_def(o); |
1143 | |
1144 | // input_already_exists[i] is true iff the i-th input of the node we're |
1145 | // importing refers to a preexisting node in g_ (i.e. input[i] existed prior |
1146 | // to importing node_defs_). Conversely, input_already_exists[i] is false |
1147 | // iff the input refers to a node in node_defs_. |
1148 | input_already_exists.clear(); |
1149 | input_already_exists.resize(node_def.input_size(), false); |
1150 | |
1151 | std::string node_name = node_def.name(); |
1152 | |
1153 | if (opts_.importing) { |
1154 | if (opts_.skip_mapped_nodes) { |
1155 | bool is_node_mapped = false; |
1156 | TF_RETURN_IF_ERROR(IsNodeFullyMapped(node_def, &is_node_mapped)); |
1157 | if (is_node_mapped) { |
1158 | // Skip this node after updating pending_count_ for outputs |
1159 | UpdatePendingCountAndReady(o, IsNextIteration(node_def)); |
1160 | continue; |
1161 | } |
1162 | } |
1163 | |
1164 | if (!opts_.input_map.empty()) { |
1165 | // Note that input_already_exists can shrink here |
1166 | RemapNodeDefInputs(&node_def, &input_already_exists); |
1167 | } |
1168 | if (!opts_.control_dependencies.empty()) { |
1169 | // Note that input_already_exists can grow here |
1170 | AddControlDependencies(&node_def, &input_already_exists); |
1171 | } |
1172 | if (!opts_.default_device.empty() && node_def.device().empty()) { |
1173 | node_def.set_device(opts_.default_device); |
1174 | } |
1175 | } |
1176 | |
1177 | DCHECK_EQ(node_def.input_size(), input_already_exists.size()); |
1178 | TF_RETURN_IF_ERROR(ValidateColocationConstraints(node_def)); |
1179 | for (int i = 0; i < node_def.input_size(); ++i) { |
1180 | TensorId tensor_id = ParseTensorName(node_def.input(i)); |
1181 | Node* src_node; |
1182 | int src_index; |
1183 | |
1184 | if (!input_already_exists[i]) { |
1185 | // Locate input in newly-imported nodes |
1186 | auto iter = gdef_nodes_.find(tensor_id.node()); |
1187 | DCHECK(iter != gdef_nodes_.end()) << tensor_id.node(); |
1188 | src_node = iter->second.node; |
1189 | src_index = tensor_id.index(); |
1190 | if (src_node == nullptr) has_data_back_edge = true; |
1191 | } else { |
1192 | // Input refers to preexistng node in graph |
1193 | auto iter = existing_nodes_.find(tensor_id.node()); |
1194 | DCHECK(iter != existing_nodes_.end()) << tensor_id.node(); |
1195 | src_node = iter->second; |
1196 | src_index = tensor_id.index(); |
1197 | } |
1198 | |
1199 | if (src_node != nullptr && src_index >= src_node->num_outputs()) { |
1200 | std::ostringstream out; |
1201 | out << "Node '" << node_def.name() << "': Connecting to invalid output " |
1202 | << tensor_id.index() << " of source node " << tensor_id.node() |
1203 | << " which has " << src_node->num_outputs() << " outputs." ; |
1204 | |
1205 | if (src_node->type_string() == "If" || |
1206 | src_node->type_string() == "StatelessIf" || |
1207 | src_node->type_string() == "While" || |
1208 | src_node->type_string() == "StatelessWhile" ) { |
1209 | out << " Try using " |
1210 | << "tf.compat.v1.experimental.output_all_intermediates(True)." ; |
1211 | } |
1212 | return errors::InvalidArgument(out.str()); |
1213 | } |
1214 | |
1215 | inputs.emplace_back(string(tensor_id.node()), src_node, src_index); |
1216 | } |
1217 | |
1218 | if (has_data_back_edge && !IsMerge(node_def)) { |
1219 | return errors::InvalidArgument( |
1220 | "Node '" , node_def.name(), |
1221 | "' had a back edge, but only Merge nodes can have back edges." ); |
1222 | } |
1223 | |
1224 | Node* node; |
1225 | if (opts_.importing) { |
1226 | if (!prefix_.empty()) { |
1227 | AddPrefixToNodeDef(input_already_exists, &node_def); |
1228 | } |
1229 | // Note: no need to uniquify names if the prefix already guarantees |
1230 | // uniqueness |
1231 | if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) { |
1232 | UniquifyNames(input_already_exists, &node_def); |
1233 | } |
1234 | } |
1235 | |
1236 | if (opts_.importing) { |
1237 | TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def)); |
1238 | } else { |
1239 | const OpDef* op_def; |
1240 | TF_RETURN_IF_ERROR( |
1241 | g_->op_registry()->LookUpOpDef(node_def.op(), &op_def)); |
1242 | if (opts_.add_default_attributes) { |
1243 | AddDefaultsToNodeDef(*op_def, &node_def); |
1244 | } |
1245 | if (opts_.validate_nodes) { |
1246 | TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def)); |
1247 | } |
1248 | } |
1249 | |
1250 | TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node)); |
1251 | |
1252 | gdef_nodes_[node_name].node = node; |
1253 | |
1254 | // Remove duplicate control inputs before adding edges to the graph. It |
1255 | // will allow us to skip expensive duplicates check in 'AddControlEdge'. |
1256 | auto first_control = absl::c_find_if(inputs, &InputInfo::IsControlInput); |
1257 | auto first_control_copy = first_control; |
1258 | std::sort(first_control, inputs.end(), &InputInfo::CompareName); |
1259 | inputs.erase( |
1260 | std::unique(first_control_copy, inputs.end(), &InputInfo::IsSameName), |
1261 | inputs.end()); |
1262 | |
1263 | // Add edges from inputs to *node to the graph. |
1264 | for (size_t i = 0; i < inputs.size(); ++i) { |
1265 | if (inputs[i].node == nullptr) { |
1266 | // Record this back edge, which will be added after all nodes |
1267 | // are created. |
1268 | back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i); |
1269 | } else if (inputs[i].index == Graph::kControlSlot) { |
1270 | g_->AddControlEdge(inputs[i].node, node, kDoNotCheckDuplicates); |
1271 | } else { |
1272 | TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i)); |
1273 | } |
1274 | } |
1275 | |
1276 | TF_RETURN_IF_ERROR(ValidateShape(node)); |
1277 | |
1278 | // Update pending_count_ for outputs. |
1279 | UpdatePendingCountAndReady(o, node->IsNextIteration()); |
1280 | } |
1281 | |
1282 | if (processed < node_def_count()) { |
1283 | LOG(WARNING) << "IN " << __func__ << " " << (node_def_count() - processed) |
1284 | << " NODES IN A CYCLE" ; |
1285 | for (int64_t i = 0; i < node_def_count(); i++) { |
1286 | if (pending_count_[i] != 0) { |
1287 | LOG(WARNING) << "PENDING: " << SummarizeNodeDef(get_node_def(i)) |
1288 | << " WITH PENDING COUNT = " << pending_count_[i]; |
1289 | } |
1290 | } |
1291 | PrintCycles(); |
1292 | return errors::InvalidArgument(node_def_count() - processed, |
1293 | " nodes in a cycle" ); |
1294 | } |
1295 | |
1296 | return OkStatus(); |
1297 | } |
1298 | |
1299 | Status GraphConstructor::AddBackEdges() { |
1300 | // Add the back edges after all nodes are created. |
1301 | for (const auto& e : back_edges_) { |
1302 | Node* src_node = gdef_nodes_[e.src_name].node; |
1303 | if (e.src_index == Graph::kControlSlot) { |
1304 | g_->AddControlEdge(src_node, e.dst_node, kDoNotCheckDuplicates); |
1305 | } else { |
1306 | TF_RETURN_IF_ERROR( |
1307 | MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index)); |
1308 | } |
1309 | |
1310 | VLOG(2) << "Add back edge: " << src_node->name() << " -> " |
1311 | << e.dst_node->name(); |
1312 | } |
1313 | return OkStatus(); |
1314 | } |
1315 | |
1316 | Status GraphConstructor::UpdateVersionDef() { |
1317 | if (versions() == nullptr) return OkStatus(); |
1318 | |
1319 | if (!opts_.importing) { |
1320 | g_->set_versions(*versions()); |
1321 | return OkStatus(); |
1322 | } |
1323 | VersionDef g_versions = g_->versions(); |
1324 | g_versions.set_producer( |
1325 | std::min(g_versions.producer(), versions()->producer())); |
1326 | g_versions.set_min_consumer( |
1327 | std::max(g_versions.min_consumer(), versions()->min_consumer())); |
1328 | if (versions()->bad_consumers_size() > 0) { |
1329 | std::set<int> bad(g_versions.bad_consumers().begin(), |
1330 | g_versions.bad_consumers().end()); |
1331 | bad.insert(versions()->bad_consumers().begin(), |
1332 | versions()->bad_consumers().end()); |
1333 | g_versions.clear_bad_consumers(); |
1334 | for (int v : bad) { |
1335 | g_versions.add_bad_consumers(v); |
1336 | } |
1337 | } |
1338 | g_->set_versions(g_versions); |
1339 | return OkStatus(); |
1340 | } |
1341 | |
1342 | Status GraphConstructor::PopulateReturnTensors() { |
1343 | if (opts_.return_tensors.empty()) return OkStatus(); |
1344 | for (const TensorId& id : opts_.return_tensors) { |
1345 | auto iter = opts_.input_map.find(id); |
1346 | if (iter == opts_.input_map.end()) { |
1347 | // Locate id in imported nodes |
1348 | auto iter = gdef_nodes_.find(id.first); |
1349 | if (iter == gdef_nodes_.end()) { |
1350 | return errors::InvalidArgument("Requested return tensor '" , |
1351 | id.ToString(), |
1352 | "' not found in graph def" ); |
1353 | } |
1354 | int num_outputs = iter->second.node->num_outputs(); |
1355 | if ((id.second < 0 || id.second >= num_outputs) && |
1356 | id.second != Graph::kControlSlot) { |
1357 | return errors::InvalidArgument("Invalid return output " , id.second, |
1358 | " of node '" , id.first, "', which has " , |
1359 | num_outputs, " output(s)" ); |
1360 | } |
1361 | return_tensors_->push_back({iter->second.node, id.second}); |
1362 | } else { |
1363 | // id was remapped to existing node |
1364 | TensorId remapped_id = iter->second; |
1365 | DCHECK_GT(existing_nodes_.count(remapped_id.first), 0); |
1366 | Node* node = existing_nodes_[remapped_id.first]; |
1367 | return_tensors_->push_back({node, remapped_id.second}); |
1368 | } |
1369 | } |
1370 | return OkStatus(); |
1371 | } |
1372 | |
1373 | Status GraphConstructor::PopulateReturnNodes() { |
1374 | if (opts_.return_nodes.empty()) return OkStatus(); |
1375 | for (StringPiece name : opts_.return_nodes) { |
1376 | auto iter = gdef_nodes_.find(name); |
1377 | if (iter == gdef_nodes_.end()) { |
1378 | return errors::InvalidArgument("Requested return node '" , name, |
1379 | "' not found in graph def" ); |
1380 | } |
1381 | return_nodes_->push_back(iter->second.node); |
1382 | } |
1383 | return OkStatus(); |
1384 | } |
1385 | |
1386 | Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { |
1387 | if (missing_unused_input_map_keys_ == nullptr) return OkStatus(); |
1388 | for (const auto& input_map_pair : opts_.input_map) { |
1389 | TensorId key = input_map_pair.first; |
1390 | if (used_input_map_keys_.count(key) > 0) continue; |
1391 | |
1392 | auto pair = gdef_nodes_.find(key.first); |
1393 | if (pair == gdef_nodes_.end()) { |
1394 | // key's node doesn't exist in GraphDef |
1395 | missing_unused_input_map_keys_->push_back(key); |
1396 | continue; |
1397 | } |
1398 | |
1399 | // Check that key's index is in bounds. Get the number of outputs from the |
1400 | // NodeDef, rather than the imported Node, since the Node may not exist if |
1401 | // opts_.skip_mapped_nodes is true. |
1402 | const NodeDef& node_def = get_node_def(pair->second.gdef_index); |
1403 | const OpDef* op_def; |
1404 | TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def)); |
1405 | int num_outputs; |
1406 | TF_RETURN_IF_ERROR(NumOutputsForNode(node_def, *op_def, &num_outputs)); |
1407 | if (key.second >= num_outputs) { |
1408 | // key's index out of bounds |
1409 | missing_unused_input_map_keys_->push_back(key); |
1410 | } |
1411 | } |
1412 | return OkStatus(); |
1413 | } |
1414 | |
1415 | void GraphConstructor::Undo() { |
1416 | for (const auto& iter : gdef_nodes_) { |
1417 | if (iter.second.node != nullptr) { |
1418 | g_->RemoveNode(iter.second.node); |
1419 | } |
1420 | } |
1421 | g_->set_versions(original_versions_); |
1422 | } |
1423 | |
1424 | Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, |
1425 | int input_index) { |
1426 | if (output_index >= src->num_outputs()) { |
1427 | return errors::InvalidArgument( |
1428 | "Output " , output_index, " of node " , src->name(), |
1429 | " does not exist. Node only has " , src->num_outputs(), " outputs." ); |
1430 | } |
1431 | if (input_index >= dst->num_inputs()) { |
1432 | return errors::InvalidArgument( |
1433 | "Input " , input_index, " of node " , dst->name(), |
1434 | " does not exist. Node only has " , dst->num_inputs(), " inputs." ); |
1435 | } |
1436 | |
1437 | DataType src_out = src->output_type(output_index); |
1438 | DataType dst_in = dst->input_type(input_index); |
1439 | if (!TypesCompatible(dst_in, src_out)) { |
1440 | return errors::InvalidArgument( |
1441 | "Input " , input_index, " of node " , dst->name(), " was passed " , |
1442 | DataTypeString(src_out), " from " , src->name(), ":" , output_index, |
1443 | " incompatible with expected " , DataTypeString(dst_in), "." ); |
1444 | } |
1445 | g_->AddEdge(src, output_index, dst, input_index); |
1446 | return OkStatus(); |
1447 | } |
1448 | |
1449 | } // namespace |
1450 | |
1451 | Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, |
1452 | const GraphDef& gdef, Graph* g) { |
1453 | ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); |
1454 | return GraphConstructor::Construct( |
1455 | opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner, |
1456 | /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, |
1457 | /*missing_unused_input_map_keys=*/nullptr); |
1458 | } |
1459 | |
1460 | Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, |
1461 | GraphDef&& gdef, Graph* g) { |
1462 | ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); |
1463 | return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner, |
1464 | /*return_tensors=*/nullptr, |
1465 | /*return_nodes=*/nullptr, |
1466 | /*missing_unused_input_map_keys=*/nullptr); |
1467 | } |
1468 | |
1469 | Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, |
1470 | gtl::ArraySlice<NodeDef> nodes, Graph* g) { |
1471 | ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry()); |
1472 | // TODO(irving): Copy will go away once NodeInfo exists |
1473 | std::vector<const NodeDef*> node_defs; |
1474 | node_defs.reserve(nodes.size()); |
1475 | for (const auto& n : nodes) { |
1476 | node_defs.push_back(&n); |
1477 | } |
1478 | return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g, |
1479 | &refiner, /*return_tensors=*/nullptr, |
1480 | /*return_nodes=*/nullptr, |
1481 | /*missing_unused_input_map_keys=*/nullptr); |
1482 | } |
1483 | |
1484 | Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, |
1485 | Graph* g, ShapeRefiner* refiner, |
1486 | ImportGraphDefResults* results) { |
1487 | if (!opts.return_tensors.empty()) { |
1488 | if (results == nullptr) { |
1489 | return errors::InvalidArgument( |
1490 | "results argument to ImportGraphDef() must be non-null if " |
1491 | "opts.return_tensors is non-empty" ); |
1492 | } |
1493 | } |
1494 | |
1495 | if (!opts.return_nodes.empty()) { |
1496 | if (opts.skip_mapped_nodes) { |
1497 | return errors::InvalidArgument( |
1498 | "Requesting return_nodes with skip_mapped_nodes set is not currently " |
1499 | "supported" ); |
1500 | } |
1501 | if (results == nullptr) { |
1502 | return errors::InvalidArgument( |
1503 | "results argument to ImportGraphDef() must be non-null if " |
1504 | "opts.return_nodes is non-empty" ); |
1505 | } |
1506 | } |
1507 | |
1508 | if (results != nullptr) { |
1509 | if (!results->return_tensors.empty() || !results->return_nodes.empty() || |
1510 | !results->missing_unused_input_map_keys.empty()) { |
1511 | return errors::InvalidArgument( |
1512 | "All fields in results argument to ImportGraphDef() must be empty." ); |
1513 | } |
1514 | } |
1515 | |
1516 | ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry()); |
1517 | if (refiner == nullptr) { |
1518 | refiner = &default_refiner; |
1519 | } else { |
1520 | // Log a warning if we are importing a GraphDef at an older |
1521 | // producer version after already having added non-source/sink |
1522 | // nodes to the graph in the past. |
1523 | if (gdef.versions().producer() > 0 && |
1524 | gdef.versions().producer() < refiner->graph_def_version() && |
1525 | g->num_nodes() > 2) { |
1526 | LOG(WARNING) << "Importing a graph with a lower producer version " |
1527 | << gdef.versions().producer() |
1528 | << " into an existing graph with producer version " |
1529 | << refiner->graph_def_version() << ". Shape inference will " |
1530 | << "have run different parts of the graph with different " |
1531 | << "producer versions." ; |
1532 | } |
1533 | } |
1534 | |
1535 | // Set the graph def version of the refiner as the min of the |
1536 | // current value and the version from the graph we are about to |
1537 | // import. |
1538 | // |
1539 | // Note: to match Run() semantics, we should re-run shape inference |
1540 | // on the entire graph if the producer version has changed. For now |
1541 | // we log the warning above. |
1542 | refiner->set_graph_def_version( |
1543 | std::min(refiner->graph_def_version(), gdef.versions().producer())); |
1544 | |
1545 | if (results == nullptr) { |
1546 | return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), |
1547 | &gdef.library(), g, refiner, nullptr, |
1548 | nullptr, nullptr); |
1549 | } else { |
1550 | return GraphConstructor::Construct( |
1551 | opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, |
1552 | &results->return_tensors, &results->return_nodes, |
1553 | &results->missing_unused_input_map_keys); |
1554 | } |
1555 | } |
1556 | |
1557 | void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); } |
1558 | |
1559 | } // namespace tensorflow |
1560 | |