1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "absl/algorithm/container.h"
19#include "absl/container/flat_hash_map.h"
20#include "absl/strings/str_cat.h"
21#include "absl/strings/string_view.h"
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/function.pb.h"
24#include "tensorflow/core/framework/graph.pb.h"
25#include "tensorflow/core/framework/node_def.pb.h"
26#include "tensorflow/core/framework/op_def.pb.h"
27#include "tensorflow/tools/graph_transforms/transform_utils.h"
28
29namespace tensorflow {
30namespace graph_transforms {
31
32constexpr char kPartitionedCallOpName[] = "PartitionedCall";
33constexpr char kFunctionAttrName[] = "f";
34
35namespace {
36absl::optional<FunctionDef> GetFunctionByNameFromLibrary(
37 const GraphDef& graph, absl::string_view function_name) {
38 for (const auto& fct : graph.library().function()) {
39 if (fct.signature().name() == function_name) {
40 return fct;
41 }
42 }
43 return {};
44}
45
46std::string NormalizeNodeDefInput(const std::string& input_name) {
47 std::vector<std::string> name_parts =
48 absl::StrSplit(input_name, absl::ByChar(':'));
49 if (name_parts.size() > 2) {
50 return absl::StrCat(name_parts[0], ":", name_parts.back());
51 }
52 return input_name;
53}
54
55} // namespace
56
57Status InlinePartitionedCall(const GraphDef& input_graph_def,
58 const TransformFuncContext& context,
59 GraphDef* output_graph_def) {
60 output_graph_def->Clear();
61 absl::flat_hash_map<std::string, std::string> remap_input;
62
63 for (const NodeDef& node : input_graph_def.node()) {
64 if (node.op() == kPartitionedCallOpName) {
65 if (node.attr().count(kFunctionAttrName) == 0) {
66 return Status(
67 error::Code::NOT_FOUND,
68 "Node " + node.name() + " has no attribute: " + kFunctionAttrName);
69 }
70
71 if (!node.attr().at(kFunctionAttrName).has_func()) {
72 return Status(error::Code::NOT_FOUND,
73 "Cannot figure out function name");
74 }
75 const std::string function_name =
76 node.attr().at(kFunctionAttrName).func().name();
77 absl::optional<FunctionDef> function =
78 GetFunctionByNameFromLibrary(input_graph_def, function_name);
79 if (!function.has_value()) {
80 return Status(error::Code::NOT_FOUND,
81 "function " + function_name + " Not found");
82 }
83
84 const std::string prefix = node.name();
85
86 const int kOutputArgumentCount =
87 function->signature().output_arg().size();
88 for (int k = 0; k < kOutputArgumentCount; ++k) {
89 const std::string function_arg_output_name =
90 function->ret().at(function->signature().output_arg()[k].name());
91 remap_input.insert_or_assign(
92 CanonicalInputName(absl::StrCat(node.name(), ":", k)),
93 absl::StrCat(prefix, "/",
94 NormalizeNodeDefInput(function_arg_output_name)));
95 }
96
97 const int kInputArgumentCount = function->signature().input_arg().size();
98 if (node.input().size() != kInputArgumentCount) {
99 return Status(error::Code::INVALID_ARGUMENT,
100 "Called function " + function_name +
101 " has invalid input signature.");
102 }
103 absl::flat_hash_map<std::string, std::string> input_argument_map;
104 for (int k = 0; k < kInputArgumentCount; ++k) {
105 const std::string canonical_name =
106 CanonicalInputName(function->signature().input_arg()[k].name());
107 input_argument_map.insert_or_assign(canonical_name, node.input()[k]);
108 }
109
110 for (const NodeDef& function_node : function->node_def()) {
111 NodeDef* new_node = output_graph_def->mutable_node()->Add();
112 *new_node = function_node;
113 new_node->set_name(absl::StrCat(prefix, "/", function_node.name()));
114 absl::c_transform(
115 *new_node->mutable_input(), new_node->mutable_input()->begin(),
116 [prefix, input_argument_map](const std::string& input_name) {
117 const std::string canonical_input_name =
118 CanonicalInputName(input_name);
119 if (input_argument_map.find(canonical_input_name) !=
120 input_argument_map.end()) {
121 return input_argument_map.at(canonical_input_name);
122 }
123 return absl::StrCat(prefix, "/",
124 NormalizeNodeDefInput(input_name));
125 });
126 }
127 } else {
128 NodeDef* new_node = output_graph_def->mutable_node()->Add();
129 *new_node = node;
130 }
131 }
132
133 // Remap PartitionCall outputs to correct nodes.
134 for (NodeDef& node : *output_graph_def->mutable_node()) {
135 absl::c_transform(
136 *node.mutable_input(), node.mutable_input()->begin(),
137 [remap_input](const std::string& input_name) {
138 const std::string canonical_input_name =
139 CanonicalInputName(input_name);
140 if (remap_input.find(canonical_input_name) != remap_input.end()) {
141 return remap_input.at(canonical_input_name);
142 }
143 return input_name;
144 });
145 }
146 return OkStatus();
147}
148
149REGISTER_GRAPH_TRANSFORM("inline_partitionedcall", InlinePartitionedCall);
150} // namespace graph_transforms
151} // namespace tensorflow
152