1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/constant_folding.h"
17
18#include <algorithm>
19#include <atomic>
20#include <set>
21#include <unordered_map>
22#include <vector>
23
24#include "tensorflow/core/common_runtime/device_factory.h"
25#include "tensorflow/core/common_runtime/executor.h"
26#include "tensorflow/core/common_runtime/function_utils.h"
27#include "tensorflow/core/common_runtime/graph_runner.h"
28#include "tensorflow/core/common_runtime/memory_types.h"
29#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
30#include "tensorflow/core/framework/log_memory.h"
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/framework/types.h"
33#include "tensorflow/core/framework/types.pb.h"
34#include "tensorflow/core/graph/algorithm.h"
35#include "tensorflow/core/graph/node_builder.h"
36#include "tensorflow/core/graph/subgraph.h"
37#include "tensorflow/core/lib/core/threadpool.h"
38#include "tensorflow/core/lib/gtl/cleanup.h"
39#include "tensorflow/core/lib/gtl/flatset.h"
40#include "tensorflow/core/lib/strings/strcat.h"
41#include "tensorflow/core/platform/denormal.h"
42#include "tensorflow/core/platform/setround.h"
43#include "tensorflow/core/public/session_options.h"
44
45namespace tensorflow {
46
47namespace {
48
49const char kScopedAllocatorAttrName[] = "_scoped_allocator";
50
51// Test to see if the Op is one that turns into a constant when its
52// inputs' shapes are known.
53bool IsShapeOp(const Node* n) {
54 const auto& ts = n->type_string();
55 return ts == "Shape" || ts == "ShapeN" || ts == "Rank" || ts == "Size";
56}
57
58// Reads the partially-known shape of each of n's inputs from shape_map, and
59// stores it to input_shapes. Returns false if any input does not have a shape
60// in shape_map.
61bool ReadPartialShapesFromShapeMap(
62 const Node* n,
63 const std::unordered_map<string, std::vector<PartialTensorShape>>*
64 shape_map,
65 std::vector<PartialTensorShape>* input_shapes) {
66 CHECK(shape_map != nullptr);
67 input_shapes->resize(n->num_inputs());
68 for (const Edge* in : n->in_edges()) {
69 // Don't need to check if incoming control edges have known shapes.
70 if (in->IsControlEdge()) continue;
71 const auto known_shape_iter = shape_map->find(in->src()->name());
72 if (known_shape_iter == shape_map->end()) {
73 // One of n's inputs doesn't have known shapes, so don't replace n.
74 return false;
75 }
76 const auto& known_shape = known_shape_iter->second;
77 CHECK_GT(known_shape.size(), in->src_output()) << known_shape_iter->first;
78 DCHECK_GE(in->dst_input(), 0);
79 DCHECK_LT(in->dst_input(), input_shapes->size());
80 (*input_shapes)[in->dst_input()] = known_shape[in->src_output()];
81 }
82 return true;
83}
84
85// If all of n's inputs have fully-defined shapes, inserts those shapes as a
86// vector of Tensors in the shape_replacement_map.
87bool MaybeReplaceShapeOrShapeNOp(
88 const Node* n, const std::vector<PartialTensorShape>& input_shapes,
89 std::unordered_map<const Node*, std::vector<Tensor>>*
90 shape_replacement_map) {
91 std::vector<Tensor> defined_shape;
92 for (const auto& shape : input_shapes) {
93 if (!shape.IsFullyDefined()) {
94 return false;
95 }
96 const int rank = shape.dims();
97 DataType op_type = n->output_type(0);
98 Tensor t(op_type, TensorShape({rank}));
99 if (op_type == DT_INT64) {
100 auto vec = t.vec<int64_t>();
101 for (int i = 0; i < rank; ++i) {
102 vec(i) = shape.dim_size(i);
103 }
104 } else {
105 CHECK(op_type == DT_INT32);
106 auto vec = t.vec<int32>();
107 for (int i = 0; i < rank; ++i) {
108 if (shape.dim_size(i) > INT_MAX) {
109 VLOG(1) << "Node " << n->name() << " has input shape dimension " << i
110 << " of " << shape.dim_size(i) << " but type INT32 "
111 << " so not replacing as constant: this will trigger a "
112 "runtime error later.";
113 return false;
114 }
115 vec(i) = static_cast<int32>(shape.dim_size(i));
116 }
117 }
118 defined_shape.push_back(t);
119 }
120 // All the inputs had known shapes so we can replace the node by constants
121 // later in the rewrite.
122 shape_replacement_map->insert({n, defined_shape});
123 return true;
124}
125
126// If n's input has defined rank, inserts that rank as a Tensor in the
127// shape_replacement_map.
128bool MaybeReplaceRankOp(const Node* n,
129 const std::vector<PartialTensorShape>& input_shapes,
130 std::unordered_map<const Node*, std::vector<Tensor>>*
131 shape_replacement_map) {
132 CHECK_EQ(input_shapes.size(), 1);
133 if (input_shapes[0].unknown_rank()) {
134 return false;
135 }
136 Tensor t(DT_INT32, TensorShape({}));
137 t.scalar<int32>()() = input_shapes[0].dims();
138 shape_replacement_map->insert({n, {t}});
139 return true;
140}
141
142// If n's input has defined size, inserts that size as a Tensor in the
143// shape_replacement_map.
144bool MaybeReplaceSizeOp(const Node* n,
145 const std::vector<PartialTensorShape>& input_shapes,
146 std::unordered_map<const Node*, std::vector<Tensor>>*
147 shape_replacement_map) {
148 CHECK_EQ(input_shapes.size(), 1);
149 if (!input_shapes[0].IsFullyDefined()) {
150 return false;
151 }
152 DataType op_type = n->output_type(0);
153 Tensor t(op_type, TensorShape({}));
154 int64_t size = input_shapes[0].num_elements();
155 if (op_type == DT_INT64) {
156 t.scalar<int64_t>()() = size;
157 } else {
158 CHECK(op_type == DT_INT32);
159 if (size > INT_MAX) {
160 VLOG(1) << "Node " << n->name() << " has input shape size " << size
161 << " but type INT32 "
162 << " so not replacing as constant: this will trigger a runtime "
163 "error later.";
164 return false;
165 }
166 t.scalar<int32>()() = static_cast<int32>(size);
167 }
168 shape_replacement_map->insert({n, {t}});
169 return true;
170}
171
172// If n is a shape Op (Shape, ShapeN, Rank, or Size) and its inputs have their
173// shapes specified in shape_map, then adds to shape_replacement_map a mapping
174// from n to a vector of Tensors, where Tensor k is the (statically known) value
175// on n's kth output edge. shape_replacement_map has an entry for n iff
176// MaybeReplaceShapeOp returns true, so it's valid to use
177// shape_replacement_map->count(n) as a test to see if n is a shape op that can
178// be replaced.
179bool MaybeReplaceShapeOp(
180 const Node* n,
181 const std::unordered_map<string, std::vector<PartialTensorShape>>*
182 shape_map,
183 std::unordered_map<const Node*, std::vector<Tensor>>*
184 shape_replacement_map) {
185 if (shape_map == nullptr || !IsShapeOp(n)) {
186 return false;
187 }
188 // input_shapes will contain the shapes of each of n's inputs.
189 std::vector<PartialTensorShape> input_shapes;
190 if (!ReadPartialShapesFromShapeMap(n, shape_map, &input_shapes)) {
191 return false;
192 }
193 const auto& ts = n->type_string();
194 if (ts == "Shape" || ts == "ShapeN") {
195 if (!MaybeReplaceShapeOrShapeNOp(n, input_shapes, shape_replacement_map)) {
196 return false;
197 }
198 } else if (ts == "Rank") {
199 if (!MaybeReplaceRankOp(n, input_shapes, shape_replacement_map)) {
200 return false;
201 }
202 } else {
203 CHECK_EQ(ts, "Size");
204 if (!MaybeReplaceSizeOp(n, input_shapes, shape_replacement_map)) {
205 return false;
206 }
207 }
208 return true;
209}
210
211// Returns true if n can be evaluated as constant. shape_map maps from
212// nodes to the partially-known shapes of their outputs. consider if
213// non-null returns a bool indicating whether a given (non-Const,
214// non-Shape) node is eligible to be
215// constant-propagated. shape_replacement_map is filled in with a
216// vector of constant output tensors for constant-foldable shape nodes
217// (Shape, ShapeN, Size, or Rank).
218bool IsConstantFoldable(
219 const Node* n,
220 const std::unordered_map<string, std::vector<PartialTensorShape>>*
221 shape_map,
222 const std::function<bool(const Node*)>& consider,
223 int64_t max_constant_size_in_bytes,
224 std::unordered_map<const Node*, std::vector<Tensor>>*
225 shape_replacement_map) {
226 if (n->IsConstant()) {
227 // Skip constant folding resources as they cannot be deep copied.
228 return n->output_type(0) != DT_RESOURCE;
229 }
230 if (MaybeReplaceShapeOp(n, shape_map, shape_replacement_map)) {
231 return true;
232 }
233 if (n->op_def().is_stateful()) {
234 return false;
235 }
236 if (consider && !consider(n)) {
237 return false;
238 }
239 if (shape_map != nullptr) {
240 // We can skip the node if an output is known to be oversized.
241 auto shape_it = shape_map->find(n->name());
242 if (shape_it != shape_map->end()) {
243 for (int64_t i = 0; i < shape_it->second.size(); ++i) {
244 const auto& out_shape = shape_it->second[i];
245 if (out_shape.IsFullyDefined() &&
246 out_shape.num_elements() * DataTypeSize(n->output_type(i)) >
247 max_constant_size_in_bytes) {
248 return false;
249 }
250 }
251 }
252 }
253 if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
254 return false;
255 }
256 // TODO(yuanbyu): For now disable these session handle operations.
257 if (n->IsGetSessionHandle() || n->IsGetSessionTensor() ||
258 n->IsDeleteSessionTensor()) {
259 return false;
260 }
261 if (n->IsSource()) {
262 return false;
263 }
264 if (n->IsSink()) {
265 return false;
266 }
267 if (n->IsFakeParam()) {
268 return false;
269 }
270 // Since constant-folding runs on the CPU, do not attempt to constant-fold
271 // operators that have no CPU kernel. Also implies that we will not
272 // constant-fold functions.
273 // TODO(phawkins): allow constant-folding for functions; functions may
274 // be arbitrarily expensive to execute.
275 if (!KernelDefAvailable(DeviceType(DEVICE_CPU), n->def())) {
276 return false;
277 }
278 // Do not constant fold nodes which will be allocated by ScopedAllocator.
279 // This is because the constant-folding graph will not contain the
280 // `_ScopedAllocator` node, and that is necessary to be able to run a node
281 // that will use this allocator.
282 if (n->attrs().Find(kScopedAllocatorAttrName) != nullptr) {
283 VLOG(2) << "Skip node [" << n->DebugString()
284 << "] for constant folding due to scoped allocator";
285 return false;
286 }
287 return true;
288}
289
290// If n is eligible for constant-folding, adds it to nodes, and places its
291// control dependencies and those transitively of its constant-foldable inputs
292// into constant_control_deps. If n is a constant-foldable shape node (Shape,
293// ShapeN, Rank, or Size), also puts its outputs into shape_replacement_map.
294void ConsiderConstantFoldableNode(
295 Node* n, const ConstantFoldingOptions& opts, std::vector<Node*>* nodes,
296 std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
297 std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
298 bool* internal_node_inserted) {
299 if (IsConstantFoldable(n, opts.shape_map, opts.consider,
300 opts.max_constant_size_in_bytes,
301 shape_replacement_map)) {
302 // A node is constant provided all of its non-control incoming Tensors come
303 // from constant nodes, or it's a shape Op with statically known inputs in
304 // which case it is placed in shape_replacement_map.
305 //
306 // We allow control dependencies from non-constant nodes to constant nodes,
307 // but to preserve the graph structure we must transfer the control
308 // dependency onto any constant replacement.
309 bool all_parents_constant = true;
310 for (const Edge* in : n->in_edges()) {
311 // Allows non-constant -> constant control edges.
312 if (!in->IsControlEdge() &&
313 constant_control_deps->count(in->src()) == 0) {
314 all_parents_constant = false;
315 break;
316 }
317 }
318 if (all_parents_constant || shape_replacement_map->count(n) != 0) {
319 gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
320 for (const Edge* e : n->in_edges()) {
321 if (constant_control_deps->count(e->src()) == 0) {
322 // This branch is taken if the incoming edge is a control dependency,
323 // in which case we want to add it to the dependencies being
324 // accumulated for this node, or the incoming edge is not
325 // constant. The latter may happen when n is a shape node and the
326 // source has known shape. In that case add a control dependency from
327 // the source node, since there was previously a data dependency and
328 // we want to preserve sequencing constraints.
329 if (!e->src()->IsSource()) {
330 control_deps.insert(e->src());
331 }
332 } else {
333 // If the parent has been accumulating control dependencies, add all
334 // of its transitive control deps.
335 const gtl::FlatSet<Node*>& parent_deps =
336 (*constant_control_deps)[e->src()];
337 control_deps.insert(parent_deps.begin(), parent_deps.end());
338 }
339 }
340 nodes->push_back(n);
341 if (!n->IsConstant()) {
342 *internal_node_inserted = true;
343 }
344 }
345 }
346}
347
348// Returns the constant foldable nodes in `nodes` in topological order.
349// Populates `constant_control_deps` with the non-constant control dependencies
350// of each constant node.
351void FindConstantFoldableNodes(
352 const Graph* graph, const ConstantFoldingOptions& opts,
353 std::vector<Node*>* nodes,
354 std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
355 std::unordered_map<const Node*, std::vector<Tensor>>*
356 shape_replacement_map) {
357 bool internal_node_inserted = false;
358 // Walk the nodes in data flow order.
359 ReverseDFS(
360 *graph, nullptr,
361 [nodes, constant_control_deps, shape_replacement_map,
362 &internal_node_inserted, &opts](Node* n) {
363 ConsiderConstantFoldableNode(n, opts, nodes, constant_control_deps,
364 shape_replacement_map,
365 &internal_node_inserted);
366 },
367 NodeComparatorName());
368 // If we have inserted just leaf level nodes, then there is nothing to fold.
369 if (!internal_node_inserted) {
370 nodes->clear();
371 constant_control_deps->clear();
372 }
373}
374
375typedef std::pair<Node*, int> NodeAndOutput;
376
377// Adds n to constant_graph which is being built up for subsequent evaluation of
378// constant propagation. node_map is the mapping of nodes in the original graph
379// to nodes in the constant graph. The value of an entry in node_map is a vector
380// of nodes because a ShapeN node in the original graph is replaced by a vector
381// of Constant nodes in the constant graph.
382void AddNodeToConstantGraph(
383 Node* n, std::unordered_map<Node*, std::vector<Node*>>* node_map,
384 Graph* constant_graph) {
385 std::vector<Node*>& added = (*node_map)[n];
386 added.push_back(constant_graph->CopyNode(n));
387 for (const Edge* in_edge : n->in_edges()) {
388 // Don't copy control edges to the constant graph.
389 if (!in_edge->IsControlEdge()) {
390 Node* in = in_edge->src();
391 auto it = node_map->find(in);
392 CHECK(it != node_map->end())
393 << n->DebugString() << " <-" << in->DebugString();
394 if (it->second.size() == 1) {
395 constant_graph->AddEdge(it->second[0], in_edge->src_output(), added[0],
396 in_edge->dst_input());
397 } else {
398 // The original source node had multiple outputs and was replaced by a
399 // vector of constants, so the edge comes from the 0th output of the kth
400 // added constant, rather than the kth output of the added node as in
401 // the standard case above.
402 constant_graph->AddEdge(it->second[in_edge->src_output()], 0, added[0],
403 in_edge->dst_input());
404 }
405 }
406 }
407}
408
409// Replaces constant-foldable shape node n by a vector of constants in
410// constant_graph, which is being built up for subsequent evaluation of constant
411// propagation. node_map is the mapping of nodes in the original graph to nodes
412// in the constant graph. The value of an entry in node_map is a vector of nodes
413// because a ShapeN node in the original graph is replaced by a vector of
414// Constant nodes in the constant graph.
415void AddShapeNodeToConstantGraph(
416 Node* n,
417 const std::unordered_map<const Node*, std::vector<Tensor>>&
418 shape_replacement_map,
419 std::unordered_map<Node*, std::vector<Node*>>* node_map,
420 const ConstantFoldNameGenerator& generate_new_name, Graph* constant_graph) {
421 std::vector<Node*>& added = (*node_map)[n];
422 const string& node_name = n->name();
423 for (const Tensor& t : shape_replacement_map.at(n)) {
424 auto builder =
425 NodeDefBuilder(generate_new_name(constant_graph, node_name), "Const")
426 .Attr("dtype", t.dtype())
427 .Attr("value", t);
428 NodeDef def;
429 CHECK(builder.Finalize(&def).ok());
430 Node* constant_node;
431 CHECK(NodeBuilder(builder).Finalize(constant_graph, &constant_node).ok());
432 added.push_back(constant_node);
433 }
434 // Don't copy incoming edges to shape nodes that are being replaced.
435}
436
437// Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
438// will contain copies of the nodes in 'nodes'. In addition, if there is an edge
439// going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
440// 'nodes', then 'tensors_to_fetch' will contain the mapping from the
441// corresponding copy of 'n' and the edge number in 'g' to 'n'.
442Graph* GetConstantGraph(
443 const Graph* orig_graph, const std::vector<Node*>& nodes,
444 const std::unordered_map<const Node*, std::vector<Tensor>>&
445 shape_replacement_map,
446 std::map<NodeAndOutput, NodeAndOutput>* tensors_to_fetch,
447 const ConstantFoldNameGenerator& generate_new_name) {
448 Graph* constant_graph = new Graph(orig_graph->op_registry());
449 std::unordered_map<Node*, std::vector<Node*>> node_map;
450 node_map[orig_graph->source_node()] = {constant_graph->source_node()};
451 node_map[orig_graph->sink_node()] = {constant_graph->sink_node()};
452 for (Node* n : nodes) {
453 if (shape_replacement_map.count(n) == 0) {
454 AddNodeToConstantGraph(n, &node_map, constant_graph);
455 } else {
456 AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map,
457 generate_new_name, constant_graph);
458 }
459 }
460
461 for (auto const& added_nodes : node_map) {
462 for (const Edge* out_edge : added_nodes.first->out_edges()) {
463 if (node_map.count(out_edge->dst()) == 0) {
464 if (out_edge->IsControlEdge()) continue;
465 if (added_nodes.second.size() == 1) {
466 tensors_to_fetch->insert(
467 {{added_nodes.second[0], out_edge->src_output()},
468 {added_nodes.first, out_edge->src_output()}});
469 } else {
470 // The node had multiple outputs and was replaced by a
471 // vector of constants, so the NodeAndOutput is the 0th
472 // output of the kth added constant, rather than the kth
473 // output of the added node as in the standard case above.
474 tensors_to_fetch->insert(
475 {{added_nodes.second[out_edge->src_output()], 0},
476 {added_nodes.first, out_edge->src_output()}});
477 }
478 }
479 }
480 }
481
482 return constant_graph;
483}
484
485// Replaces the identified Tensor in 'graph' by a 'Const' node with
486// the value supplied in 'constant'. 'partition_device', if non-null
487// is the device where the graph executes. Returns true if the
488// replacement was successful, false otherwise.
489// 'control_deps' is the set of nodes that should be control predecessors of the
490// new constant node.
491bool ReplaceTensorWithConstant(
492 Graph* graph, const Device* partition_device, NodeAndOutput tensor,
493 const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
494 int64_t max_constant_size_in_bytes,
495 const ConstantFoldNameGenerator& generate_new_name) {
496 // Be conservative when replacing a tensor with a constant, when not
497 // running on CPU.
498 // 1) Do not replace another constant.
499 // 2) If the destination tensor or any other tensor from the same node is not
500 // an int32 tensor, and has HOST_MEMORY constraint, do not replace it.
501 // 3) If the destination tensor or any other tensor from the same node is an
502 // int32 tensor, and has DEVICE_MEMORY constraint, do not replace it.
503 // 4) If the size of the constant in bytes is too large (>
504 // max_constant_in_bytes), do not replace it. This prevents the size of the
505 // Graph from growing too large.
506 // 5) If the constant op created does not have a kernel implementation
507 // for the device, do not use it.
508 // TODO(keveman): Consider adding a new constant op that has a kernel
509 // implementation for all types, but with HostMemory constraint on it's
510 // output.
511 if (tensor.first->IsConstant()) {
512 return false;
513 }
514 DeviceType device_type = partition_device
515 ? DeviceType{partition_device->device_type()}
516 : DEVICE_CPU;
517 if (partition_device && device_type != DEVICE_CPU) {
518 MemoryTypeVector input_mvec;
519 MemoryTypeVector output_mvec;
520 if (!MemoryTypesForNode(graph->op_registry(), device_type,
521 tensor.first->def(), &input_mvec, &output_mvec)
522 .ok()) {
523 return false;
524 }
525 for (int i = 0; i < output_mvec.size(); i++) {
526 MemoryType memory_type = output_mvec[i];
527 bool is_int32 = tensor.first->output_type(i) == DT_INT32;
528 if ((memory_type == HOST_MEMORY && !is_int32) ||
529 (memory_type == DEVICE_MEMORY && is_int32)) {
530 return false;
531 }
532 }
533 }
534 if (constant.TotalBytes() > max_constant_size_in_bytes) {
535 return false;
536 }
537
538 Node* n = tensor.first;
539 std::vector<const Edge*> edges_to_remove;
540 for (const Edge* out_edge : n->out_edges()) {
541 if (out_edge->src_output() == tensor.second) {
542 edges_to_remove.push_back(out_edge);
543 }
544 }
545 const string& node_name = n->name();
546 Node* constant_node;
547 auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const")
548 .Attr("dtype", constant.dtype())
549 .Attr("value", constant);
550 if (partition_device) {
551 builder.Device(partition_device->name());
552 }
553 NodeDef def;
554 if (!builder.Finalize(&def).ok()) {
555 return false;
556 }
557 const KernelDef* kdef;
558 if (!FindKernelDef(device_type, def, &kdef, nullptr).ok()) {
559 return false;
560 }
561
562 VLOG(1) << "Replacing " << tensor.first->name() << " :: " << tensor.second
563 << " with a constant";
564
565 if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
566 return false;
567 }
568 for (auto edge : edges_to_remove) {
569 graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
570 graph->RemoveEdge(edge);
571 }
572 if (control_deps.empty()) {
573 graph->AddControlEdge(graph->source_node(), constant_node);
574 } else {
575 for (Node* node : control_deps) {
576 graph->AddControlEdge(node, constant_node);
577 }
578 }
579 if (partition_device) {
580 constant_node->set_assigned_device_name(partition_device->name());
581 }
582 return true;
583}
584
585} // namespace
586
587Status ConstantFold(const ConstantFoldingOptions& opts,
588 FunctionLibraryRuntime* function_library, Env* env,
589 const Device* partition_device, Graph* graph,
590 bool* was_mutated) {
591 // TensorFlow flushes denormals to zero and rounds to nearest, so we do
592 // the same here.
593 port::ScopedFlushDenormal flush;
594 port::ScopedSetRound round(FE_TONEAREST);
595
596 DumpGraph("Before", graph);
597
598 ConstantFoldNameGenerator generate_new_name = opts.generate_new_name;
599 std::atomic_int_fast64_t constant_unique_id{0};
600 if (generate_new_name == nullptr) {
601 generate_new_name = [&constant_unique_id](Graph* graph, string old_name) {
602 return strings::StrCat(graph->NewName(old_name), "__cf__",
603 constant_unique_id.fetch_add(1));
604 };
605 }
606
607 std::vector<Node*> constant_foldable_nodes;
608 std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
609 std::unordered_map<const Node*, std::vector<Tensor>> shape_replacement_map;
610 FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes,
611 &constant_control_deps, &shape_replacement_map);
612 if (constant_foldable_nodes.empty()) {
613 VLOG(1) << "No constant foldable nodes found";
614 *was_mutated = false;
615 // This is not an error, so return the status as OK.
616 return OkStatus();
617 }
618
619 std::map<NodeAndOutput, NodeAndOutput> tensors_to_fetch;
620 std::unique_ptr<Graph> constant_graph(
621 GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
622 &tensors_to_fetch, generate_new_name));
623 DumpGraph("Constant graph", constant_graph.get());
624
625 if (tensors_to_fetch.empty()) {
626 VLOG(1) << "No constant nodes found that feed into the original graph.";
627 *was_mutated = false;
628 // This is not an error, so return the status as OK.
629 return OkStatus();
630 }
631 VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : "
632 << graph->num_node_ids();
633
634 std::vector<string> tensors_to_fetch_names;
635 std::vector<NodeAndOutput> tensors_to_replace;
636 // Sorting the nodes based on the name gives us a stable ordering between runs
637 // for the same graph.
638 std::vector<std::pair<NodeAndOutput, NodeAndOutput>> tensors_to_fetch_sorted(
639 tensors_to_fetch.begin(), tensors_to_fetch.end());
640 std::sort(tensors_to_fetch_sorted.begin(), tensors_to_fetch_sorted.end(),
641 [](const std::pair<NodeAndOutput, NodeAndOutput>& n1,
642 const std::pair<NodeAndOutput, NodeAndOutput>& n2) {
643 return std::tie(n1.first.first->name(), n1.first.second) <
644 std::tie(n2.first.first->name(), n2.first.second);
645 });
646 for (auto n : tensors_to_fetch_sorted) {
647 tensors_to_fetch_names.push_back(
648 strings::StrCat(n.first.first->name(), ":", n.first.second));
649 tensors_to_replace.push_back(n.second);
650 }
651
652 auto graph_runner = std::unique_ptr<GraphRunner>(new GraphRunner(env));
653 // Evaluate the constant foldable nodes.
654 std::vector<Tensor> outputs;
655 auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] {
656 // Output tensors need to be cleared before the GraphRunner is deleted.
657 outputs.clear();
658 graph_runner.reset(nullptr);
659 });
660
661 Status s =
662 graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/,
663 tensors_to_fetch_names, &outputs);
664 if (!s.ok()) {
665 VLOG(1) << "Could not fetch constants: " << s;
666 *was_mutated = false;
667 return s;
668 }
669
670 // Fetch the constant tensors and replace the corresponding tensors in the
671 // original graph with those constants.
672 int32_t num_nodes_replaced = 0;
673 for (size_t c = 0; c < outputs.size(); ++c) {
674 const gtl::FlatSet<Node*>& control_deps =
675 constant_control_deps[tensors_to_replace[c].first];
676 if (ReplaceTensorWithConstant(
677 graph, partition_device, tensors_to_replace[c], outputs[c],
678 control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
679 ++num_nodes_replaced;
680 }
681 }
682
683 DumpGraph("After", graph);
684
685 *was_mutated = (num_nodes_replaced > 0);
686 return OkStatus();
687}
688
689} // namespace tensorflow
690