1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/eval_const_tensor.h"
17
18#include <deque>
19
20#include "tensorflow/core/common_runtime/graph_runner.h"
21#include "tensorflow/core/common_runtime/shape_refiner.h"
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/shape_inference.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/versions.pb.h"
27#include "tensorflow/core/graph/graph.h"
28
29namespace tensorflow {
30namespace {
31
32using ::tensorflow::shape_inference::InferenceContext;
33using ::tensorflow::shape_inference::ShapeHandle;
34
35// Returns a Tensor containing the underlyiing constant value of a Node if the
36// node contains a constant value.
37Status EvaluateConstantNode(const Node& node, Tensor* output, bool* success) {
38 *success = false;
39 if (node.IsConstant()) {
40 if (output->FromProto(node.def().attr().at("value").tensor())) {
41 *success = true;
42 }
43 }
44 return OkStatus();
45}
46
47// Returns the int value corresponding to the input src at the i'th edge if the
48// input src contains a scalar tensor.
49Status EvaluateConstantIntFromScalarEdge(const Node& node, int input_idx,
50 int64* output, bool* success) {
51 *success = false;
52 Tensor scalar;
53 const Edge* edge;
54 TF_RETURN_IF_ERROR(node.input_edge(input_idx, &edge));
55 TF_RETURN_IF_ERROR(EvaluateConstantNode(*edge->src(), &scalar, success));
56 if (success && scalar.NumElements() == 1) {
57 if (scalar.dtype() == DT_INT32) {
58 *output = scalar.scalar<int32>()();
59 } else if (scalar.dtype() == DT_INT64) {
60 *output = scalar.scalar<int64_t>()();
61 } else {
62 *success = false;
63 }
64 }
65 return OkStatus();
66}
67
68// Tries to infer the tensor output based on the input dims of a
69// Shape node.
70// [allow_partial = false]
71// Can infer the Shape op's output tensor only when the
72// input shapes to the Shape op are fully defined.
73// [allow_partial = true]
74// Can infer the Shape op's output tensor as long as the rank of the input
75// shapes to the Shape op are known. Uses kUnknownDim for unknown dims.
76Status TryToInferTensorOutputFromShapeNode(const Node& shape_node,
77 InferenceContext* shape_c,
78 Tensor* output, bool* success,
79 bool allow_partial = false) {
80 *success = false;
81 if (shape_node.type_string() != "Shape") return OkStatus();
82 if (shape_c == nullptr) return OkStatus();
83 if (!shape_c->FullyDefined(shape_c->input(0)) && !allow_partial)
84 return OkStatus();
85 if (!shape_c->RankKnown(shape_c->input(0))) return OkStatus();
86
87 int src_rank = shape_c->Rank(shape_c->input(0));
88 Tensor t(shape_node.output_type(0), TensorShape({src_rank}));
89 if (shape_node.output_type(0) == DT_INT32) {
90 auto flat = t.flat<int>();
91 for (int i = 0; i < src_rank; i++) {
92 int64_t dimension;
93 if (shape_c->ValueKnown(shape_c->Dim(shape_c->input(0), i))) {
94 dimension = shape_c->Value(shape_c->Dim(shape_c->input(0), i));
95 if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
96 return errors::InvalidArgument(
97 "Shape has output type int32, but dimension exceeds maximum "
98 "int32 value");
99 }
100 } else {
101 dimension = shape_c->kUnknownDim;
102 }
103 flat(i) = static_cast<int32>(dimension);
104 }
105 } else if (shape_node.output_type(0) == DT_INT64) {
106 auto flat = t.flat<int64_t>();
107 for (int i = 0; i < src_rank; i++) {
108 if (shape_c->ValueKnown(shape_c->Dim(shape_c->input(0), i))) {
109 flat(i) = shape_c->Value(shape_c->Dim(shape_c->input(0), i));
110 } else {
111 flat(i) = shape_c->kUnknownDim;
112 }
113 }
114 } else {
115 return errors::FailedPrecondition(
116 "Shape has output type that is not int32 or int64");
117 }
118 *output = t;
119 *success = true;
120 return OkStatus();
121}
122
123// Tries to infer the tensor output of a StridedSlice node. This can be done
124// when taking a slice of a fully defined Shape node or when taking a slice
125// of partial Shape node along a known dimension.
126// Examples:
127// tf.shape(x)[0]; x.shape = (5, 10) - slicing fully defined shape
128// tf.shape(x)[0]; x.shape = (5, ?) - slicing partial shape along known dim
129Status TryToInferTensorOutputFromStridedSliceNode(const Node& node,
130 const ShapeRefiner& refiner,
131 Tensor* output,
132 bool* success) {
133 *success = false;
134 const Edge* edge;
135 TF_RETURN_IF_ERROR(node.input_edge(0, &edge));
136 const Node* shape_node = edge->src();
137 const Node* stride_node = edge->dst();
138 InferenceContext* shape_c = refiner.GetContext(shape_node);
139 InferenceContext* stride_c = refiner.GetContext(stride_node);
140
141 if (stride_c == nullptr || shape_c == nullptr) return OkStatus();
142 if (stride_node == nullptr || shape_node == nullptr) return OkStatus();
143 if (stride_node->type_string() != "StridedSlice") return OkStatus();
144 if (shape_node->type_string() != "Shape") return OkStatus();
145
146 // Only attempt to evaluate if the rank of the inputs to the Shape node are
147 // known.
148 if (!shape_c->RankKnown(shape_c->input(0))) return OkStatus();
149
150 // Only attempt to evaluate if begin/end/strides values of the StridedSlice
151 // node are all scalars.
152 for (int i = 1; i <= 3; ++i) {
153 ShapeHandle input_shape = stride_c->input(i);
154 if (stride_c->Value(stride_c->Dim(input_shape, 0)) != 1) {
155 return OkStatus();
156 }
157 }
158
159 // Only attempt to evaluate cases with non-complex masks.
160 int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
161 TF_RETURN_IF_ERROR(stride_c->GetAttr("begin_mask", &begin_mask));
162 TF_RETURN_IF_ERROR(stride_c->GetAttr("end_mask", &end_mask));
163 TF_RETURN_IF_ERROR(stride_c->GetAttr("ellipsis_mask", &ellipsis_mask));
164 TF_RETURN_IF_ERROR(stride_c->GetAttr("new_axis_mask", &new_axis_mask));
165 TF_RETURN_IF_ERROR(stride_c->GetAttr("shrink_axis_mask", &shrink_axis_mask));
166
167 // Case where user has sliced a single element of a collection. E.g.
168 // collection[i].
169 bool accesses_single_element = begin_mask == 0 && end_mask == 0 &&
170 ellipsis_mask == 0 && new_axis_mask == 0 &&
171 shrink_axis_mask == 1;
172
173 if (!accesses_single_element) return OkStatus();
174
175 // Calculate the output tensor from the Shape node.
176 Tensor shape_output;
177 TF_RETURN_IF_ERROR(TryToInferTensorOutputFromShapeNode(
178 *shape_node, shape_c, &shape_output, success, /*allow_partial=*/true));
179 if (!success) return OkStatus();
180
181 // Discard the output tensor computed above if the StridedSlice points to an
182 // unknown dimension.
183 int64 begin_value = 0;
184 bool evaluated = false;
185 *success = false;
186 TF_RETURN_IF_ERROR(EvaluateConstantIntFromScalarEdge(
187 *stride_node, 1, &begin_value, &evaluated));
188
189 if (evaluated && node.output_type(0) == shape_output.dtype()) {
190 begin_value = begin_value < 0
191 ? begin_value + shape_c->Rank(shape_c->input(0))
192 : begin_value;
193 Tensor t(node.output_type(0), TensorShape({}));
194 if (shape_output.dtype() == DT_INT32 &&
195 shape_output.flat<int>()(begin_value) != -1) {
196 t.flat<int32>()(0) = shape_output.flat<int>()(begin_value);
197 *output = t;
198 *success = true;
199 } else if (shape_output.dtype() == DT_INT64 &&
200 shape_output.flat<int64_t>()(begin_value) != -1) {
201 t.flat<int64_t>()(0) = shape_output.flat<int64_t>()(begin_value);
202 *output = t;
203 *success = true;
204 }
205 }
206
207 return OkStatus();
208}
209
210// Tries to infer tensor output based on the input shapes of the node. In some
211// cases, the shapes of the inputs are sufficient for inferring the contents of
212// the output tensor. For example, a Shape op with fully defined input shapes
213// can have its output tensor inferred.
214Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
215 const ShapeRefiner& refiner,
216 Tensor* output, bool* success) {
217 *success = false;
218 const Node* node = edge.src();
219 InferenceContext* c = refiner.GetContext(node);
220 if (c == nullptr) {
221 // An input without context is a soft failure; we sometimes need to break
222 // control flow loops by running shape inference on a node without first
223 // adding its input.
224 return OkStatus();
225 }
226
227 if (node->type_string() == "StridedSlice") {
228 TF_RETURN_IF_ERROR(TryToInferTensorOutputFromStridedSliceNode(
229 *node, refiner, output, success));
230 } else if (node->type_string() == "Shape") {
231 // If input shapes to the shape op are fully defined,
232 // we can infer the shape op's output tensor.
233 TF_RETURN_IF_ERROR(
234 TryToInferTensorOutputFromShapeNode(*node, c, output, success));
235 } else if (node->type_string() == "Rank") {
236 bool rank_known = c->RankKnown(c->input(0));
237 if (rank_known) {
238 int32 input_rank = c->Rank(c->input(0));
239 Tensor t(node->output_type(0), TensorShape({}));
240 t.flat<int32>()(0) = input_rank;
241 *output = t;
242 *success = true;
243 }
244 } else if (node->type_string() == "Size") {
245 bool fully_defined_inputs = c->FullyDefined(c->input(0));
246 if (fully_defined_inputs) {
247 int32 rank = c->Rank(c->input(0));
248 Tensor t(node->output_type(0), TensorShape({}));
249 int64 size = 1;
250 for (int i = 0; i < rank; i++) {
251 size *= c->Value(c->Dim(c->input(0), i));
252 }
253 if (node->output_type(0) == DT_INT32) {
254 if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
255 return errors::InvalidArgument(
256 "Size has output type int32, but size exceeds maximum int32 "
257 "value");
258 }
259 t.flat<int32>()(0) = static_cast<int32>(size);
260 } else if (node->output_type(0) == DT_INT64) {
261 t.flat<int64_t>()(0) = size;
262 } else {
263 return errors::FailedPrecondition(
264 "Size has output type that is not int32 or int64");
265 }
266 *output = t;
267 *success = true;
268 }
269 }
270 return OkStatus();
271}
272
273// Returns true if 'node' has a registered CPU kernel.
274bool HasCpuKernel(const Node& node) {
275 return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
276 /*kernel_class_name=*/nullptr)
277 .ok();
278}
279
280Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) {
281 DCHECK(node->IsArg());
282 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", index));
283 if (*index < 0 || num_function_inputs <= *index) {
284 return errors::Internal(
285 "Function instantiation included invalid input index: ", index,
286 " not in [0, ", num_function_inputs, ").");
287 }
288 return OkStatus();
289}
290
291// Extracts the subgraph ending at 'target_node' that is statically computable
292// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
293// will be set to true.
294Status ExtractConstantSubgraph(
295 const Node& target_node, const ShapeRefiner& refiner,
296 const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
297 bool* is_constant_graph,
298 std::vector<std::pair<string, Tensor>>* const_inputs,
299 InferenceContext* outer_context) {
300 *is_constant_graph = false;
301 std::unordered_set<string> const_inputs_added;
302 if (target_node.op_def().is_stateful()) {
303 return OkStatus();
304 }
305
306 if (IsMerge(&target_node)) {
307 return OkStatus();
308 }
309
310 if (target_node.type_string() == "PlaceholderWithDefault") {
311 return OkStatus();
312 }
313
314 // Since constant-folding runs on the CPU, do not attempt to constant-fold
315 // operators that have no CPU kernel.
316 if (!HasCpuKernel(target_node)) {
317 return OkStatus();
318 }
319
320 // TODO(skyewm): should more of the filtering applied in input nodes below be
321 // applied to target_node here?
322
323 // Identify the possibly constant subgraph by recursively iterating backwards
324 // through the inputs to 'target_node' until we either 1) find an already
325 // existing input to our subgraph 'const_inputs', 2) Discover our graph is not
326 // constant, or 3) Hit a root node.
327
328 struct NodeAndRecursed {
329 Node* new_node = nullptr;
330 bool recursed = false;
331 };
332
333 std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed;
334 Node* target_node_copy = out_graph->CopyNode(&target_node);
335 old_to_new_and_recursed[&target_node].new_node = target_node_copy;
336 old_to_new_and_recursed[&target_node].recursed = true;
337
338 // Add the target node's inputs to seed the recursion.
339 std::deque<const Edge*> edges_to_visit;
340 for (const Edge* e : target_node.in_edges()) {
341 // TODO(skyewm): control edges will be meaningful if/when we handle control
342 // flow (e.g. constants in cond branches are triggered via control edges).
343 if (e->IsControlEdge()) continue;
344 edges_to_visit.push_back(e);
345 }
346
347 *is_constant_graph = true;
348
349 // Iterate over the set of edges to visit (backwards).
350 while (!edges_to_visit.empty()) {
351 const Edge* current_edge = edges_to_visit.front();
352 edges_to_visit.pop_front();
353 Node* current_node = current_edge->src();
354
355 // If the node is stateful, assume the graph is not constant unless it is
356 // an Arg node which is handled later on.
357 if (!current_node->IsArg() && current_node->op_def().is_stateful()) {
358 *is_constant_graph = false;
359 return OkStatus();
360 }
361
362 // During construction or import from GraphConstructor, back edges may not
363 // be filled in. In addition, control flow constructs may depend on control
364 // edges which aren't handled by this method. Don't constant fold through
365 // merges at all for now.
366 if (IsMerge(current_node)) {
367 *is_constant_graph = false;
368 return OkStatus();
369 }
370
371 // Don't constant fold enter/exit currently either, as it's easy to end
372 // up with a partial frame.
373 if (IsEnter(current_node) || IsExit(current_node)) {
374 *is_constant_graph = false;
375 return OkStatus();
376 }
377
378 // Placeholders should never be constant folded because their outputs are
379 // fed by the user. Note that "Placeholder" nodes have no inputs so are
380 // handled below.
381 if (current_node->type_string() == "PlaceholderWithDefault") {
382 *is_constant_graph = false;
383 return OkStatus();
384 }
385
386 if (!HasCpuKernel(*current_node)) {
387 *is_constant_graph = false;
388 return OkStatus();
389 }
390
391 // If there is nothing more to recurse down, see if
392 // the generator node is a constant or an Arg node whose value is available
393 // in the `outer_context`.
394 if (current_node->num_inputs() == 0) {
395 if (outer_context && current_node->IsArg()) {
396 const string& tensor_name =
397 strings::StrCat(current_node->name(), ":", 0);
398 // If we do not already have a constant Tensor for this Arg try to
399 // fetch it from the outer context.
400 if (const_inputs_added.count(tensor_name) == 0) {
401 int index;
402 TF_RETURN_IF_ERROR(GetArgNodeIndex(
403 current_node, outer_context->num_inputs(), &index));
404 const Tensor* const_tensor = outer_context->input_tensor(index);
405 if (const_tensor) {
406 const_inputs->emplace_back(tensor_name, *const_tensor);
407 const_inputs_added.insert(tensor_name);
408 } else {
409 // Request a constant value for this Arg. If that is statically
410 // computable, shape refiner will re-run the shape inference for
411 // this function with this tensor's value.
412 outer_context->request_input_tensor(index);
413 *is_constant_graph = false;
414 return OkStatus();
415 }
416 }
417 } else if (!current_node->IsConstant()) {
418 // Generator node is not a constant, so subgraph is not
419 // constant.
420 *is_constant_graph = false;
421 return OkStatus();
422 }
423 }
424
425 // Either the node is a constant, or the node is a potential
426 // intermediate node on the path from a constant.
427 //
428 // Add a copy of its node and a new edge to the new subgraph.
429
430 // Get or create the version of 'current_node' in the new graph.
431 Node* current_node_copy;
432 // This gets or creates the NodeAndRecursed entry for current_node.
433 NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
434 if (node_and_recursed->new_node == nullptr) {
435 // First time processing this node.
436 current_node_copy = out_graph->CopyNode(current_node);
437 // Track the mapping from the original node to the new one.
438 node_and_recursed->new_node = current_node_copy;
439 } else {
440 current_node_copy = node_and_recursed->new_node;
441 }
442
443 // Add the edge to the destination node.
444 {
445 auto it = old_to_new_and_recursed.find(current_edge->dst());
446 if (it == old_to_new_and_recursed.end()) {
447 return errors::Internal(
448 "Could not find mapping from old to new copy of destination node: ",
449 current_edge->dst()->name());
450 }
451 Node* dst_copy = it->second.new_node;
452
453 out_graph->AddEdge(current_node_copy, current_edge->src_output(),
454 dst_copy, current_edge->dst_input());
455 }
456
457 const string& output_tensor_name =
458 strings::StrCat(current_node->name(), ":", current_edge->src_output());
459
460 // Some tensor values can be inferred. For example, a shape op
461 // with input shapes fully defined can have its output tensor inferred.
462 Tensor tensor_inferred;
463 bool successfully_inferred_tensor = false;
464 TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
465 *current_edge, refiner, &tensor_inferred,
466 &successfully_inferred_tensor));
467 if (successfully_inferred_tensor) {
468 const_inputs->emplace_back(output_tensor_name, tensor_inferred);
469 const_inputs_added.insert(output_tensor_name);
470 continue;
471 }
472
473 // If we have a copy of the input tensor materialized already,
474 // then add to the list of inputs to feed and do not recurse further.
475 if (cached_values != nullptr) {
476 auto it = cached_values->find(output_tensor_name);
477 if (it != cached_values->end() &&
478 const_inputs_added.count(output_tensor_name) == 0) {
479 const_inputs->emplace_back(output_tensor_name, it->second);
480 const_inputs_added.insert(output_tensor_name);
481 continue;
482 }
483 }
484
485 // If this node's inputs have not been processed already, do so now.
486 if (!node_and_recursed->recursed) {
487 node_and_recursed->recursed = true;
488 for (const Edge* e : current_node->in_edges()) {
489 if (e->IsControlEdge()) continue;
490 edges_to_visit.push_back(e);
491 }
492 }
493 }
494 return OkStatus();
495}
496
497} // namespace
498
499Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
500 const OpRegistryInterface& ops,
501 int32 graph_def_version, bool* evaluated,
502 Tensor* result, GraphRunner* graph_runner,
503 std::unordered_map<string, Tensor>* cached_values,
504 int64 max_cached_value_size,
505 bool disable_constant_propagation,
506 InferenceContext* outer_context) {
507 *evaluated = false;
508 const Node* src = tensor.node;
509
510 // Simple case: the source node is a constant
511 TF_RETURN_IF_ERROR(EvaluateConstantNode(*src, result, evaluated));
512 if (*evaluated) return OkStatus();
513
514 // Shape Slice: the source node is slicing a single value of a shape
515 // This is needed to handle the case where the StridedSlice is the only
516 // SubGraph and there are no other subgraphs as in a simple expression such as
517 // tf.shape([-1, 10])[-1] (the ExtractConstantSubgraph call below
518 // only looks at all the input srcs of the various edges; there is never a
519 // chance to evaluate the StridedSlice node as it is never an input src).
520 if (src->type_string() == "StridedSlice") {
521 Tensor slice_output;
522 TF_RETURN_IF_ERROR(TryToInferTensorOutputFromStridedSliceNode(
523 *src, refiner, &slice_output, evaluated));
524 if (*evaluated) {
525 *result = slice_output;
526 return OkStatus();
527 }
528 }
529
530 // If the source node is an Arg return its value, if available in the outer
531 // context.
532 if (src->IsArg() && outer_context) {
533 int index;
534 TF_RETURN_IF_ERROR(
535 GetArgNodeIndex(src, outer_context->num_inputs(), &index));
536 const Tensor* const_tensor = outer_context->input_tensor(index);
537 if (const_tensor) {
538 *evaluated = true;
539 *result = *(outer_context->input_tensor(index));
540 } else {
541 outer_context->request_input_tensor(index);
542 }
543 return OkStatus();
544 }
545
546 if (disable_constant_propagation) {
547 return OkStatus();
548 }
549
550 bool is_constant_graph = false;
551 Graph subgraph(&ops);
552 auto versions = subgraph.versions();
553 versions.set_producer(graph_def_version);
554 subgraph.set_versions(versions);
555
556 std::vector<std::pair<string, Tensor>> const_inputs;
557 TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
558 &subgraph, &is_constant_graph,
559 &const_inputs, outer_context));
560 if (!is_constant_graph) {
561 return OkStatus();
562 }
563 const string output_tensor_name =
564 strings::StrCat(src->name(), ":", tensor.index);
565 std::vector<Tensor> outputs;
566
567 std::unique_ptr<GraphRunner> graph_runner_storage;
568 if (graph_runner == nullptr) {
569 // TODO(skyewm): Convert to std::make_unique when available.
570 graph_runner_storage.reset(new GraphRunner(Env::Default()));
571 graph_runner = graph_runner_storage.get();
572 }
573
574 // NOTE; we should pass in a function library runtime if we want
575 // to support constant-expression evaluation on functions.
576 Status s = graph_runner->Run(&subgraph, nullptr /* function_library */,
577 const_inputs, {output_tensor_name}, &outputs);
578
579 // If all kernels in the constant graph are not registered
580 // in the process, GraphRunner::Run may fail, in which case
581 // we cannot propagate constants, so this is best-effort.
582 if (s.ok()) {
583 *result = outputs[0];
584 *evaluated = true;
585
586 // We memoize (small) constants evaluated so far, so
587 // ExtractConstantSubgraph can avoid extracting the full
588 // subgraph. As we build up large graphs, this avoids
589 // repeated computation of the early parts of a constant
590 // graph.
591 if (cached_values != nullptr &&
592 outputs[0].TotalBytes() <= max_cached_value_size) {
593 (*cached_values)[output_tensor_name] = outputs[0];
594 }
595 }
596 return OkStatus();
597}
598
599} // namespace tensorflow
600