1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
17 | |
18 | #include "tensorflow/core/framework/node_def_util.h" |
19 | #include "tensorflow/core/framework/op.h" |
20 | #include "tensorflow/core/lib/hash/hash.h" |
21 | #include "tensorflow/core/lib/strings/numbers.h" |
22 | #include "tensorflow/core/lib/strings/str_util.h" |
23 | |
24 | namespace tensorflow { |
25 | namespace graph_transforms { |
26 | |
27 | namespace { |
28 | inline bool IsMerge(const NodeDef& node_def) { |
29 | return node_def.op() == "Merge" || node_def.op() == "RefMerge" || |
30 | node_def.op() == "_XlaMerge" ; |
31 | } |
32 | |
33 | void RecordMatchedNodes(const NodeMatch& match, |
34 | std::set<string>* matched_nodes) { |
35 | matched_nodes->insert(match.node.name()); |
36 | for (const NodeMatch& input_match : match.inputs) { |
37 | RecordMatchedNodes(input_match, matched_nodes); |
38 | } |
39 | } |
40 | |
41 | inline uint64 Hash64String(const string& input) { |
42 | return Hash64(input.data(), input.size()); |
43 | } |
44 | } // namespace |
45 | |
46 | void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) { |
47 | std::set<string> found_nodes; |
48 | std::vector<NodeMatch> current_matches = {match}; |
49 | while (!current_matches.empty()) { |
50 | std::vector<NodeMatch> next_matches; |
51 | for (const NodeMatch& current_match : current_matches) { |
52 | if (found_nodes.count(current_match.node.name())) { |
53 | continue; |
54 | } |
55 | found_nodes.insert(current_match.node.name()); |
56 | result->push_back(current_match.node); |
57 | for (const NodeMatch& input_match : current_match.inputs) { |
58 | next_matches.push_back(input_match); |
59 | } |
60 | } |
61 | current_matches = next_matches; |
62 | } |
63 | } |
64 | |
65 | void MapNamesToNodes(const GraphDef& graph_def, |
66 | std::map<string, const NodeDef*>* result) { |
67 | for (const NodeDef& node : graph_def.node()) { |
68 | (*result)[node.name()] = &node; |
69 | } |
70 | } |
71 | |
72 | void MapNodesToOutputs(const GraphDef& graph_def, |
73 | std::map<string, std::vector<const NodeDef*>>* result) { |
74 | std::map<string, const NodeDef*> node_map; |
75 | MapNamesToNodes(graph_def, &node_map); |
76 | for (const NodeDef& node : graph_def.node()) { |
77 | for (const string& input : node.input()) { |
78 | string input_node_name = NodeNameFromInput(input); |
79 | (*result)[input_node_name].push_back(&node); |
80 | } |
81 | } |
82 | } |
83 | |
84 | void NodeNamePartsFromInput(const string& input_name, string* prefix, |
85 | string* node_name, string* suffix) { |
86 | std::vector<string> input_parts = str_util::Split(input_name, ':'); |
87 | if (input_parts.size() < 2) { |
88 | *suffix = "" ; |
89 | } else { |
90 | *suffix = ":" + input_parts[1]; |
91 | } |
92 | StringPiece node_name_piece(input_parts[0]); |
93 | if (absl::ConsumePrefix(&node_name_piece, "^" )) { |
94 | *prefix = "^" ; |
95 | } else { |
96 | *prefix = "" ; |
97 | } |
98 | *node_name = string(node_name_piece); |
99 | } |
100 | |
101 | string NodeNameFromInput(const string& input_name) { |
102 | string prefix; |
103 | string node_name; |
104 | string suffix; |
105 | NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); |
106 | return node_name; |
107 | } |
108 | |
109 | string CanonicalInputName(const string& input_name) { |
110 | string prefix; |
111 | string node_name; |
112 | string suffix; |
113 | NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); |
114 | if (suffix.empty()) { |
115 | suffix = ":0" ; |
116 | } |
117 | return prefix + node_name + suffix; |
118 | } |
119 | |
120 | uint64 HashNodeDef(const NodeDef& node) { |
121 | uint64 hash = Hash64String(node.op()); |
122 | hash = Hash64Combine(hash, Hash64String(node.name())); |
123 | for (const string& input : node.input()) { |
124 | hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input))); |
125 | } |
126 | hash = Hash64Combine(hash, Hash64String(node.device())); |
127 | std::vector<string> attr_names; |
128 | attr_names.reserve(node.attr().size()); |
129 | for (const auto& attr : node.attr()) { |
130 | attr_names.push_back(attr.first); |
131 | } |
132 | std::sort(attr_names.begin(), attr_names.end()); |
133 | string attr_serialized; |
134 | for (const string& attr_name : attr_names) { |
135 | auto attr = node.attr().at(attr_name); |
136 | attr.SerializeToString(&attr_serialized); |
137 | hash = Hash64Combine(hash, Hash64String(attr_serialized)); |
138 | } |
139 | return hash; |
140 | } |
141 | |
142 | void AddNodeInput(const string& input_name, NodeDef* node) { |
143 | *(node->mutable_input()->Add()) = input_name; |
144 | } |
145 | |
146 | void CopyNodeAttr(const NodeDef& source, const string& source_key, |
147 | const string& dest_key, NodeDef* dest) { |
148 | CHECK_NE(0, source.attr().count(source_key)) |
149 | << "No key '" << source_key << "' found in " << source.DebugString(); |
150 | (*(dest->mutable_attr()))[dest_key] = source.attr().at(source_key); |
151 | } |
152 | |
153 | Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) { |
154 | TensorProto tensor_proto = node.attr().at(key).tensor(); |
155 | Tensor tensor; |
156 | CHECK(tensor.FromProto(tensor_proto)); |
157 | return tensor; |
158 | } |
159 | |
160 | void FilterGraphDef(const GraphDef& input_graph_def, |
161 | std::function<bool(const NodeDef&)> selector, |
162 | GraphDef* output_graph_def) { |
163 | output_graph_def->mutable_node()->Clear(); |
164 | for (const NodeDef& node : input_graph_def.node()) { |
165 | if (selector(node)) { |
166 | *output_graph_def->mutable_node()->Add() = node; |
167 | } |
168 | } |
169 | } |
170 | |
171 | void RemoveAttributes(const GraphDef& input_graph_def, |
172 | const std::vector<string>& attributes, |
173 | GraphDef* output_graph_def) { |
174 | output_graph_def->mutable_node()->Clear(); |
175 | for (const NodeDef& node : input_graph_def.node()) { |
176 | NodeDef* new_node = output_graph_def->mutable_node()->Add(); |
177 | *new_node = node; |
178 | for (const string& attribute : attributes) { |
179 | new_node->mutable_attr()->erase(attribute); |
180 | } |
181 | } |
182 | } |
183 | |
184 | Status SortByExecutionOrder(const GraphDef& input_graph_def, |
185 | GraphDef* output_graph_def) { |
186 | const int num_nodes = input_graph_def.node_size(); |
187 | std::vector<int> ready; |
188 | std::vector<int> pending_count; |
189 | pending_count.reserve(num_nodes); |
190 | std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes); |
191 | |
192 | std::map<string, int> name_index; |
193 | for (int i = 0; i < input_graph_def.node_size(); ++i) { |
194 | const NodeDef& node(input_graph_def.node(i)); |
195 | name_index[node.name()] = i; |
196 | } |
197 | |
198 | // Parse the inputs for each node. |
199 | for (int n = 0; n < num_nodes; ++n) { |
200 | const NodeDef& node_def(input_graph_def.node(n)); |
201 | if (IsMerge(node_def)) { |
202 | // for merge only wait for one non-control input. |
203 | int32_t num_control_edges = 0; |
204 | for (int i = 0; i < node_def.input_size(); ++i) { |
205 | if (absl::StartsWith(node_def.input(i), "^" )) { |
206 | num_control_edges++; |
207 | } |
208 | } |
209 | pending_count.push_back(num_control_edges + 1); |
210 | } else { |
211 | pending_count.push_back(node_def.input_size()); |
212 | } |
213 | if (node_def.input_size() == 0) { |
214 | ready.push_back(n); |
215 | continue; |
216 | } |
217 | for (int i = 0; i < node_def.input_size(); ++i) { |
218 | const string& input_name = node_def.input(i); |
219 | const string& input_node_name = NodeNameFromInput(input_name); |
220 | if (!name_index.count(input_node_name)) { |
221 | return errors::InvalidArgument("Node '" , node_def.name(), |
222 | "': Unknown input node '" , |
223 | node_def.input(i), "'" ); |
224 | } |
225 | outputs[name_index[input_node_name]].push_back(n); |
226 | } |
227 | } |
228 | |
229 | int processed = 0; |
230 | output_graph_def->Clear(); |
231 | // Process the NodeDefs in topological order. |
232 | // Code above sets this up by filling in ready_ with nodes that have no |
233 | // inputs, pending_counts_ with the number of inputs for each node and |
234 | // outputs_ with the outputs of each node. |
235 | while (!ready.empty()) { |
236 | int o = ready.back(); |
237 | ready.pop_back(); |
238 | ++processed; |
239 | const NodeDef& node_def(input_graph_def.node(o)); |
240 | *output_graph_def->mutable_node()->Add() = node_def; |
241 | |
242 | // Update pending_count for outputs. |
243 | for (size_t i = 0; i < outputs[o].size(); ++i) { |
244 | const int output = outputs[o][i]; |
245 | pending_count[output]--; |
246 | if (pending_count[output] == 0) { |
247 | ready.push_back(output); |
248 | } |
249 | } |
250 | } |
251 | |
252 | if (processed < num_nodes) { |
253 | LOG(WARNING) << "IN " << __func__ << (num_nodes - processed) |
254 | << " NODES IN A CYCLE" ; |
255 | for (int64_t i = 0; i < num_nodes; i++) { |
256 | if (pending_count[i] != 0) { |
257 | LOG(WARNING) << "PENDING: " << SummarizeNodeDef(input_graph_def.node(i)) |
258 | << "WITH PENDING COUNT = " << pending_count[i]; |
259 | } |
260 | } |
261 | return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle" ); |
262 | } |
263 | return OkStatus(); |
264 | } |
265 | |
266 | string OpTypePattern::DebugString() const { |
267 | string result = "{" + op + ", {" ; |
268 | for (const OpTypePattern& input : inputs) { |
269 | result += input.DebugString() + "," ; |
270 | } |
271 | result += "}}" ; |
272 | return result; |
273 | } |
274 | |
275 | string NodeMatch::DebugString() const { |
276 | string result = "{" ; |
277 | result += node.DebugString(); |
278 | result += ", {" ; |
279 | for (const NodeMatch& input : inputs) { |
280 | result += input.DebugString() + "," ; |
281 | } |
282 | result += "}}" ; |
283 | return result; |
284 | } |
285 | |
286 | GraphMatcher::GraphMatcher(const GraphDef& graph_def) { |
287 | SortByExecutionOrder(graph_def, &graph_def_).IgnoreError(); |
288 | MapNamesToNodes(graph_def_, &node_map_); |
289 | } |
290 | |
291 | Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern, |
292 | std::vector<NodeMatch>* matches) { |
293 | std::set<string> matched_nodes; |
294 | for (const NodeDef& node : graph_def_.node()) { |
295 | // Skip any nodes that are already part of a match. |
296 | if (matched_nodes.count(node.name())) { |
297 | continue; |
298 | } |
299 | NodeMatch match; |
300 | if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) { |
301 | RecordMatchedNodes(match, &matched_nodes); |
302 | matches->push_back(match); |
303 | } |
304 | } |
305 | return OkStatus(); |
306 | } |
307 | |
308 | bool GraphMatcher::DoesOpTypeMatch( |
309 | const NodeDef& node, const OpTypePattern& pattern, |
310 | const std::set<string>& previously_matched_nodes, NodeMatch* match) { |
311 | VLOG(1) << "Looking at node " << node.DebugString(); |
312 | VLOG(1) << "pattern=" << pattern.DebugString(); |
313 | VLOG(1) << "match=" << match->DebugString(); |
314 | if (previously_matched_nodes.count(node.name())) { |
315 | VLOG(1) << "node " << node.name() << " has been previously matched" ; |
316 | return false; |
317 | } |
318 | bool pattern_matched = false; |
319 | if (pattern.op == "*" ) { |
320 | pattern_matched = true; |
321 | } else { |
322 | std::vector<string> pattern_ops = str_util::Split(pattern.op, '|'); |
323 | for (const string& pattern_op : pattern_ops) { |
324 | if (node.op() == pattern_op) { |
325 | pattern_matched = true; |
326 | } |
327 | } |
328 | } |
329 | if (!pattern_matched) { |
330 | VLOG(1) << "node.op() != pattern.op()" ; |
331 | return false; |
332 | } |
333 | match->node = node; |
334 | // Ignore any control inputs for pattern-matching purposes |
335 | std::vector<string> non_control_inputs; |
336 | for (const string& input : node.input()) { |
337 | if (!input.empty() && (input[0] != '^')) { |
338 | non_control_inputs.push_back(input); |
339 | } |
340 | } |
341 | if (pattern.inputs.empty()) { |
342 | // If there are no inputs, assume that's the end of the pattern. |
343 | return true; |
344 | } |
345 | if (non_control_inputs.size() != pattern.inputs.size()) { |
346 | VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()" ; |
347 | return false; |
348 | } |
349 | for (int i = 0; i < pattern.inputs.size(); ++i) { |
350 | const string& input_node_name = NodeNameFromInput(non_control_inputs[i]); |
351 | const NodeDef& input_node = *(node_map_[input_node_name]); |
352 | const OpTypePattern& input_pattern = pattern.inputs[i]; |
353 | match->inputs.push_back(NodeMatch()); |
354 | NodeMatch* input_match = &(match->inputs.back()); |
355 | if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes, |
356 | input_match)) { |
357 | return false; |
358 | } |
359 | } |
360 | return true; |
361 | } |
362 | |
363 | Status ReplaceMatchingOpTypes( |
364 | const GraphDef& input_graph_def, const OpTypePattern& pattern, |
365 | const std::function<Status(const NodeMatch&, const std::set<string>&, |
366 | const std::set<string>&, std::vector<NodeDef>*)>& |
367 | node_generator, |
368 | const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) { |
369 | // Start off by retrieving all the matching subgraphs. |
370 | GraphMatcher matcher(input_graph_def); |
371 | std::vector<NodeMatch> matches; |
372 | TF_RETURN_IF_ERROR(matcher.GetOpTypeMatches(pattern, &matches)); |
373 | |
374 | // Do some housekeeping so we can easily look up the resulting matches given |
375 | // a node name. |
376 | std::set<string> matched_nodes; |
377 | std::map<string, const NodeMatch*> matches_by_head_name; |
378 | for (const NodeMatch& match : matches) { |
379 | matches_by_head_name[match.node.name()] = &match; |
380 | RecordMatchedNodes(match, &matched_nodes); |
381 | } |
382 | std::map<string, std::vector<const NodeDef*>> outputs_map; |
383 | MapNodesToOutputs(input_graph_def, &outputs_map); |
384 | |
385 | // Go through all the nodes in the input graph, see if they are part of a |
386 | // match or if they can be left untouched. |
387 | output_graph_def->Clear(); |
388 | for (const NodeDef& input_node : input_graph_def.node()) { |
389 | if (matches_by_head_name.count(input_node.name())) { |
390 | // This node is the beginning of a match, so call the replacement function |
391 | // after setting up some information it will need. |
392 | const NodeMatch* match = matches_by_head_name[input_node.name()]; |
393 | std::vector<NodeDef> matched_nodes_array; |
394 | MatchedNodesAsArray(*match, &matched_nodes_array); |
395 | // This tells us whether a node is part of the current match. |
396 | std::set<string> matched_nodes_lookup; |
397 | for (const NodeDef& matched_node : matched_nodes_array) { |
398 | matched_nodes_lookup.insert(matched_node.name()); |
399 | } |
400 | // These are helper arrays that the replacement function can use to tell |
401 | // whether it can safely remove an internal node (because nothing outside |
402 | // of the match uses it) or whether external nodes depend on it. |
403 | std::set<string> input_nodes; |
404 | std::set<string> output_nodes; |
405 | for (const NodeDef& matched_node : matched_nodes_array) { |
406 | // Look through all of this node's inputs, and if any of them come from |
407 | // outside the match, then this should be noted as one of the external |
408 | // inputs of the subgraph. |
409 | for (const string& input_name : matched_node.input()) { |
410 | string input_node_name = NodeNameFromInput(input_name); |
411 | if (!matched_nodes_lookup.count(input_node_name)) { |
412 | input_nodes.insert(matched_node.name()); |
413 | } |
414 | } |
415 | // Do a reverse input lookup, to see which other nodes use the current |
416 | // one as an input. If any of those nodes are outside the match |
417 | // subgraph, then the current node is marked as an output node that |
418 | // shouldn't be removed. |
419 | if (outputs_map.count(matched_node.name())) { |
420 | for (const NodeDef* dependent_node : |
421 | outputs_map[matched_node.name()]) { |
422 | if (!matched_nodes_lookup.count(dependent_node->name())) { |
423 | output_nodes.insert(matched_node.name()); |
424 | } |
425 | } |
426 | } |
427 | } |
428 | // Call the generator function and add all the returned nodes to the |
429 | // graph. |
430 | std::vector<NodeDef> new_nodes; |
431 | TF_RETURN_IF_ERROR( |
432 | node_generator(*match, input_nodes, output_nodes, &new_nodes)); |
433 | std::set<string> new_node_names; |
434 | for (const NodeDef& new_node : new_nodes) { |
435 | new_node_names.insert(new_node.name()); |
436 | } |
437 | // Check to make sure the generator function preserved all of the nodes |
438 | // that are used elsewhere in the graph, and add them back in if not. |
439 | bool abort_replacement = false; |
440 | if (!options.allow_inconsistencies) { |
441 | for (const string& expected_output : output_nodes) { |
442 | if (!new_node_names.count(expected_output)) { |
443 | LOG(WARNING) << "Expected " << expected_output |
444 | << " to be preserved." ; |
445 | abort_replacement = true; |
446 | } |
447 | } |
448 | } |
449 | if (abort_replacement) { |
450 | LOG(WARNING) << "Generator function didn't preserve needed nodes, " |
451 | << "copying old replacements back in instead." ; |
452 | std::vector<NodeDef> old_nodes; |
453 | MatchedNodesAsArray(*match, &old_nodes); |
454 | for (const NodeDef& old_node : old_nodes) { |
455 | NodeDef* added_node = output_graph_def->mutable_node()->Add(); |
456 | *added_node = old_node; |
457 | } |
458 | } else { |
459 | for (const NodeDef& new_node : new_nodes) { |
460 | NodeDef* added_node = output_graph_def->mutable_node()->Add(); |
461 | *added_node = new_node; |
462 | } |
463 | } |
464 | } else if (!matched_nodes.count(input_node.name())) { |
465 | // This node isn't part of any match, so just copy it over. |
466 | NodeDef* added_node = output_graph_def->mutable_node()->Add(); |
467 | *added_node = input_node; |
468 | } else { |
469 | // Do nothing, because this is an internal part of a matching subgraph, |
470 | // and so will have been replaced by a new replacement subgraph. |
471 | } |
472 | } |
473 | |
474 | return OkStatus(); |
475 | } |
476 | |
477 | Status RenameNodeInputs(const GraphDef& input_graph_def, |
478 | const std::map<string, string>& inputs_to_rename, |
479 | const std::unordered_set<string>& nodes_to_ignore, |
480 | GraphDef* output_graph_def) { |
481 | std::map<string, std::vector<std::pair<string, string>>> |
482 | canonical_inputs_to_rename; |
483 | for (const auto& input_to_rename : inputs_to_rename) { |
484 | canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)] |
485 | .push_back({input_to_rename.first, input_to_rename.second}); |
486 | } |
487 | |
488 | output_graph_def->Clear(); |
489 | for (const NodeDef& node : input_graph_def.node()) { |
490 | NodeDef* new_node = output_graph_def->mutable_node()->Add(); |
491 | *new_node = node; |
492 | new_node->mutable_input()->Clear(); |
493 | for (const string& input_name : node.input()) { |
494 | std::set<string> already_visited; |
495 | string new_input_name = input_name; |
496 | while ( |
497 | canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) { |
498 | string input_node_name = NodeNameFromInput(new_input_name); |
499 | if (already_visited.count(input_node_name)) { |
500 | return errors::InvalidArgument( |
501 | "RenameNodeInputs argument contains a cycle for " , |
502 | input_node_name); |
503 | } |
504 | already_visited.insert(input_node_name); |
505 | if (nodes_to_ignore.count(node.name())) { |
506 | break; |
507 | } |
508 | bool any_match_found = false; |
509 | for (const std::pair<string, string>& input_to_rename : |
510 | canonical_inputs_to_rename.at(input_node_name)) { |
511 | const string& source_name = input_to_rename.first; |
512 | const string& dest_name = input_to_rename.second; |
513 | bool is_match; |
514 | string match_name; |
515 | if (str_util::EndsWith(source_name, ":*" )) { |
516 | is_match = true; |
517 | string prefix; |
518 | string unused_node_name; |
519 | string suffix; |
520 | NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name, |
521 | &suffix); |
522 | match_name = prefix + dest_name + suffix; |
523 | } else { |
524 | is_match = (CanonicalInputName(source_name) == |
525 | CanonicalInputName(new_input_name)); |
526 | match_name = dest_name; |
527 | } |
528 | if (is_match) { |
529 | new_input_name = match_name; |
530 | any_match_found = true; |
531 | } |
532 | } |
533 | if (!any_match_found) { |
534 | break; |
535 | } |
536 | } |
537 | *(new_node->mutable_input()->Add()) = new_input_name; |
538 | } |
539 | } |
540 | return OkStatus(); |
541 | } |
542 | |
543 | void CopyOriginalMatch(const NodeMatch& match, |
544 | std::vector<NodeDef>* new_nodes) { |
545 | std::vector<NodeDef> old_nodes; |
546 | MatchedNodesAsArray(match, &old_nodes); |
547 | for (const NodeDef& old_node : old_nodes) { |
548 | new_nodes->push_back(old_node); |
549 | } |
550 | } |
551 | |
552 | TransformRegistry* GetTransformRegistry() { |
553 | static TransformRegistry transform_registry; |
554 | return &transform_registry; |
555 | } |
556 | |
557 | void FindInvalidInputs(const GraphDef& graph_def, |
558 | std::vector<std::pair<string, string>>* invalid_inputs) { |
559 | std::map<string, const NodeDef*> node_map; |
560 | MapNamesToNodes(graph_def, &node_map); |
561 | |
562 | for (const NodeDef& node : graph_def.node()) { |
563 | for (const string& input : node.input()) { |
564 | string input_node = NodeNameFromInput(input); |
565 | if (!node_map.count(input_node)) { |
566 | invalid_inputs->push_back({node.name(), input_node}); |
567 | } |
568 | } |
569 | } |
570 | } |
571 | |
572 | Status IsGraphValid(const GraphDef& graph_def) { |
573 | std::vector<std::pair<string, string>> invalid_inputs; |
574 | FindInvalidInputs(graph_def, &invalid_inputs); |
575 | if (!invalid_inputs.empty()) { |
576 | std::map<string, const NodeDef*> node_map; |
577 | MapNamesToNodes(graph_def, &node_map); |
578 | for (const std::pair<string, string>& invalid_input : invalid_inputs) { |
579 | LOG(ERROR) << "Invalid input " << invalid_input.second << " for node " |
580 | << invalid_input.first << " - " |
581 | << node_map[invalid_input.first]->DebugString(); |
582 | } |
583 | return errors::Internal( |
584 | "Invalid graph with inputs referring to nonexistent nodes" ); |
585 | } |
586 | return OkStatus(); |
587 | } |
588 | |
589 | Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, |
590 | DataTypeVector* outputs) { |
591 | const OpDef* op_def; |
592 | TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); |
593 | TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs)); |
594 | return OkStatus(); |
595 | } |
596 | |
597 | Status TensorShapeFromString(const string& shape_string, TensorShape* result) { |
598 | if (shape_string.empty()) { |
599 | return errors::InvalidArgument("Specified shape is empty." ); |
600 | } |
601 | std::vector<string> dims_as_str = str_util::Split(shape_string, "," ); |
602 | std::vector<int64_t> dims; |
603 | for (const string& dim : dims_as_str) { |
604 | int64_t tmp; |
605 | if (strings::safe_strto64(dim, &tmp)) { |
606 | dims.push_back(tmp); |
607 | } else { |
608 | return errors::InvalidArgument("Could parse as shape: '" , shape_string, |
609 | "'" ); |
610 | } |
611 | } |
612 | *result = TensorShape(dims); |
613 | return OkStatus(); |
614 | } |
615 | |
616 | int TransformFuncContext::CountParameters(const string& name) const { |
617 | if (params.count(name)) { |
618 | return params.at(name).size(); |
619 | } else { |
620 | return 0; |
621 | } |
622 | } |
623 | |
624 | Status TransformFuncContext::GetOneStringParameter(const string& name, |
625 | const string& default_value, |
626 | string* result) const { |
627 | const int params_count = CountParameters(name); |
628 | if (params_count == 0) { |
629 | *result = default_value; |
630 | return OkStatus(); |
631 | } else if (params_count == 1) { |
632 | *result = params.at(name).at(0); |
633 | return OkStatus(); |
634 | } else { |
635 | return errors::InvalidArgument("Expected a single '" , name, |
636 | "' parameter, but found " , params_count, |
637 | " occurrences" ); |
638 | } |
639 | } |
640 | |
641 | Status TransformFuncContext::GetOneInt32Parameter(const string& name, |
642 | int32_t default_value, |
643 | int32* result) const { |
644 | const int params_count = CountParameters(name); |
645 | if (params_count == 0) { |
646 | *result = default_value; |
647 | return OkStatus(); |
648 | } |
649 | string string_value; |
650 | TF_RETURN_IF_ERROR(GetOneStringParameter(name, "" , &string_value)); |
651 | if (!strings::safe_strto32(StringPiece(string_value), result)) { |
652 | return errors::InvalidArgument("Couldn't interpret the " , name, |
653 | " argument as a number:" , string_value); |
654 | } |
655 | return OkStatus(); |
656 | } |
657 | |
658 | Status TransformFuncContext::GetOneInt64Parameter(const string& name, |
659 | int64_t default_value, |
660 | int64_t* result) const { |
661 | const int params_count = CountParameters(name); |
662 | if (params_count == 0) { |
663 | *result = default_value; |
664 | return OkStatus(); |
665 | } |
666 | string string_value; |
667 | TF_RETURN_IF_ERROR(GetOneStringParameter(name, "" , &string_value)); |
668 | if (!strings::safe_strto64(StringPiece(string_value), result)) { |
669 | return errors::InvalidArgument("Couldn't interpret the " , name, |
670 | " argument as a number:" , string_value); |
671 | } |
672 | return OkStatus(); |
673 | } |
674 | |
675 | Status TransformFuncContext::GetOneFloatParameter(const string& name, |
676 | float default_value, |
677 | float* result) const { |
678 | const int params_count = CountParameters(name); |
679 | if (params_count == 0) { |
680 | *result = default_value; |
681 | return OkStatus(); |
682 | } |
683 | string string_value; |
684 | TF_RETURN_IF_ERROR(GetOneStringParameter(name, "" , &string_value)); |
685 | if (!strings::safe_strtof(string_value.c_str(), result)) { |
686 | return errors::InvalidArgument( |
687 | "Couldn't interpret the " , name, |
688 | " argument as a float number:" , string_value); |
689 | } |
690 | return OkStatus(); |
691 | } |
692 | |
693 | Status TransformFuncContext::GetOneBoolParameter(const string& name, |
694 | bool default_value, |
695 | bool* result) const { |
696 | const int params_count = CountParameters(name); |
697 | if (params_count == 0) { |
698 | *result = default_value; |
699 | return OkStatus(); |
700 | } |
701 | string string_value; |
702 | TF_RETURN_IF_ERROR(GetOneStringParameter(name, "" , &string_value)); |
703 | if (string_value == "true" || string_value == "1" ) { |
704 | *result = true; |
705 | } else if (string_value == "false" || string_value == "0" ) { |
706 | *result = false; |
707 | } else { |
708 | return errors::InvalidArgument("Couldn't interpret the " , name, |
709 | " argument as a boolean:" , string_value, |
710 | " (expected true, false, 0 or 1)" ); |
711 | } |
712 | return OkStatus(); |
713 | } |
714 | |
715 | } // namespace graph_transforms |
716 | } // namespace tensorflow |
717 | |