1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
17
18#include <algorithm>
19#include <iterator>
20#include <map>
21#include <string>
22#include <unordered_map>
23#include <unordered_set>
24#include <utility>
25#include <vector>
26
27#include "tensorflow/core/common_runtime/constant_folding.h"
28#include "tensorflow/core/common_runtime/graph_constructor.h"
29#include "tensorflow/core/common_runtime/shape_refiner.h"
30#include "tensorflow/core/graph/node_builder.h"
31#include "tensorflow/core/graph/subgraph.h"
32#include "tensorflow/core/lib/core/stringpiece.h"
33#include "tensorflow/core/lib/strings/numbers.h"
34#include "tensorflow/core/platform/init_main.h"
35#include "tensorflow/core/public/session.h"
36#include "tensorflow/tools/graph_transforms/transform_utils.h"
37
38namespace tensorflow {
39namespace graph_transforms {
40namespace {
41using StringPieceSet = std::unordered_set<StringPiece, StringPieceHasher>;
42template <typename T>
43using StringPieceMap = std::unordered_map<StringPiece, T, StringPieceHasher>;
44} // namespace
45
46Status ReplaceSendRecvs(const GraphDef& original_graph_def,
47 const GraphDef& rewritten_graph_def,
48 const std::vector<string>& inputs,
49 const std::vector<string>& outputs,
50 GraphDef* output_graph_def) {
51 // recv_node_names serves as a string storage for recv node names.
52 std::vector<string> recv_node_names(inputs.size());
53 StringPieceMap<TensorId> recv_node_map;
54 StringPieceSet input_nodes;
55 for (int i = 0; i < inputs.size(); ++i) {
56 // RewriteGraphForExecution adds a recv node for each input edge. We assume
57 // here that adding such recv node did not fail. For example, the original
58 // graph did not already have a node with the name for the new added recv
59 // node.
60 TensorId id = ParseTensorName(inputs[i]);
61 input_nodes.insert(id.first);
62 string& recv_node_name = recv_node_names[i];
63 recv_node_name = strings::StrCat("_recv_", id.first, "_", id.second);
64 recv_node_map.emplace(recv_node_name, id);
65 }
66
67 StringPieceMap<const NodeDef*> original_map;
68 for (const NodeDef& node : original_graph_def.node()) {
69 original_map.emplace(node.name(), &node);
70 }
71
72 for (const NodeDef& node : rewritten_graph_def.node()) {
73 if ((node.op() == "_Send") || (node.op() == "_Recv")) {
74 // If the op is a Send or Recv that wasn't in the original, skip it.
75 if (original_map.count(node.name()) == 0) {
76 continue;
77 }
78 }
79
80 NodeDef* new_node = output_graph_def->add_node();
81 new_node->MergeFrom(node);
82 for (int i = 0; i < new_node->input_size(); ++i) {
83 string& input = *new_node->mutable_input(i);
84 TensorId id = ParseTensorName(input);
85 const auto iter = recv_node_map.find(id.first);
86 if (iter != recv_node_map.end()) {
87 // The node being substituted is a Recv node, and it has only one
88 // output. If this input is not a control input, then replace the input
89 // with the mapped value. Otherwise, replace the node name only.
90 if (id.second != Graph::kControlSlot) {
91 CHECK_EQ(id.second, 0);
92 input = iter->second.ToString();
93 } else {
94 id.first = iter->second.first;
95 input = id.ToString();
96 }
97 }
98 }
99
100 // RewriteGraphForExecution() did not remove this input node. Remove this
101 // node name from input_nodes so that a duplicate does not get added to the
102 // output_graph_def.
103 auto iter = input_nodes.find(new_node->name());
104 if (iter != input_nodes.end()) {
105 input_nodes.erase(iter);
106 }
107 }
108
109 // Some input nodes are removed in rewrite_graph_def. Add those nodes to
110 // output_graph_def.
111 for (StringPiece name : input_nodes) {
112 const NodeDef& removed_node = *CHECK_NOTNULL(original_map[name]);
113 output_graph_def->add_node()->MergeFrom(removed_node);
114 }
115
116 return OkStatus();
117}
118
119Status RewriteInputsAsPlaceholders(const TransformFuncContext& context,
120 GraphDef* graph_def) {
121 std::unordered_set<string> input_names;
122 for (const string& input_name : context.input_names) {
123 input_names.emplace(ParseTensorName(input_name).first);
124 }
125
126 for (NodeDef& node : *graph_def->mutable_node()) {
127 if (input_names.find(node.name()) == input_names.end()) {
128 continue;
129 }
130 if (node.op() == "PlaceholderWithDefault") {
131 node.set_op("Placeholder");
132 node.clear_input();
133 } else if (node.op() != "Placeholder") {
134 return errors::InvalidArgument(
135 "Input '", node.name(),
136 "' was expected to be a Placeholder or PlaceholderWithDefault op, "
137 "but was ",
138 node.op());
139 }
140 }
141 return OkStatus();
142}
143
144Status RemoveUnusedNodes(const GraphDef& input_graph_def,
145 const TransformFuncContext& context,
146 GraphDef* output_graph_def) {
147 StringPieceMap<const NodeDef*> node_map;
148 for (const NodeDef& node : input_graph_def.node()) {
149 node_map.emplace(node.name(), &node);
150 }
151
152 std::unordered_set<TensorId, TensorId::Hasher> input_names;
153 for (const string& input : context.input_names) {
154 input_names.insert(ParseTensorName(input));
155 }
156 StringPieceSet used_nodes;
157 StringPieceSet current_nodes;
158 for (const string& name : context.output_names) {
159 TensorId id = ParseTensorName(name);
160 used_nodes.insert(id.first);
161 current_nodes.insert(id.first);
162 }
163 while (!current_nodes.empty()) {
164 StringPieceSet next_nodes;
165 for (StringPiece node_name : current_nodes) {
166 if (node_map.count(node_name) == 0) {
167 LOG(ERROR) << "Bad graph structure, no node named '" << node_name
168 << "' found for input lookup";
169 return errors::InvalidArgument("Bad graph structure, no node named '",
170 node_name, "' found for input lookup");
171 }
172 const NodeDef& node = *(node_map[node_name]);
173 for (const string& input : node.input()) {
174 TensorId id = ParseTensorName(input);
175 if (input_names.count(id) > 0) {
176 continue;
177 }
178 if (used_nodes.insert(id.first).second) {
179 next_nodes.insert(id.first);
180 }
181 }
182 }
183 current_nodes.swap(next_nodes);
184 }
185 for (const TensorId& id : input_names) {
186 used_nodes.insert(id.first);
187 }
188 FilterGraphDef(
189 input_graph_def,
190 [&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; },
191 output_graph_def);
192 TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def));
193
194 return OkStatus();
195}
196
197// Converts a shape inference handle to a PartialTensorShape.
198Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle,
199 shape_inference::InferenceContext* context,
200 PartialTensorShape* shape) {
201 // The default is already unknown.
202 if (!context->RankKnown(handle)) return OkStatus();
203
204 std::vector<int64_t> dims(context->Rank(handle));
205 for (int32_t i = 0; i < dims.size(); ++i) {
206 dims[i] = context->Value(context->Dim(handle, i));
207 }
208 return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
209}
210
211// Converts any sub-graphs that can be resolved into constant expressions into
212// single Const ops.
213Status FoldConstants(const GraphDef& input_graph_def,
214 const TransformFuncContext& context,
215 GraphDef* output_graph_def) {
216 Graph input_graph(OpRegistry::Global());
217 TF_RETURN_IF_ERROR(input_graph.AddFunctionLibrary(input_graph_def.library()));
218
219 ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry());
220 shape_refiner.set_require_shape_inference_fns(false);
221 shape_refiner.set_disable_constant_propagation(false);
222 shape_refiner.set_function_library_for_shape_inference(
223 &input_graph.flib_def());
224
225 bool clear_output_shapes;
226 TF_RETURN_IF_ERROR(context.GetOneBoolParameter("clear_output_shapes", true,
227 &clear_output_shapes));
228 if (clear_output_shapes) {
229 // Some older GraphDefs have saved _output_shapes attributes which are out
230 // of date and cause import errors, so clean them up first.
231 GraphDef cleaned_graph_def;
232 RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def);
233
234 TF_RETURN_IF_ERROR(
235 ImportGraphDef({}, cleaned_graph_def, &input_graph, &shape_refiner));
236 } else {
237 TF_RETURN_IF_ERROR(
238 ImportGraphDef({}, input_graph_def, &input_graph, &shape_refiner));
239 }
240
241 // Sorted array of input names as lookup table.
242 std::vector<TensorId> input_names;
243 input_names.reserve(context.input_names.size());
244 std::transform(context.input_names.begin(), context.input_names.end(),
245 std::back_inserter(input_names),
246 [](const string& name) { return ParseTensorName(name); });
247
248 const auto compare = [](TensorId lhs, TensorId rhs) {
249 return lhs.first < rhs.first;
250 };
251
252 std::sort(input_names.begin(), input_names.end(), compare);
253
254 // Set statically inferred shapes.
255 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
256 for (const Node* const node : input_graph.nodes()) {
257 auto ctx = shape_refiner.GetContext(node);
258 if (ctx == nullptr) {
259 continue;
260 }
261
262 std::vector<PartialTensorShape>& partial_shapes = shape_map[node->name()];
263 if (ctx->num_outputs() <= 0) continue;
264 partial_shapes.resize(ctx->num_outputs());
265
266 // Check all outputs.
267 for (const Edge* out_edge : node->out_edges()) {
268 if (out_edge->IsControlEdge()) continue;
269
270 const int output_idx = out_edge->src_output();
271 TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(ctx->output(output_idx), ctx,
272 &partial_shapes[output_idx]));
273 }
274
275 // RewriteGraphForExecution() will add a Recv node for each input. Shape
276 // refiner does not include shape information of these Recv nodes. Therefore
277 // we add entries for Recv nodes here.
278 const auto pair = std::equal_range(input_names.begin(), input_names.end(),
279 TensorId{node->name(), 0}, compare);
280 for (auto it = pair.first; it != pair.second; ++it) {
281 const string recv_name =
282 strings::StrCat("_recv_", it->first, "_", it->second);
283 auto& recv_partial_shapes = shape_map[recv_name];
284 // For whatever reason (for example, name collision) if the map entry was
285 // already there, then do nothing.
286 if (recv_partial_shapes.empty()) {
287 recv_partial_shapes.push_back(partial_shapes[it->second]);
288 }
289 }
290 }
291
292 subgraph::RewriteGraphMetadata unused_metadata;
293 TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
294 &input_graph, context.input_names, context.output_names, {}, {},
295 false /* use_function_convention */, &unused_metadata));
296
297 ConstantFoldingOptions cf_opts;
298 cf_opts.shape_map = &shape_map;
299
300 // Exclude specified nodes from constant folding.
301 std::set<string> excluded_ops, excluded_nodes;
302 if (context.params.count("exclude_op") > 0) {
303 const auto& ops = context.params.at("exclude_op");
304 excluded_ops = std::set<string>(ops.begin(), ops.end());
305 }
306 if (context.params.count("exclude_node") > 0) {
307 const auto& nodes = context.params.at("exclude_node");
308 excluded_nodes = std::set<string>(nodes.begin(), nodes.end());
309 }
310 if (!excluded_ops.empty() || !excluded_nodes.empty()) {
311 cf_opts.consider = [excluded_ops, excluded_nodes](const Node* n) {
312 return excluded_ops.find(n->op_def().name()) == excluded_ops.end() &&
313 excluded_nodes.find(n->name()) == excluded_nodes.end();
314 };
315 }
316
317 TF_RETURN_IF_ERROR(context.GetOneInt64Parameter(
318 "max_constant_size_in_bytes", cf_opts.max_constant_size_in_bytes,
319 &cf_opts.max_constant_size_in_bytes));
320
321 // Constant folding.
322 bool was_mutated;
323 TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr,
324 &input_graph, &was_mutated));
325 GraphDef folded_graph_def;
326 input_graph.ToGraphDef(&folded_graph_def);
327 GraphDef send_recvs_replaced;
328 TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def,
329 context.input_names, context.output_names,
330 &send_recvs_replaced));
331 TF_RETURN_IF_ERROR(
332 RemoveUnusedNodes(send_recvs_replaced, context, output_graph_def));
333 return OkStatus();
334}
335
336REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
337
338} // namespace graph_transforms
339} // namespace tensorflow
340