1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/graph_to_functiondef.h"
17
18#include <unordered_map>
19#include <unordered_set>
20
21#include "tensorflow/core/framework/attr_value_util.h"
22#include "tensorflow/core/framework/function.pb.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/node_def_util.h"
25#include "tensorflow/core/framework/tensor.pb.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/graph/graph.h"
28#include "tensorflow/core/graph/graph_node_util.h"
29#include "tensorflow/core/graph/tensor_id.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/lib/strings/base64.h"
33#include "tensorflow/core/lib/strings/str_util.h"
34#include "tensorflow/core/lib/strings/strcat.h"
35
36namespace tensorflow {
37namespace {
38
39// Class that maintains a one-to-one original node name -> new node name
40// mapping. We normalize the names used as input and output arguments to match
41// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
42// Once we rename them, we risk creating a name collision with the other
43// node names, so if necessary we add a suffix to make
44// names unique. If we have an input named "A" and a node in the function
45// body named "a", they will be renamed to "a" and "a_0".
46class NodeNameMapping {
47 public:
48 NodeNameMapping() = default;
49
50 // Normalize the input name and make it unique. This is the same as the
51 // function for output, expect that it adds a name mapping for the name.
52 string GetInputName(const string& name);
53
54 // Normalize the output name and make it unique.
55 string GetOutputName(const string& name);
56
57 // Make the node name unique.
58 string Uniquify(const string& name);
59
60 // Records name as a used name. If this name is already used,
61 // returns an error status.
62 Status UseOutputName(const string& name);
63
64 // Look up how a node name was previously normalized/uniquified.
65 // Returns empty if name was never seen.
66 string Lookup(const string& name) const;
67
68 private:
69 string UniquifyHelper(const string& name);
70 static string Normalize(string name);
71
72 // The normalized/uniquified names already used as
73 // input names (in signature), output names (in signature), and node names
74 // (in node_def).
75 // This is a superset of values in name_mapping_.
76 std::unordered_map<string, uint64> used_names_;
77 // Mapping from original node name from the graph to the normalized
78 // and uniquified version of it.
79 std::unordered_map<string, string> name_mapping_;
80};
81
82string NodeNameMapping::Normalize(string name) {
83 // Convert letters to lowercase and non-alphanumeric characters to '_'.
84 if (name.empty()) return "unknown";
85 const int n = name.size();
86 for (int i = 0; i < n; ++i) {
87 char c = name[i];
88 if (isalnum(c)) {
89 if (isupper(c)) {
90 name[i] = tolower(c);
91 }
92 } else {
93 name[i] = '_';
94 }
95 }
96
97 // Find the first letter and start with it.
98 int i = 0;
99 for (; i < n; ++i) {
100 if (isalpha(name[i])) break;
101 }
102
103 // Return "unknown" if none of the name's chars were letters.
104 return i == n ? "unknown" : name.substr(i);
105}
106
107string NodeNameMapping::UniquifyHelper(const string& name) {
108 auto it = used_names_.emplace(name, 0);
109 // If the name hasn't been used yet, use it as-is.
110 if (it.second) return name;
111
112 // Add a suffix to name to make it unique.
113 while (true) {
114 const string candidate = strings::StrCat(name, "_", it.first->second);
115 it.first->second++;
116 if (used_names_.emplace(candidate, 0).second) return candidate;
117 }
118}
119
120string NodeNameMapping::GetInputName(const string& name) {
121 const string& input_name = UniquifyHelper(Normalize(name));
122 name_mapping_[name] = input_name;
123 return input_name;
124}
125
126string NodeNameMapping::GetOutputName(const string& name) {
127 const string& input_name = UniquifyHelper(Normalize(name));
128 // Don't add it to name_mapping_ since this name is not for a node.
129 return input_name;
130}
131
132string NodeNameMapping::Uniquify(const string& name) {
133 const string uniqued = UniquifyHelper(name);
134 name_mapping_[name] = uniqued;
135 return uniqued;
136}
137
138Status NodeNameMapping::UseOutputName(const string& name) {
139 const auto& iter = used_names_.find(name);
140 if (iter != used_names_.end()) {
141 return errors::InvalidArgument(
142 "Cannot have duplicate output names. Name '", name,
143 "' appears more than once in 'output_names' array.");
144 }
145 used_names_.emplace(name, 0);
146 return OkStatus();
147}
148
149string NodeNameMapping::Lookup(const string& name) const {
150 const auto iter = name_mapping_.find(name);
151 if (iter == name_mapping_.end()) return string();
152 return iter->second;
153}
154
155Status FillFunctionBody(
156 const string& fn_name, const NodeNameMapping& node_names,
157 const std::vector<const Node*>& body_nodes,
158 const std::unordered_map<string, string>& tensor_renaming,
159 bool set_stateful_from_nodes, bool copy_placeholder_attrs_from_nodes,
160 FunctionDef* fdef) {
161 std::unordered_set<string> func_attr_names;
162 for (const auto& func_attr : fdef->signature().attr()) {
163 func_attr_names.insert(func_attr.name());
164 }
165
166 std::vector<const Edge*> in_edges;
167 std::vector<const Edge*> control_edges;
168 for (const Node* node : body_nodes) {
169 NodeDef* node_def = fdef->add_node_def();
170 // First, copy the node_def as is. We will patch it next.
171 *node_def = node->def();
172 if (!node->assigned_device_name().empty()) {
173 node_def->set_device(node->assigned_device_name());
174 }
175 node_def->set_name(node_names.Lookup(node->name()));
176 MergeDebugInfo(NodeDebugInfo(node->def()), node_def);
177
178 // Input names must be set based on nested names in tensor_renaming.
179 // Clear the flat input names we got from the original node_def
180 // from the graph.
181 node_def->clear_input();
182
183 // Collect regular and control inputs. Regular inputs are indexed
184 // by the index at which they come into the `node`. Control inputs
185 // don't follow any order, and we sort control inputs to make sure generated
186 // NodeDef is deterministic.
187 in_edges.clear();
188 in_edges.resize(node->num_inputs(), nullptr);
189 control_edges.clear();
190 for (const Edge* edge : node->in_edges()) {
191 if (edge->src()->IsSource()) continue;
192 if (edge->IsControlEdge()) {
193 control_edges.push_back(edge);
194 } else {
195 in_edges[edge->dst_input()] = edge;
196 }
197 }
198 std::sort(control_edges.begin(), control_edges.end(),
199 [](const Edge* a, const Edge* b) {
200 return a->src()->name() < b->src()->name();
201 });
202
203 // Add regular inputs.
204 for (size_t i = 0; i < in_edges.size(); ++i) {
205 const Edge* edge = in_edges[i];
206 string original_input_name;
207 if (edge == nullptr) {
208 // A backedge might not appear as a regular Edge, but be only present
209 // in the node_def. Such edges are referred to as requested_inputs().
210 if (i >= node->requested_inputs().size()) {
211 return errors::InvalidArgument(
212 "Graph to be converted to function appears to be malformed. ",
213 "Node ", node->name(), " is missing input edge ", i);
214 }
215 original_input_name =
216 ParseTensorName(node->requested_inputs()[i]).ToString();
217 } else {
218 original_input_name =
219 strings::StrCat(edge->src()->name(), ":", edge->src_output());
220 }
221
222 const auto iter = tensor_renaming.find(original_input_name);
223 if (iter == tensor_renaming.end()) {
224 return errors::InvalidArgument(
225 "Input ", i, ", '", original_input_name, "', of node '",
226 node->name(), "' in function '", fn_name,
227 "' is not available. You might need to include it in inputs "
228 "or include its source node in the body");
229 }
230 node_def->add_input(iter->second);
231 }
232
233 // Add control inputs.
234 for (const Edge* edge : control_edges) {
235 // Add this control input only if the src node is in the body or a part of
236 // the inputs.
237 const string normalized = node_names.Lookup(edge->src()->name());
238 // If we did not find a name for the source of control edge, this
239 // source must be outside of the body, and not an input. Raise an error.
240 if (normalized.empty()) {
241 return errors::InvalidArgument(
242 "The source of control edge ", edge->DebugString(),
243 " is not in the body. Encountered while creating function '",
244 fn_name, "'");
245 }
246 node_def->add_input(strings::StrCat("^", normalized));
247 }
248
249 // A function is stateful if any of its nodes are stateful.
250 if (set_stateful_from_nodes && node->op_def().is_stateful()) {
251 fdef->mutable_signature()->set_is_stateful(true);
252 }
253
254 // If this node has any attributes with placeholder value, add the
255 // attribute to FunctionDef signature.
256 if (!copy_placeholder_attrs_from_nodes) {
257 continue;
258 }
259 for (const auto& iter : node->attrs()) {
260 if (iter.second.placeholder().empty()) {
261 continue;
262 }
263
264 // If we already added the attribute, skip it.
265 string func_attr_name = iter.second.placeholder();
266 if (func_attr_names.find(func_attr_name) != func_attr_names.end()) {
267 continue;
268 }
269
270 // This node's attribute is a placeholder value, so it does not have type
271 // information. We check node's OpDef for attribute type.
272 string node_attr_name = iter.first;
273 const OpDef::AttrDef* node_attr_def = nullptr;
274 for (const auto& node_attr : node->op_def().attr()) {
275 if (node_attr.name() == node_attr_name) {
276 node_attr_def = &node_attr;
277 }
278 }
279 if (!node_attr_def) {
280 return errors::Unimplemented(
281 "Placeholder value is not supported for attributes not in OpDef. "
282 "Attribute: ",
283 node_attr_name, ", OpDef: ", node->op_def().DebugString());
284 }
285 OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr();
286 attr_def->set_name(func_attr_name);
287 attr_def->set_type(node_attr_def->type());
288
289 func_attr_names.insert(func_attr_name);
290 }
291 }
292 return OkStatus();
293}
294
295Status GraphToFunctionDefHelper(
296 const Graph& graph, const string& name,
297 const std::function<absl::optional<string>(const Node*)>& control_ret,
298 const std::vector<string>& output_names, FunctionDef* fdef) {
299 auto add_arg_or_retval = [](Node* node,
300 std::vector<OutputTensor>* args_or_retvals) {
301 int index;
302 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
303 if (index >= args_or_retvals->size()) {
304 args_or_retvals->resize(index + 1);
305 }
306 if ((*args_or_retvals)[index].node == nullptr) {
307 (*args_or_retvals)[index].node = node;
308 } else {
309 return errors::InvalidArgument("Multiple '", node->type_string(),
310 "' nodes found with index ", index);
311 }
312 return OkStatus();
313 };
314
315 std::vector<const Node*> body_nodes;
316 std::vector<OutputTensor> inputs;
317 std::vector<OutputTensor> outputs;
318 std::vector<const Node*> control_outputs;
319 std::vector<string> control_output_names;
320 for (Node* node : graph.op_nodes()) {
321 if (node->IsArg()) {
322 TF_RETURN_IF_ERROR(add_arg_or_retval(node, &inputs));
323 continue;
324 }
325
326 if (node->IsRetval()) {
327 TF_RETURN_IF_ERROR(add_arg_or_retval(node, &outputs));
328 continue;
329 }
330
331 if (control_ret) {
332 auto control_ret_name = control_ret(node);
333 if (control_ret_name.has_value()) {
334 control_outputs.push_back(node);
335 control_output_names.push_back(control_ret_name.value());
336 }
337 }
338
339 body_nodes.push_back(node);
340 }
341
342 auto validate_args_retvals =
343 [](const std::vector<OutputTensor>& args_or_retvals,
344 const string& op_type) {
345 for (int i = 0, e = args_or_retvals.size(); i < e; ++i) {
346 if (args_or_retvals[i].node == nullptr) {
347 return errors::InvalidArgument("Missing '", op_type,
348 "' node at index ", i);
349 }
350 }
351 return OkStatus();
352 };
353
354 TF_RETURN_IF_ERROR(validate_args_retvals(inputs, "_Arg"));
355 TF_RETURN_IF_ERROR(validate_args_retvals(outputs, "_Retval"));
356
357 return GraphToFunctionDef(graph, name, /*append_hash_to_fn_name=*/false,
358 /*set_stateful_from_nodes=*/false,
359 /*copy_placeholder_attrs_from_nodes=*/false,
360 body_nodes, inputs, outputs, output_names,
361 control_outputs, control_output_names,
362 /*description=*/nullptr, fdef);
363}
364
365} // anonymous namespace
366
367Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
368 bool append_hash_to_fn_name,
369 bool set_stateful_from_nodes,
370 bool copy_placeholder_attrs_from_nodes,
371 const std::vector<const Node*>& body_nodes,
372 const std::vector<OutputTensor>& inputs,
373 const std::vector<OutputTensor>& outputs,
374 const std::vector<string>& output_names,
375 const std::vector<const Node*>& control_outputs,
376 const std::vector<string>& control_output_names,
377 const char* description, FunctionDef* fdef) {
378 if (!output_names.empty()) {
379 DCHECK_EQ(output_names.size(), outputs.size());
380 }
381
382 if (description != nullptr) {
383 fdef->mutable_signature()->set_description(description);
384 }
385
386 // Keep track of names we used and how we normalized them.
387 NodeNameMapping node_names;
388
389 // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
390 // name we used in the function:
391 // - For input tensors:
392 // {flat_tensor_name -> normalized_name_of_src_node}
393 // e.g. {In:3 -> in}
394 // - For tensors produced by nodes in function's body:
395 // {flat_tensor_name -> nested_tensor_name}
396 // e.g. {Add:3 -> add_0:z:1}
397 std::unordered_map<string, string> tensor_renaming;
398
399 // Fill outputs in function's signature.
400 // We fill the outputs first to prevent output_names from colliding
401 // with the input names we pick below. With this order, no names are used in
402 // node_names yet, and output_names won't collide with anything (except
403 // potentially with themselves).
404 for (size_t i = 0; i < outputs.size(); ++i) {
405 const Node* node = outputs[i].node;
406 int idx = outputs[i].index;
407 OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
408 if (node->IsRetval()) {
409 argdef->set_type(node->input_type(idx));
410 } else {
411 argdef->set_type(node->output_type(idx));
412 }
413 if (!output_names.empty()) {
414 TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i]));
415 argdef->set_name(output_names[i]);
416 } else {
417 argdef->set_name(node_names.GetOutputName(node->name()));
418 }
419 }
420
421 // Fill inputs in function's signature.
422 for (size_t i = 0; i < inputs.size(); ++i) {
423 const Node* node = inputs[i].node;
424 int idx = inputs[i].index;
425 OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
426 argdef->set_type(node->output_type(idx));
427 const string& input_name = node_names.GetInputName(node->name());
428 argdef->set_name(input_name);
429 FunctionDef::ArgAttrs arg_attrs;
430 int64_t resource_arg_unique_id = -1;
431 for (const auto& attr : node->attrs()) {
432 // Only copy internal attributes. These attributes will be applied to
433 // _Arg/Placeholder nodes when this FunctionDef is converted to graph,
434 // and normal attributes for nodes cannot be applied to those
435 // _Arg/Placeholder nodes.
436 if (absl::StartsWith(attr.first, "_")) {
437 arg_attrs.mutable_attr()->insert(attr);
438 } else if (attr.first == "shape" && argdef->type() != DT_RESOURCE) {
439 // Preserve known shapes by moving them to the _output_shapes list.
440 // The _Arg shape function knows how to extract them from there.
441 // Don't preserve the shape of a resource arg node, which is a scalar
442 // resource handle.
443 AttrValue value;
444 *(value.mutable_list()->add_shape()) = attr.second.shape();
445 arg_attrs.mutable_attr()->insert({"_output_shapes", value});
446 } else if (attr.first == "value" && node->type_string() == "Const") {
447 // Small eager tensors are captured as const ops rather than
448 // Placeholders. Add a _output_shapes arg_attr with the shape of the
449 // const tensor.
450 AttrValue value;
451 *(value.mutable_list()->add_shape()) =
452 attr.second.tensor().tensor_shape();
453 arg_attrs.mutable_attr()->insert({"_output_shapes", value});
454 }
455 if (attr.first == "_resource_arg_unique_id") {
456 resource_arg_unique_id = attr.second.i();
457 }
458 }
459 if (arg_attrs.attr_size() > 0) {
460 (*fdef->mutable_arg_attr())[i] = std::move(arg_attrs);
461 }
462 if (resource_arg_unique_id >= 0) {
463 (*fdef->mutable_resource_arg_unique_id())[idx] = resource_arg_unique_id;
464 }
465 tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
466 }
467
468 // Populate tensor_renaming and node_names.
469 // Generate the new output names for every node in the function.
470 // The NodeDefs in FunctionDefs use a different naming scheme for
471 // their inputs than the NodeDefs in a graph (see the comment for
472 // FunctionDef.node_def in function.proto). We do the
473 // graph tensor name -> function tensor name conversion for every
474 // possible input (i.e. every node's outputs) and store the result
475 // in tensor_renaming.
476 for (const Node* node : body_nodes) {
477 // Make sure node_name does not collide with an input or output name.
478 const string& node_name = node_names.Uniquify(node->name());
479 // For each output_arg in the op_def, the output_ranges
480 // map will have [start, end] range of indices that this arg produces
481 // among all the output tensors of this op.
482 NameRangeMap output_ranges;
483 TF_RETURN_IF_ERROR(
484 NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
485 for (const auto& output : output_ranges) {
486 const StringPiece& output_name = output.first;
487 int index_start = output.second.first;
488 int index_end = output.second.second;
489 for (int i = index_start; i < index_end; ++i) {
490 const string& original_name = strings::StrCat(node->name(), ":", i);
491 const string& new_name =
492 strings::StrCat(node_name, ":", output_name, ":", i - index_start);
493 // Record the mapping if this tensor is not already mapped.
494 // Tensor can be already mapped if it is used as an input.
495 if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
496 tensor_renaming[original_name] = new_name;
497 }
498 }
499 }
500 }
501
502 TF_RETURN_IF_ERROR(FillFunctionBody(fn_name, node_names, body_nodes,
503 tensor_renaming, set_stateful_from_nodes,
504 copy_placeholder_attrs_from_nodes, fdef));
505
506 // Remap return values.
507 for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
508 const string& ret_name = fdef->signature().output_arg(r).name();
509 // We convert this flat tensor name to the nested value
510 // (e.g. `add:z:1`) that we stored in tensor_renaming.
511 string return_value;
512 if (outputs[r].node->IsRetval()) {
513 Edge const* edge;
514 TF_RETURN_IF_ERROR(outputs[r].node->input_edge(0, &edge));
515 return_value =
516 strings::StrCat(edge->src()->name(), ":", edge->src_output());
517 } else {
518 return_value =
519 strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
520 }
521 const auto iter = tensor_renaming.find(return_value);
522 if (iter == tensor_renaming.end()) {
523 return errors::InvalidArgument(
524 "TF_Output ", return_value, " is neither in the function body ",
525 "nor among function inputs. Encountered while creating function '",
526 fn_name, "'");
527 }
528 (*fdef->mutable_ret())[ret_name] = iter->second;
529 }
530
531 if (append_hash_to_fn_name) {
532 const uint64 hash = FunctionDefHash(*fdef);
533 string encoded;
534 TF_RETURN_IF_ERROR(Base64Encode(
535 StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)),
536 &encoded));
537 // Besides letters and digits our Base64 encoding uses '_' and '-'.
538 // Dash is invalid in operation names and multiple underscores in random
539 // places look strange. Since we never need to decode the hash back,
540 // replace these chars with 'a' and 'A'. Replacing with different letters
541 // keeps more entropy.
542 std::replace(encoded.begin(), encoded.end(), '-', 'a');
543 std::replace(encoded.begin(), encoded.end(), '_', 'A');
544 fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded));
545 } else {
546 fdef->mutable_signature()->set_name(fn_name);
547 }
548
549 if (!control_output_names.empty() &&
550 (control_outputs.size() != control_output_names.size())) {
551 return errors::InvalidArgument(
552 "Expected number of control outputs (", control_outputs.size(),
553 ") and the number of control output names (",
554 control_output_names.size(), ") to match but they do not.");
555 }
556 std::set<string> control_output_names_set;
557 for (int i = 0; i < control_outputs.size(); ++i) {
558 string signature_name;
559 if (!control_output_names.empty()) {
560 signature_name = control_output_names[i];
561 } else {
562 signature_name = control_outputs[i]->name();
563 }
564 if (signature_name.empty()) {
565 return errors::InvalidArgument("Control output name must be not empty");
566 }
567 if (!control_output_names_set.insert(signature_name).second) {
568 return errors::InvalidArgument("Repeated control output name: ",
569 signature_name);
570 }
571 const string control_output_node =
572 node_names.Lookup(control_outputs[i]->name());
573 if (control_output_node.empty()) {
574 return errors::InvalidArgument(
575 "Control output node name must be not empty");
576 }
577 (*fdef->mutable_control_ret())[signature_name] = control_output_node;
578 }
579 for (const string& control_output : control_output_names_set) {
580 fdef->mutable_signature()->add_control_output(control_output);
581 }
582
583 return OkStatus();
584}
585
586Status GraphToFunctionDef(
587 const Graph& graph, const string& name,
588 const std::function<absl::optional<string>(const Node*)>& control_ret,
589 FunctionDef* fdef) {
590 return GraphToFunctionDefHelper(graph, name, control_ret,
591 /*output_names=*/{}, fdef);
592}
593
594Status GraphToFunctionDef(const Graph& graph, const string& name,
595 FunctionDef* fdef) {
596 return GraphToFunctionDef(graph, name, /*control_ret=*/nullptr, fdef);
597}
598
599Status GraphToFunctionDef(const Graph& graph, const string& name,
600 const std::vector<std::string>& output_names,
601 FunctionDef* fdef) {
602 return GraphToFunctionDefHelper(graph, name, /*control_ret=*/nullptr,
603 output_names, fdef);
604}
605
606} // namespace tensorflow
607