1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" |
17 | |
18 | #include <algorithm> |
19 | #include <iterator> |
20 | #include <map> |
21 | #include <string> |
22 | #include <unordered_map> |
23 | #include <unordered_set> |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | #include "tensorflow/core/common_runtime/constant_folding.h" |
28 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
29 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
30 | #include "tensorflow/core/graph/node_builder.h" |
31 | #include "tensorflow/core/graph/subgraph.h" |
32 | #include "tensorflow/core/lib/core/stringpiece.h" |
33 | #include "tensorflow/core/lib/strings/numbers.h" |
34 | #include "tensorflow/core/platform/init_main.h" |
35 | #include "tensorflow/core/public/session.h" |
36 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
37 | |
38 | namespace tensorflow { |
39 | namespace graph_transforms { |
40 | namespace { |
41 | using StringPieceSet = std::unordered_set<StringPiece, StringPieceHasher>; |
42 | template <typename T> |
43 | using StringPieceMap = std::unordered_map<StringPiece, T, StringPieceHasher>; |
44 | } // namespace |
45 | |
46 | Status ReplaceSendRecvs(const GraphDef& original_graph_def, |
47 | const GraphDef& rewritten_graph_def, |
48 | const std::vector<string>& inputs, |
49 | const std::vector<string>& outputs, |
50 | GraphDef* output_graph_def) { |
51 | // recv_node_names serves as a string storage for recv node names. |
52 | std::vector<string> recv_node_names(inputs.size()); |
53 | StringPieceMap<TensorId> recv_node_map; |
54 | StringPieceSet input_nodes; |
55 | for (int i = 0; i < inputs.size(); ++i) { |
56 | // RewriteGraphForExecution adds a recv node for each input edge. We assume |
57 | // here that adding such recv node did not fail. For example, the original |
58 | // graph did not already have a node with the name for the new added recv |
59 | // node. |
60 | TensorId id = ParseTensorName(inputs[i]); |
61 | input_nodes.insert(id.first); |
62 | string& recv_node_name = recv_node_names[i]; |
63 | recv_node_name = strings::StrCat("_recv_" , id.first, "_" , id.second); |
64 | recv_node_map.emplace(recv_node_name, id); |
65 | } |
66 | |
67 | StringPieceMap<const NodeDef*> original_map; |
68 | for (const NodeDef& node : original_graph_def.node()) { |
69 | original_map.emplace(node.name(), &node); |
70 | } |
71 | |
72 | for (const NodeDef& node : rewritten_graph_def.node()) { |
73 | if ((node.op() == "_Send" ) || (node.op() == "_Recv" )) { |
74 | // If the op is a Send or Recv that wasn't in the original, skip it. |
75 | if (original_map.count(node.name()) == 0) { |
76 | continue; |
77 | } |
78 | } |
79 | |
80 | NodeDef* new_node = output_graph_def->add_node(); |
81 | new_node->MergeFrom(node); |
82 | for (int i = 0; i < new_node->input_size(); ++i) { |
83 | string& input = *new_node->mutable_input(i); |
84 | TensorId id = ParseTensorName(input); |
85 | const auto iter = recv_node_map.find(id.first); |
86 | if (iter != recv_node_map.end()) { |
87 | // The node being substituted is a Recv node, and it has only one |
88 | // output. If this input is not a control input, then replace the input |
89 | // with the mapped value. Otherwise, replace the node name only. |
90 | if (id.second != Graph::kControlSlot) { |
91 | CHECK_EQ(id.second, 0); |
92 | input = iter->second.ToString(); |
93 | } else { |
94 | id.first = iter->second.first; |
95 | input = id.ToString(); |
96 | } |
97 | } |
98 | } |
99 | |
100 | // RewriteGraphForExecution() did not remove this input node. Remove this |
101 | // node name from input_nodes so that a duplicate does not get added to the |
102 | // output_graph_def. |
103 | auto iter = input_nodes.find(new_node->name()); |
104 | if (iter != input_nodes.end()) { |
105 | input_nodes.erase(iter); |
106 | } |
107 | } |
108 | |
109 | // Some input nodes are removed in rewrite_graph_def. Add those nodes to |
110 | // output_graph_def. |
111 | for (StringPiece name : input_nodes) { |
112 | const NodeDef& removed_node = *CHECK_NOTNULL(original_map[name]); |
113 | output_graph_def->add_node()->MergeFrom(removed_node); |
114 | } |
115 | |
116 | return OkStatus(); |
117 | } |
118 | |
119 | Status RewriteInputsAsPlaceholders(const TransformFuncContext& context, |
120 | GraphDef* graph_def) { |
121 | std::unordered_set<string> input_names; |
122 | for (const string& input_name : context.input_names) { |
123 | input_names.emplace(ParseTensorName(input_name).first); |
124 | } |
125 | |
126 | for (NodeDef& node : *graph_def->mutable_node()) { |
127 | if (input_names.find(node.name()) == input_names.end()) { |
128 | continue; |
129 | } |
130 | if (node.op() == "PlaceholderWithDefault" ) { |
131 | node.set_op("Placeholder" ); |
132 | node.clear_input(); |
133 | } else if (node.op() != "Placeholder" ) { |
134 | return errors::InvalidArgument( |
135 | "Input '" , node.name(), |
136 | "' was expected to be a Placeholder or PlaceholderWithDefault op, " |
137 | "but was " , |
138 | node.op()); |
139 | } |
140 | } |
141 | return OkStatus(); |
142 | } |
143 | |
144 | Status RemoveUnusedNodes(const GraphDef& input_graph_def, |
145 | const TransformFuncContext& context, |
146 | GraphDef* output_graph_def) { |
147 | StringPieceMap<const NodeDef*> node_map; |
148 | for (const NodeDef& node : input_graph_def.node()) { |
149 | node_map.emplace(node.name(), &node); |
150 | } |
151 | |
152 | std::unordered_set<TensorId, TensorId::Hasher> input_names; |
153 | for (const string& input : context.input_names) { |
154 | input_names.insert(ParseTensorName(input)); |
155 | } |
156 | StringPieceSet used_nodes; |
157 | StringPieceSet current_nodes; |
158 | for (const string& name : context.output_names) { |
159 | TensorId id = ParseTensorName(name); |
160 | used_nodes.insert(id.first); |
161 | current_nodes.insert(id.first); |
162 | } |
163 | while (!current_nodes.empty()) { |
164 | StringPieceSet next_nodes; |
165 | for (StringPiece node_name : current_nodes) { |
166 | if (node_map.count(node_name) == 0) { |
167 | LOG(ERROR) << "Bad graph structure, no node named '" << node_name |
168 | << "' found for input lookup" ; |
169 | return errors::InvalidArgument("Bad graph structure, no node named '" , |
170 | node_name, "' found for input lookup" ); |
171 | } |
172 | const NodeDef& node = *(node_map[node_name]); |
173 | for (const string& input : node.input()) { |
174 | TensorId id = ParseTensorName(input); |
175 | if (input_names.count(id) > 0) { |
176 | continue; |
177 | } |
178 | if (used_nodes.insert(id.first).second) { |
179 | next_nodes.insert(id.first); |
180 | } |
181 | } |
182 | } |
183 | current_nodes.swap(next_nodes); |
184 | } |
185 | for (const TensorId& id : input_names) { |
186 | used_nodes.insert(id.first); |
187 | } |
188 | FilterGraphDef( |
189 | input_graph_def, |
190 | [&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; }, |
191 | output_graph_def); |
192 | TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def)); |
193 | |
194 | return OkStatus(); |
195 | } |
196 | |
197 | // Converts a shape inference handle to a PartialTensorShape. |
198 | Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle, |
199 | shape_inference::InferenceContext* context, |
200 | PartialTensorShape* shape) { |
201 | // The default is already unknown. |
202 | if (!context->RankKnown(handle)) return OkStatus(); |
203 | |
204 | std::vector<int64_t> dims(context->Rank(handle)); |
205 | for (int32_t i = 0; i < dims.size(); ++i) { |
206 | dims[i] = context->Value(context->Dim(handle, i)); |
207 | } |
208 | return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); |
209 | } |
210 | |
211 | // Converts any sub-graphs that can be resolved into constant expressions into |
212 | // single Const ops. |
213 | Status FoldConstants(const GraphDef& input_graph_def, |
214 | const TransformFuncContext& context, |
215 | GraphDef* output_graph_def) { |
216 | Graph input_graph(OpRegistry::Global()); |
217 | TF_RETURN_IF_ERROR(input_graph.AddFunctionLibrary(input_graph_def.library())); |
218 | |
219 | ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry()); |
220 | shape_refiner.set_require_shape_inference_fns(false); |
221 | shape_refiner.set_disable_constant_propagation(false); |
222 | shape_refiner.set_function_library_for_shape_inference( |
223 | &input_graph.flib_def()); |
224 | |
225 | bool clear_output_shapes; |
226 | TF_RETURN_IF_ERROR(context.GetOneBoolParameter("clear_output_shapes" , true, |
227 | &clear_output_shapes)); |
228 | if (clear_output_shapes) { |
229 | // Some older GraphDefs have saved _output_shapes attributes which are out |
230 | // of date and cause import errors, so clean them up first. |
231 | GraphDef cleaned_graph_def; |
232 | RemoveAttributes(input_graph_def, {"_output_shapes" }, &cleaned_graph_def); |
233 | |
234 | TF_RETURN_IF_ERROR( |
235 | ImportGraphDef({}, cleaned_graph_def, &input_graph, &shape_refiner)); |
236 | } else { |
237 | TF_RETURN_IF_ERROR( |
238 | ImportGraphDef({}, input_graph_def, &input_graph, &shape_refiner)); |
239 | } |
240 | |
241 | // Sorted array of input names as lookup table. |
242 | std::vector<TensorId> input_names; |
243 | input_names.reserve(context.input_names.size()); |
244 | std::transform(context.input_names.begin(), context.input_names.end(), |
245 | std::back_inserter(input_names), |
246 | [](const string& name) { return ParseTensorName(name); }); |
247 | |
248 | const auto compare = [](TensorId lhs, TensorId rhs) { |
249 | return lhs.first < rhs.first; |
250 | }; |
251 | |
252 | std::sort(input_names.begin(), input_names.end(), compare); |
253 | |
254 | // Set statically inferred shapes. |
255 | std::unordered_map<string, std::vector<PartialTensorShape>> shape_map; |
256 | for (const Node* const node : input_graph.nodes()) { |
257 | auto ctx = shape_refiner.GetContext(node); |
258 | if (ctx == nullptr) { |
259 | continue; |
260 | } |
261 | |
262 | std::vector<PartialTensorShape>& partial_shapes = shape_map[node->name()]; |
263 | if (ctx->num_outputs() <= 0) continue; |
264 | partial_shapes.resize(ctx->num_outputs()); |
265 | |
266 | // Check all outputs. |
267 | for (const Edge* out_edge : node->out_edges()) { |
268 | if (out_edge->IsControlEdge()) continue; |
269 | |
270 | const int output_idx = out_edge->src_output(); |
271 | TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(ctx->output(output_idx), ctx, |
272 | &partial_shapes[output_idx])); |
273 | } |
274 | |
275 | // RewriteGraphForExecution() will add a Recv node for each input. Shape |
276 | // refiner does not include shape information of these Recv nodes. Therefore |
277 | // we add entries for Recv nodes here. |
278 | const auto pair = std::equal_range(input_names.begin(), input_names.end(), |
279 | TensorId{node->name(), 0}, compare); |
280 | for (auto it = pair.first; it != pair.second; ++it) { |
281 | const string recv_name = |
282 | strings::StrCat("_recv_" , it->first, "_" , it->second); |
283 | auto& recv_partial_shapes = shape_map[recv_name]; |
284 | // For whatever reason (for example, name collision) if the map entry was |
285 | // already there, then do nothing. |
286 | if (recv_partial_shapes.empty()) { |
287 | recv_partial_shapes.push_back(partial_shapes[it->second]); |
288 | } |
289 | } |
290 | } |
291 | |
292 | subgraph::RewriteGraphMetadata unused_metadata; |
293 | TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( |
294 | &input_graph, context.input_names, context.output_names, {}, {}, |
295 | false /* use_function_convention */, &unused_metadata)); |
296 | |
297 | ConstantFoldingOptions cf_opts; |
298 | cf_opts.shape_map = &shape_map; |
299 | |
300 | // Exclude specified nodes from constant folding. |
301 | std::set<string> excluded_ops, excluded_nodes; |
302 | if (context.params.count("exclude_op" ) > 0) { |
303 | const auto& ops = context.params.at("exclude_op" ); |
304 | excluded_ops = std::set<string>(ops.begin(), ops.end()); |
305 | } |
306 | if (context.params.count("exclude_node" ) > 0) { |
307 | const auto& nodes = context.params.at("exclude_node" ); |
308 | excluded_nodes = std::set<string>(nodes.begin(), nodes.end()); |
309 | } |
310 | if (!excluded_ops.empty() || !excluded_nodes.empty()) { |
311 | cf_opts.consider = [excluded_ops, excluded_nodes](const Node* n) { |
312 | return excluded_ops.find(n->op_def().name()) == excluded_ops.end() && |
313 | excluded_nodes.find(n->name()) == excluded_nodes.end(); |
314 | }; |
315 | } |
316 | |
317 | TF_RETURN_IF_ERROR(context.GetOneInt64Parameter( |
318 | "max_constant_size_in_bytes" , cf_opts.max_constant_size_in_bytes, |
319 | &cf_opts.max_constant_size_in_bytes)); |
320 | |
321 | // Constant folding. |
322 | bool was_mutated; |
323 | TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr, |
324 | &input_graph, &was_mutated)); |
325 | GraphDef folded_graph_def; |
326 | input_graph.ToGraphDef(&folded_graph_def); |
327 | GraphDef send_recvs_replaced; |
328 | TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def, |
329 | context.input_names, context.output_names, |
330 | &send_recvs_replaced)); |
331 | TF_RETURN_IF_ERROR( |
332 | RemoveUnusedNodes(send_recvs_replaced, context, output_graph_def)); |
333 | return OkStatus(); |
334 | } |
335 | |
336 | REGISTER_GRAPH_TRANSFORM("fold_constants" , FoldConstants); |
337 | |
338 | } // namespace graph_transforms |
339 | } // namespace tensorflow |
340 | |