1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/graph_to_functiondef.h" |
17 | |
18 | #include <unordered_map> |
19 | #include <unordered_set> |
20 | |
21 | #include "tensorflow/core/framework/attr_value_util.h" |
22 | #include "tensorflow/core/framework/function.pb.h" |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/node_def_util.h" |
25 | #include "tensorflow/core/framework/tensor.pb.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/graph/graph.h" |
28 | #include "tensorflow/core/graph/graph_node_util.h" |
29 | #include "tensorflow/core/graph/tensor_id.h" |
30 | #include "tensorflow/core/lib/core/errors.h" |
31 | #include "tensorflow/core/lib/core/status.h" |
32 | #include "tensorflow/core/lib/strings/base64.h" |
33 | #include "tensorflow/core/lib/strings/str_util.h" |
34 | #include "tensorflow/core/lib/strings/strcat.h" |
35 | |
36 | namespace tensorflow { |
37 | namespace { |
38 | |
39 | // Class that maintains a one-to-one original node name -> new node name |
40 | // mapping. We normalize the names used as input and output arguments to match |
41 | // regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name. |
42 | // Once we rename them, we risk creating a name collision with the other |
43 | // node names, so if necessary we add a suffix to make |
44 | // names unique. If we have an input named "A" and a node in the function |
45 | // body named "a", they will be renamed to "a" and "a_0". |
46 | class NodeNameMapping { |
47 | public: |
48 | NodeNameMapping() = default; |
49 | |
50 | // Normalize the input name and make it unique. This is the same as the |
51 | // function for output, expect that it adds a name mapping for the name. |
52 | string GetInputName(const string& name); |
53 | |
54 | // Normalize the output name and make it unique. |
55 | string GetOutputName(const string& name); |
56 | |
57 | // Make the node name unique. |
58 | string Uniquify(const string& name); |
59 | |
60 | // Records name as a used name. If this name is already used, |
61 | // returns an error status. |
62 | Status UseOutputName(const string& name); |
63 | |
64 | // Look up how a node name was previously normalized/uniquified. |
65 | // Returns empty if name was never seen. |
66 | string Lookup(const string& name) const; |
67 | |
68 | private: |
69 | string UniquifyHelper(const string& name); |
70 | static string Normalize(string name); |
71 | |
72 | // The normalized/uniquified names already used as |
73 | // input names (in signature), output names (in signature), and node names |
74 | // (in node_def). |
75 | // This is a superset of values in name_mapping_. |
76 | std::unordered_map<string, uint64> used_names_; |
77 | // Mapping from original node name from the graph to the normalized |
78 | // and uniquified version of it. |
79 | std::unordered_map<string, string> name_mapping_; |
80 | }; |
81 | |
82 | string NodeNameMapping::Normalize(string name) { |
83 | // Convert letters to lowercase and non-alphanumeric characters to '_'. |
84 | if (name.empty()) return "unknown" ; |
85 | const int n = name.size(); |
86 | for (int i = 0; i < n; ++i) { |
87 | char c = name[i]; |
88 | if (isalnum(c)) { |
89 | if (isupper(c)) { |
90 | name[i] = tolower(c); |
91 | } |
92 | } else { |
93 | name[i] = '_'; |
94 | } |
95 | } |
96 | |
97 | // Find the first letter and start with it. |
98 | int i = 0; |
99 | for (; i < n; ++i) { |
100 | if (isalpha(name[i])) break; |
101 | } |
102 | |
103 | // Return "unknown" if none of the name's chars were letters. |
104 | return i == n ? "unknown" : name.substr(i); |
105 | } |
106 | |
107 | string NodeNameMapping::UniquifyHelper(const string& name) { |
108 | auto it = used_names_.emplace(name, 0); |
109 | // If the name hasn't been used yet, use it as-is. |
110 | if (it.second) return name; |
111 | |
112 | // Add a suffix to name to make it unique. |
113 | while (true) { |
114 | const string candidate = strings::StrCat(name, "_" , it.first->second); |
115 | it.first->second++; |
116 | if (used_names_.emplace(candidate, 0).second) return candidate; |
117 | } |
118 | } |
119 | |
120 | string NodeNameMapping::GetInputName(const string& name) { |
121 | const string& input_name = UniquifyHelper(Normalize(name)); |
122 | name_mapping_[name] = input_name; |
123 | return input_name; |
124 | } |
125 | |
126 | string NodeNameMapping::GetOutputName(const string& name) { |
127 | const string& input_name = UniquifyHelper(Normalize(name)); |
128 | // Don't add it to name_mapping_ since this name is not for a node. |
129 | return input_name; |
130 | } |
131 | |
132 | string NodeNameMapping::Uniquify(const string& name) { |
133 | const string uniqued = UniquifyHelper(name); |
134 | name_mapping_[name] = uniqued; |
135 | return uniqued; |
136 | } |
137 | |
138 | Status NodeNameMapping::UseOutputName(const string& name) { |
139 | const auto& iter = used_names_.find(name); |
140 | if (iter != used_names_.end()) { |
141 | return errors::InvalidArgument( |
142 | "Cannot have duplicate output names. Name '" , name, |
143 | "' appears more than once in 'output_names' array." ); |
144 | } |
145 | used_names_.emplace(name, 0); |
146 | return OkStatus(); |
147 | } |
148 | |
149 | string NodeNameMapping::Lookup(const string& name) const { |
150 | const auto iter = name_mapping_.find(name); |
151 | if (iter == name_mapping_.end()) return string(); |
152 | return iter->second; |
153 | } |
154 | |
155 | Status FillFunctionBody( |
156 | const string& fn_name, const NodeNameMapping& node_names, |
157 | const std::vector<const Node*>& body_nodes, |
158 | const std::unordered_map<string, string>& tensor_renaming, |
159 | bool set_stateful_from_nodes, bool copy_placeholder_attrs_from_nodes, |
160 | FunctionDef* fdef) { |
161 | std::unordered_set<string> func_attr_names; |
162 | for (const auto& func_attr : fdef->signature().attr()) { |
163 | func_attr_names.insert(func_attr.name()); |
164 | } |
165 | |
166 | std::vector<const Edge*> in_edges; |
167 | std::vector<const Edge*> control_edges; |
168 | for (const Node* node : body_nodes) { |
169 | NodeDef* node_def = fdef->add_node_def(); |
170 | // First, copy the node_def as is. We will patch it next. |
171 | *node_def = node->def(); |
172 | if (!node->assigned_device_name().empty()) { |
173 | node_def->set_device(node->assigned_device_name()); |
174 | } |
175 | node_def->set_name(node_names.Lookup(node->name())); |
176 | MergeDebugInfo(NodeDebugInfo(node->def()), node_def); |
177 | |
178 | // Input names must be set based on nested names in tensor_renaming. |
179 | // Clear the flat input names we got from the original node_def |
180 | // from the graph. |
181 | node_def->clear_input(); |
182 | |
183 | // Collect regular and control inputs. Regular inputs are indexed |
184 | // by the index at which they come into the `node`. Control inputs |
185 | // don't follow any order, and we sort control inputs to make sure generated |
186 | // NodeDef is deterministic. |
187 | in_edges.clear(); |
188 | in_edges.resize(node->num_inputs(), nullptr); |
189 | control_edges.clear(); |
190 | for (const Edge* edge : node->in_edges()) { |
191 | if (edge->src()->IsSource()) continue; |
192 | if (edge->IsControlEdge()) { |
193 | control_edges.push_back(edge); |
194 | } else { |
195 | in_edges[edge->dst_input()] = edge; |
196 | } |
197 | } |
198 | std::sort(control_edges.begin(), control_edges.end(), |
199 | [](const Edge* a, const Edge* b) { |
200 | return a->src()->name() < b->src()->name(); |
201 | }); |
202 | |
203 | // Add regular inputs. |
204 | for (size_t i = 0; i < in_edges.size(); ++i) { |
205 | const Edge* edge = in_edges[i]; |
206 | string original_input_name; |
207 | if (edge == nullptr) { |
208 | // A backedge might not appear as a regular Edge, but be only present |
209 | // in the node_def. Such edges are referred to as requested_inputs(). |
210 | if (i >= node->requested_inputs().size()) { |
211 | return errors::InvalidArgument( |
212 | "Graph to be converted to function appears to be malformed. " , |
213 | "Node " , node->name(), " is missing input edge " , i); |
214 | } |
215 | original_input_name = |
216 | ParseTensorName(node->requested_inputs()[i]).ToString(); |
217 | } else { |
218 | original_input_name = |
219 | strings::StrCat(edge->src()->name(), ":" , edge->src_output()); |
220 | } |
221 | |
222 | const auto iter = tensor_renaming.find(original_input_name); |
223 | if (iter == tensor_renaming.end()) { |
224 | return errors::InvalidArgument( |
225 | "Input " , i, ", '" , original_input_name, "', of node '" , |
226 | node->name(), "' in function '" , fn_name, |
227 | "' is not available. You might need to include it in inputs " |
228 | "or include its source node in the body" ); |
229 | } |
230 | node_def->add_input(iter->second); |
231 | } |
232 | |
233 | // Add control inputs. |
234 | for (const Edge* edge : control_edges) { |
235 | // Add this control input only if the src node is in the body or a part of |
236 | // the inputs. |
237 | const string normalized = node_names.Lookup(edge->src()->name()); |
238 | // If we did not find a name for the source of control edge, this |
239 | // source must be outside of the body, and not an input. Raise an error. |
240 | if (normalized.empty()) { |
241 | return errors::InvalidArgument( |
242 | "The source of control edge " , edge->DebugString(), |
243 | " is not in the body. Encountered while creating function '" , |
244 | fn_name, "'" ); |
245 | } |
246 | node_def->add_input(strings::StrCat("^" , normalized)); |
247 | } |
248 | |
249 | // A function is stateful if any of its nodes are stateful. |
250 | if (set_stateful_from_nodes && node->op_def().is_stateful()) { |
251 | fdef->mutable_signature()->set_is_stateful(true); |
252 | } |
253 | |
254 | // If this node has any attributes with placeholder value, add the |
255 | // attribute to FunctionDef signature. |
256 | if (!copy_placeholder_attrs_from_nodes) { |
257 | continue; |
258 | } |
259 | for (const auto& iter : node->attrs()) { |
260 | if (iter.second.placeholder().empty()) { |
261 | continue; |
262 | } |
263 | |
264 | // If we already added the attribute, skip it. |
265 | string func_attr_name = iter.second.placeholder(); |
266 | if (func_attr_names.find(func_attr_name) != func_attr_names.end()) { |
267 | continue; |
268 | } |
269 | |
270 | // This node's attribute is a placeholder value, so it does not have type |
271 | // information. We check node's OpDef for attribute type. |
272 | string node_attr_name = iter.first; |
273 | const OpDef::AttrDef* node_attr_def = nullptr; |
274 | for (const auto& node_attr : node->op_def().attr()) { |
275 | if (node_attr.name() == node_attr_name) { |
276 | node_attr_def = &node_attr; |
277 | } |
278 | } |
279 | if (!node_attr_def) { |
280 | return errors::Unimplemented( |
281 | "Placeholder value is not supported for attributes not in OpDef. " |
282 | "Attribute: " , |
283 | node_attr_name, ", OpDef: " , node->op_def().DebugString()); |
284 | } |
285 | OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr(); |
286 | attr_def->set_name(func_attr_name); |
287 | attr_def->set_type(node_attr_def->type()); |
288 | |
289 | func_attr_names.insert(func_attr_name); |
290 | } |
291 | } |
292 | return OkStatus(); |
293 | } |
294 | |
295 | Status GraphToFunctionDefHelper( |
296 | const Graph& graph, const string& name, |
297 | const std::function<absl::optional<string>(const Node*)>& control_ret, |
298 | const std::vector<string>& output_names, FunctionDef* fdef) { |
299 | auto add_arg_or_retval = [](Node* node, |
300 | std::vector<OutputTensor>* args_or_retvals) { |
301 | int index; |
302 | TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index" , &index)); |
303 | if (index >= args_or_retvals->size()) { |
304 | args_or_retvals->resize(index + 1); |
305 | } |
306 | if ((*args_or_retvals)[index].node == nullptr) { |
307 | (*args_or_retvals)[index].node = node; |
308 | } else { |
309 | return errors::InvalidArgument("Multiple '" , node->type_string(), |
310 | "' nodes found with index " , index); |
311 | } |
312 | return OkStatus(); |
313 | }; |
314 | |
315 | std::vector<const Node*> body_nodes; |
316 | std::vector<OutputTensor> inputs; |
317 | std::vector<OutputTensor> outputs; |
318 | std::vector<const Node*> control_outputs; |
319 | std::vector<string> control_output_names; |
320 | for (Node* node : graph.op_nodes()) { |
321 | if (node->IsArg()) { |
322 | TF_RETURN_IF_ERROR(add_arg_or_retval(node, &inputs)); |
323 | continue; |
324 | } |
325 | |
326 | if (node->IsRetval()) { |
327 | TF_RETURN_IF_ERROR(add_arg_or_retval(node, &outputs)); |
328 | continue; |
329 | } |
330 | |
331 | if (control_ret) { |
332 | auto control_ret_name = control_ret(node); |
333 | if (control_ret_name.has_value()) { |
334 | control_outputs.push_back(node); |
335 | control_output_names.push_back(control_ret_name.value()); |
336 | } |
337 | } |
338 | |
339 | body_nodes.push_back(node); |
340 | } |
341 | |
342 | auto validate_args_retvals = |
343 | [](const std::vector<OutputTensor>& args_or_retvals, |
344 | const string& op_type) { |
345 | for (int i = 0, e = args_or_retvals.size(); i < e; ++i) { |
346 | if (args_or_retvals[i].node == nullptr) { |
347 | return errors::InvalidArgument("Missing '" , op_type, |
348 | "' node at index " , i); |
349 | } |
350 | } |
351 | return OkStatus(); |
352 | }; |
353 | |
354 | TF_RETURN_IF_ERROR(validate_args_retvals(inputs, "_Arg" )); |
355 | TF_RETURN_IF_ERROR(validate_args_retvals(outputs, "_Retval" )); |
356 | |
357 | return GraphToFunctionDef(graph, name, /*append_hash_to_fn_name=*/false, |
358 | /*set_stateful_from_nodes=*/false, |
359 | /*copy_placeholder_attrs_from_nodes=*/false, |
360 | body_nodes, inputs, outputs, output_names, |
361 | control_outputs, control_output_names, |
362 | /*description=*/nullptr, fdef); |
363 | } |
364 | |
365 | } // anonymous namespace |
366 | |
367 | Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, |
368 | bool append_hash_to_fn_name, |
369 | bool set_stateful_from_nodes, |
370 | bool copy_placeholder_attrs_from_nodes, |
371 | const std::vector<const Node*>& body_nodes, |
372 | const std::vector<OutputTensor>& inputs, |
373 | const std::vector<OutputTensor>& outputs, |
374 | const std::vector<string>& output_names, |
375 | const std::vector<const Node*>& control_outputs, |
376 | const std::vector<string>& control_output_names, |
377 | const char* description, FunctionDef* fdef) { |
378 | if (!output_names.empty()) { |
379 | DCHECK_EQ(output_names.size(), outputs.size()); |
380 | } |
381 | |
382 | if (description != nullptr) { |
383 | fdef->mutable_signature()->set_description(description); |
384 | } |
385 | |
386 | // Keep track of names we used and how we normalized them. |
387 | NodeNameMapping node_names; |
388 | |
389 | // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the |
390 | // name we used in the function: |
391 | // - For input tensors: |
392 | // {flat_tensor_name -> normalized_name_of_src_node} |
393 | // e.g. {In:3 -> in} |
394 | // - For tensors produced by nodes in function's body: |
395 | // {flat_tensor_name -> nested_tensor_name} |
396 | // e.g. {Add:3 -> add_0:z:1} |
397 | std::unordered_map<string, string> tensor_renaming; |
398 | |
399 | // Fill outputs in function's signature. |
400 | // We fill the outputs first to prevent output_names from colliding |
401 | // with the input names we pick below. With this order, no names are used in |
402 | // node_names yet, and output_names won't collide with anything (except |
403 | // potentially with themselves). |
404 | for (size_t i = 0; i < outputs.size(); ++i) { |
405 | const Node* node = outputs[i].node; |
406 | int idx = outputs[i].index; |
407 | OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg(); |
408 | if (node->IsRetval()) { |
409 | argdef->set_type(node->input_type(idx)); |
410 | } else { |
411 | argdef->set_type(node->output_type(idx)); |
412 | } |
413 | if (!output_names.empty()) { |
414 | TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i])); |
415 | argdef->set_name(output_names[i]); |
416 | } else { |
417 | argdef->set_name(node_names.GetOutputName(node->name())); |
418 | } |
419 | } |
420 | |
421 | // Fill inputs in function's signature. |
422 | for (size_t i = 0; i < inputs.size(); ++i) { |
423 | const Node* node = inputs[i].node; |
424 | int idx = inputs[i].index; |
425 | OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); |
426 | argdef->set_type(node->output_type(idx)); |
427 | const string& input_name = node_names.GetInputName(node->name()); |
428 | argdef->set_name(input_name); |
429 | FunctionDef::ArgAttrs arg_attrs; |
430 | int64_t resource_arg_unique_id = -1; |
431 | for (const auto& attr : node->attrs()) { |
432 | // Only copy internal attributes. These attributes will be applied to |
433 | // _Arg/Placeholder nodes when this FunctionDef is converted to graph, |
434 | // and normal attributes for nodes cannot be applied to those |
435 | // _Arg/Placeholder nodes. |
436 | if (absl::StartsWith(attr.first, "_" )) { |
437 | arg_attrs.mutable_attr()->insert(attr); |
438 | } else if (attr.first == "shape" && argdef->type() != DT_RESOURCE) { |
439 | // Preserve known shapes by moving them to the _output_shapes list. |
440 | // The _Arg shape function knows how to extract them from there. |
441 | // Don't preserve the shape of a resource arg node, which is a scalar |
442 | // resource handle. |
443 | AttrValue value; |
444 | *(value.mutable_list()->add_shape()) = attr.second.shape(); |
445 | arg_attrs.mutable_attr()->insert({"_output_shapes" , value}); |
446 | } else if (attr.first == "value" && node->type_string() == "Const" ) { |
447 | // Small eager tensors are captured as const ops rather than |
448 | // Placeholders. Add a _output_shapes arg_attr with the shape of the |
449 | // const tensor. |
450 | AttrValue value; |
451 | *(value.mutable_list()->add_shape()) = |
452 | attr.second.tensor().tensor_shape(); |
453 | arg_attrs.mutable_attr()->insert({"_output_shapes" , value}); |
454 | } |
455 | if (attr.first == "_resource_arg_unique_id" ) { |
456 | resource_arg_unique_id = attr.second.i(); |
457 | } |
458 | } |
459 | if (arg_attrs.attr_size() > 0) { |
460 | (*fdef->mutable_arg_attr())[i] = std::move(arg_attrs); |
461 | } |
462 | if (resource_arg_unique_id >= 0) { |
463 | (*fdef->mutable_resource_arg_unique_id())[idx] = resource_arg_unique_id; |
464 | } |
465 | tensor_renaming[strings::StrCat(node->name(), ":" , idx)] = input_name; |
466 | } |
467 | |
468 | // Populate tensor_renaming and node_names. |
469 | // Generate the new output names for every node in the function. |
470 | // The NodeDefs in FunctionDefs use a different naming scheme for |
471 | // their inputs than the NodeDefs in a graph (see the comment for |
472 | // FunctionDef.node_def in function.proto). We do the |
473 | // graph tensor name -> function tensor name conversion for every |
474 | // possible input (i.e. every node's outputs) and store the result |
475 | // in tensor_renaming. |
476 | for (const Node* node : body_nodes) { |
477 | // Make sure node_name does not collide with an input or output name. |
478 | const string& node_name = node_names.Uniquify(node->name()); |
479 | // For each output_arg in the op_def, the output_ranges |
480 | // map will have [start, end] range of indices that this arg produces |
481 | // among all the output tensors of this op. |
482 | NameRangeMap output_ranges; |
483 | TF_RETURN_IF_ERROR( |
484 | NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); |
485 | for (const auto& output : output_ranges) { |
486 | const StringPiece& output_name = output.first; |
487 | int index_start = output.second.first; |
488 | int index_end = output.second.second; |
489 | for (int i = index_start; i < index_end; ++i) { |
490 | const string& original_name = strings::StrCat(node->name(), ":" , i); |
491 | const string& new_name = |
492 | strings::StrCat(node_name, ":" , output_name, ":" , i - index_start); |
493 | // Record the mapping if this tensor is not already mapped. |
494 | // Tensor can be already mapped if it is used as an input. |
495 | if (tensor_renaming.find(original_name) == tensor_renaming.end()) { |
496 | tensor_renaming[original_name] = new_name; |
497 | } |
498 | } |
499 | } |
500 | } |
501 | |
502 | TF_RETURN_IF_ERROR(FillFunctionBody(fn_name, node_names, body_nodes, |
503 | tensor_renaming, set_stateful_from_nodes, |
504 | copy_placeholder_attrs_from_nodes, fdef)); |
505 | |
506 | // Remap return values. |
507 | for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { |
508 | const string& ret_name = fdef->signature().output_arg(r).name(); |
509 | // We convert this flat tensor name to the nested value |
510 | // (e.g. `add:z:1`) that we stored in tensor_renaming. |
511 | string return_value; |
512 | if (outputs[r].node->IsRetval()) { |
513 | Edge const* edge; |
514 | TF_RETURN_IF_ERROR(outputs[r].node->input_edge(0, &edge)); |
515 | return_value = |
516 | strings::StrCat(edge->src()->name(), ":" , edge->src_output()); |
517 | } else { |
518 | return_value = |
519 | strings::StrCat(outputs[r].node->name(), ":" , outputs[r].index); |
520 | } |
521 | const auto iter = tensor_renaming.find(return_value); |
522 | if (iter == tensor_renaming.end()) { |
523 | return errors::InvalidArgument( |
524 | "TF_Output " , return_value, " is neither in the function body " , |
525 | "nor among function inputs. Encountered while creating function '" , |
526 | fn_name, "'" ); |
527 | } |
528 | (*fdef->mutable_ret())[ret_name] = iter->second; |
529 | } |
530 | |
531 | if (append_hash_to_fn_name) { |
532 | const uint64 hash = FunctionDefHash(*fdef); |
533 | string encoded; |
534 | TF_RETURN_IF_ERROR(Base64Encode( |
535 | StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)), |
536 | &encoded)); |
537 | // Besides letters and digits our Base64 encoding uses '_' and '-'. |
538 | // Dash is invalid in operation names and multiple underscores in random |
539 | // places look strange. Since we never need to decode the hash back, |
540 | // replace these chars with 'a' and 'A'. Replacing with different letters |
541 | // keeps more entropy. |
542 | std::replace(encoded.begin(), encoded.end(), '-', 'a'); |
543 | std::replace(encoded.begin(), encoded.end(), '_', 'A'); |
544 | fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_" , encoded)); |
545 | } else { |
546 | fdef->mutable_signature()->set_name(fn_name); |
547 | } |
548 | |
549 | if (!control_output_names.empty() && |
550 | (control_outputs.size() != control_output_names.size())) { |
551 | return errors::InvalidArgument( |
552 | "Expected number of control outputs (" , control_outputs.size(), |
553 | ") and the number of control output names (" , |
554 | control_output_names.size(), ") to match but they do not." ); |
555 | } |
556 | std::set<string> control_output_names_set; |
557 | for (int i = 0; i < control_outputs.size(); ++i) { |
558 | string signature_name; |
559 | if (!control_output_names.empty()) { |
560 | signature_name = control_output_names[i]; |
561 | } else { |
562 | signature_name = control_outputs[i]->name(); |
563 | } |
564 | if (signature_name.empty()) { |
565 | return errors::InvalidArgument("Control output name must be not empty" ); |
566 | } |
567 | if (!control_output_names_set.insert(signature_name).second) { |
568 | return errors::InvalidArgument("Repeated control output name: " , |
569 | signature_name); |
570 | } |
571 | const string control_output_node = |
572 | node_names.Lookup(control_outputs[i]->name()); |
573 | if (control_output_node.empty()) { |
574 | return errors::InvalidArgument( |
575 | "Control output node name must be not empty" ); |
576 | } |
577 | (*fdef->mutable_control_ret())[signature_name] = control_output_node; |
578 | } |
579 | for (const string& control_output : control_output_names_set) { |
580 | fdef->mutable_signature()->add_control_output(control_output); |
581 | } |
582 | |
583 | return OkStatus(); |
584 | } |
585 | |
586 | Status GraphToFunctionDef( |
587 | const Graph& graph, const string& name, |
588 | const std::function<absl::optional<string>(const Node*)>& control_ret, |
589 | FunctionDef* fdef) { |
590 | return GraphToFunctionDefHelper(graph, name, control_ret, |
591 | /*output_names=*/{}, fdef); |
592 | } |
593 | |
594 | Status GraphToFunctionDef(const Graph& graph, const string& name, |
595 | FunctionDef* fdef) { |
596 | return GraphToFunctionDef(graph, name, /*control_ret=*/nullptr, fdef); |
597 | } |
598 | |
599 | Status GraphToFunctionDef(const Graph& graph, const string& name, |
600 | const std::vector<std::string>& output_names, |
601 | FunctionDef* fdef) { |
602 | return GraphToFunctionDefHelper(graph, name, /*control_ret=*/nullptr, |
603 | output_names, fdef); |
604 | } |
605 | |
606 | } // namespace tensorflow |
607 | |