1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "absl/algorithm/container.h" |
19 | #include "absl/container/flat_hash_map.h" |
20 | #include "absl/strings/str_cat.h" |
21 | #include "absl/strings/string_view.h" |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/framework/function.pb.h" |
24 | #include "tensorflow/core/framework/graph.pb.h" |
25 | #include "tensorflow/core/framework/node_def.pb.h" |
26 | #include "tensorflow/core/framework/op_def.pb.h" |
27 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
28 | |
29 | namespace tensorflow { |
30 | namespace graph_transforms { |
31 | |
32 | constexpr char kPartitionedCallOpName[] = "PartitionedCall" ; |
33 | constexpr char kFunctionAttrName[] = "f" ; |
34 | |
35 | namespace { |
36 | absl::optional<FunctionDef> GetFunctionByNameFromLibrary( |
37 | const GraphDef& graph, absl::string_view function_name) { |
38 | for (const auto& fct : graph.library().function()) { |
39 | if (fct.signature().name() == function_name) { |
40 | return fct; |
41 | } |
42 | } |
43 | return {}; |
44 | } |
45 | |
46 | std::string NormalizeNodeDefInput(const std::string& input_name) { |
47 | std::vector<std::string> name_parts = |
48 | absl::StrSplit(input_name, absl::ByChar(':')); |
49 | if (name_parts.size() > 2) { |
50 | return absl::StrCat(name_parts[0], ":" , name_parts.back()); |
51 | } |
52 | return input_name; |
53 | } |
54 | |
55 | } // namespace |
56 | |
57 | Status InlinePartitionedCall(const GraphDef& input_graph_def, |
58 | const TransformFuncContext& context, |
59 | GraphDef* output_graph_def) { |
60 | output_graph_def->Clear(); |
61 | absl::flat_hash_map<std::string, std::string> remap_input; |
62 | |
63 | for (const NodeDef& node : input_graph_def.node()) { |
64 | if (node.op() == kPartitionedCallOpName) { |
65 | if (node.attr().count(kFunctionAttrName) == 0) { |
66 | return Status( |
67 | error::Code::NOT_FOUND, |
68 | "Node " + node.name() + " has no attribute: " + kFunctionAttrName); |
69 | } |
70 | |
71 | if (!node.attr().at(kFunctionAttrName).has_func()) { |
72 | return Status(error::Code::NOT_FOUND, |
73 | "Cannot figure out function name" ); |
74 | } |
75 | const std::string function_name = |
76 | node.attr().at(kFunctionAttrName).func().name(); |
77 | absl::optional<FunctionDef> function = |
78 | GetFunctionByNameFromLibrary(input_graph_def, function_name); |
79 | if (!function.has_value()) { |
80 | return Status(error::Code::NOT_FOUND, |
81 | "function " + function_name + " Not found" ); |
82 | } |
83 | |
84 | const std::string prefix = node.name(); |
85 | |
86 | const int kOutputArgumentCount = |
87 | function->signature().output_arg().size(); |
88 | for (int k = 0; k < kOutputArgumentCount; ++k) { |
89 | const std::string function_arg_output_name = |
90 | function->ret().at(function->signature().output_arg()[k].name()); |
91 | remap_input.insert_or_assign( |
92 | CanonicalInputName(absl::StrCat(node.name(), ":" , k)), |
93 | absl::StrCat(prefix, "/" , |
94 | NormalizeNodeDefInput(function_arg_output_name))); |
95 | } |
96 | |
97 | const int kInputArgumentCount = function->signature().input_arg().size(); |
98 | if (node.input().size() != kInputArgumentCount) { |
99 | return Status(error::Code::INVALID_ARGUMENT, |
100 | "Called function " + function_name + |
101 | " has invalid input signature." ); |
102 | } |
103 | absl::flat_hash_map<std::string, std::string> input_argument_map; |
104 | for (int k = 0; k < kInputArgumentCount; ++k) { |
105 | const std::string canonical_name = |
106 | CanonicalInputName(function->signature().input_arg()[k].name()); |
107 | input_argument_map.insert_or_assign(canonical_name, node.input()[k]); |
108 | } |
109 | |
110 | for (const NodeDef& function_node : function->node_def()) { |
111 | NodeDef* new_node = output_graph_def->mutable_node()->Add(); |
112 | *new_node = function_node; |
113 | new_node->set_name(absl::StrCat(prefix, "/" , function_node.name())); |
114 | absl::c_transform( |
115 | *new_node->mutable_input(), new_node->mutable_input()->begin(), |
116 | [prefix, input_argument_map](const std::string& input_name) { |
117 | const std::string canonical_input_name = |
118 | CanonicalInputName(input_name); |
119 | if (input_argument_map.find(canonical_input_name) != |
120 | input_argument_map.end()) { |
121 | return input_argument_map.at(canonical_input_name); |
122 | } |
123 | return absl::StrCat(prefix, "/" , |
124 | NormalizeNodeDefInput(input_name)); |
125 | }); |
126 | } |
127 | } else { |
128 | NodeDef* new_node = output_graph_def->mutable_node()->Add(); |
129 | *new_node = node; |
130 | } |
131 | } |
132 | |
133 | // Remap PartitionCall outputs to correct nodes. |
134 | for (NodeDef& node : *output_graph_def->mutable_node()) { |
135 | absl::c_transform( |
136 | *node.mutable_input(), node.mutable_input()->begin(), |
137 | [remap_input](const std::string& input_name) { |
138 | const std::string canonical_input_name = |
139 | CanonicalInputName(input_name); |
140 | if (remap_input.find(canonical_input_name) != remap_input.end()) { |
141 | return remap_input.at(canonical_input_name); |
142 | } |
143 | return input_name; |
144 | }); |
145 | } |
146 | return OkStatus(); |
147 | } |
148 | |
149 | REGISTER_GRAPH_TRANSFORM("inline_partitionedcall" , InlinePartitionedCall); |
150 | } // namespace graph_transforms |
151 | } // namespace tensorflow |
152 | |