1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/shape_refiner.h"
16
17#include <deque>
18#include <memory>
19#include <unordered_set>
20#include <vector>
21
22#include "tensorflow/core/common_runtime/eval_const_tensor.h"
23#include "tensorflow/core/common_runtime/function_utils.h"
24#include "tensorflow/core/common_runtime/graph_constructor.h"
25#include "tensorflow/core/framework/bounds_check.h"
26#include "tensorflow/core/framework/common_shape_fns.h"
27#include "tensorflow/core/framework/node_def.pb.h"
28#include "tensorflow/core/framework/shape_inference.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/tensor.pb.h"
31#include "tensorflow/core/framework/versions.pb.h"
32#include "tensorflow/core/graph/algorithm.h"
33#include "tensorflow/core/lib/core/errors.h"
34
35namespace tensorflow {
36
37using shape_inference::DimensionHandle;
38using shape_inference::InferenceContext;
39using shape_inference::ShapeAndType;
40using shape_inference::ShapeHandle;
41
42ShapeRefiner::ShapeRefiner(int graph_def_version,
43 const OpRegistryInterface* ops)
44 : graph_def_version_(graph_def_version),
45 ops_registry_(ops),
46 graph_runner_(Env::Default()) {}
47
48ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49 const OpRegistryInterface* ops)
50 : ShapeRefiner(versions.producer(), ops) {}
51
52ShapeRefiner::~ShapeRefiner() {
53 // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54 // should be deleted before it.
55 const_tensor_map_.clear();
56}
57
58namespace {
59
60constexpr char kArgOp[] = "_Arg";
61constexpr char kRetvalOp[] = "_Retval";
62
63} // namespace
64
65// Runs shape inference for the given node using the given ShapeRefiner.
66// The node must be a sub-node of a function node and the outer_context is
67// the inference context of that function node in the outer graph.
68Status ShapeRefiner::InferShapesForFunctionSubNode(
69 const Node* node, InferenceContext* outer_context) {
70 TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
71 InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
72
73 if (StringPiece(node->type_string()) == kArgOp) {
74 // Handle special node: function input.
75 // Shapes for these nodes are provided in the outer inference
76 // context.
77
78 int index;
79 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
80
81 if (index < 0 || outer_context->num_inputs() <= index) {
82 return errors::Internal(
83 "Function instantiation included invalid input index: ", index,
84 " not in [0, ", outer_context->num_inputs(), ").");
85 }
86
87 // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set
88 // in outer context, set _Arg node output shape to unknown.
89 if (outer_context->input(index).SameHandle(ShapeHandle())) {
90 VLOG(1) << "Function instantiation has undefined input shape at "
91 << "index: " << index << " in the outer inference context.";
92 node_context->set_output(0, node_context->UnknownShape());
93 } else {
94 node_context->set_output(0, outer_context->input(index));
95 }
96
97 auto* resource = outer_context->input_handle_shapes_and_types(index);
98 if (resource) {
99 node_context->set_output_handle_shapes_and_types(0, *resource);
100 }
101 } else if (StringPiece(node->type_string()) == kRetvalOp) {
102 // Handle special node: function output.
103 // Shapes inferred for these nodes go into the outer inference
104 // context.
105
106 int index;
107 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
108
109 if (index < 0 || outer_context->num_outputs() <= index) {
110 return errors::Internal(
111 "Function instantiation included invalid output index: ", index,
112 " not in [0, ", outer_context->num_outputs(), ").");
113 }
114
115 // outer_context outlives node_context, therefore we need to create
116 // a new shape handle owned by outer_context instead.
117 ShapeHandle handle;
118 TensorShapeProto proto;
119 node_context->ShapeHandleToProto(node_context->input(0), &proto);
120 TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
121 outer_context->set_output(index, handle);
122
123 const std::vector<ShapeAndType>* resource =
124 node_context->input_handle_shapes_and_types(0);
125 if (resource) {
126 // `ShapesAndType`s contain `ShapeHandle`s. These `ShapeHandle`s point
127 // to `Shape`s that are owned by a different inference context too. We
128 // need to copy them to the outer context to prevent them from being
129 // destroyed before they are used.
130 std::vector<ShapeAndType> copied_shapes_and_types;
131 for (auto& shape_and_type : *resource) {
132 ShapeHandle handle;
133 TensorShapeProto proto;
134 node_context->ShapeHandleToProto(shape_and_type.shape, &proto);
135 TF_RETURN_IF_ERROR(
136 outer_context->MakeShapeFromShapeProto(proto, &handle));
137 copied_shapes_and_types.push_back(
138 ShapeAndType(handle, shape_and_type.dtype, shape_and_type.type));
139 }
140
141 outer_context->set_output_handle_shapes_and_types(
142 index, copied_shapes_and_types);
143 }
144 }
145
146 return OkStatus();
147}
148
149// TODO(cwhipkey): When an inference context inside function has
150// requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
151// set when input(i) is an _Arg op, then this request should propagate to
152// context, and vice versa.
153//
154// NOTE: Recursive user-defined functions are not supported.
155// Maybe we won't support recursive functions at all in TF, because of
156// other maintainability issues.
157Status ShapeRefiner::InferShapesForFunction(const FunctionDef* function_def,
158 AttrSlice attributes,
159 InferenceContext* outer_context) {
160 const Graph* graph;
161 auto it = functions_.find(function_def);
162 if (it != functions_.end()) {
163 graph = it->second.get();
164 } else {
165 InstantiationResult result;
166 TF_RETURN_IF_ERROR(InstantiateFunction(
167 *function_def, attributes,
168 [this](const string& op, const OpDef** sig) {
169 return this->function_library_->LookUpOpDef(op, sig);
170 },
171 &result));
172
173 Graph* new_graph = new Graph(function_library_);
174 GraphConstructorOptions options;
175 options.allow_internal_ops = true;
176 TF_RETURN_IF_ERROR(
177 ConvertNodeDefsToGraph(options, result.nodes, new_graph));
178 functions_[function_def].reset(new_graph);
179 graph = new_graph;
180 }
181
182 std::unordered_set<const Node*> function_nodes;
183 Status inference_status = OkStatus();
184 {
185 auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
186 &inference_status](const Node* node) {
187 if (!inference_status.ok()) return;
188 inference_status = InferShapesForFunctionSubNode(node, outer_context);
189 function_nodes.insert(node);
190 };
191
192 // Calls inference lambda for each node after visiting all predecessors.
193 // Ensures that we are adding nodes to ShapeRefiner in the topological
194 // order.
195 ReverseDFS(*graph, {}, node_shape_inference_lambda);
196 }
197
198 // Delete the contexts created for the functions nodes to save memory.
199 for (const Node* node : function_nodes) {
200 node_to_context_.erase(node);
201 }
202
203 return inference_status;
204}
205
206Status ShapeRefiner::AddNode(const Node* node) {
207 return AddNodeInternal(node, /*outer_context=*/nullptr);
208}
209
210Status ShapeRefiner::AddNodeInternal(
211 const Node* node, shape_inference::InferenceContext* outer_context) {
212 // Create the inference context for this node with the existing input shapes.
213 std::unique_ptr<InferenceContext> ic(new InferenceContext(
214 graph_def_version_, node->def(), node->op_def(),
215 std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
216 TF_RETURN_IF_ERROR(ic->construction_status());
217
218 // For each 'input' of this node, fetch the corresponding shape
219 // from 'input's InferenceContext, and store into this node's
220 // InferenceContext.
221 for (const Edge* e : node->in_edges()) {
222 if (e->IsControlEdge()) continue;
223
224 if (e->dst_input() < 0) {
225 return tensorflow::errors::Internal(
226 "Index ", e->dst_input(), " is negative but not a control edge.");
227 }
228
229 const Node* input = e->src();
230 auto it = node_to_context_.find(input);
231 if (it == node_to_context_.end()) {
232 // v1 control flow adds loops to the graph; we have to break them
233 // somewhere, so we'll ignore this input and leave its shape undefined.
234 ic->SetInput(e->dst_input(), ic->UnknownShape());
235 continue;
236 }
237
238 InferenceContext* input_ic = it->second->get_context();
239 ic->SetInput(e->dst_input(), input_ic->output(e->src_output()));
240
241 const auto* in_v =
242 input_ic->output_handle_shapes_and_types(e->src_output());
243 if (in_v != nullptr) {
244 DataType input_type = e->src()->output_type(e->src_output());
245 DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
246 ic->set_input_handle_shapes_and_types(e->dst_input(),
247 std::vector<ShapeAndType>(*in_v));
248 }
249 }
250
251 // Get the shape function for this node
252 const OpRegistrationData* op_reg_data;
253 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
254 if (op_reg_data->shape_inference_fn == nullptr &&
255 require_shape_inference_fns_) {
256 return errors::InvalidArgument(
257 "No shape inference function exists for op '", node->type_string(),
258 "', did you forget to define it?");
259 }
260
261 std::unique_ptr<ExtendedInferenceContext> ec(
262 new ExtendedInferenceContext(std::move(ic), node));
263
264 // Run the shape inference function, and return if there was an error.
265 TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
266
267 // Store the resulting context object in the map.
268 node_to_context_[node].swap(ec);
269
270 return OkStatus();
271}
272
273Status ShapeRefiner::SetShape(const Node* node, int output_port,
274 ShapeHandle shape) {
275 auto c = GetContext(node);
276 if (c == nullptr) {
277 return errors::Internal("Could not find context for ", node->name());
278 }
279
280 if (output_port < 0 || output_port >= node->num_outputs()) {
281 return errors::InvalidArgument(
282 "output_port '", output_port, "' is out of range, ", "node '",
283 node->name(), "' has ", node->num_outputs(), " outputs");
284 }
285 // Note: it's possible, if the node's been updated, that the shape inference
286 // context doesn't have the right number of outputs.
287 if (node->num_outputs() > c->num_outputs()) {
288 TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
289 }
290
291 // Check compatibility, and merge the shapes.
292 ShapeHandle existing_shape = c->output(output_port);
293 TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
294 c->set_output(output_port, shape);
295
296 // TODO(vrv): Do we need to propagate the new shape through all
297 // consumers that change their outputs? At the moment, python
298 // does not do this, but this seems like a nice feature.
299
300 // TODO(vrv): We might need to keep track of the fact that the
301 // existing shape is invalidated, in case we need to propagate
302 // this information to remote workers.
303 return OkStatus();
304}
305
306Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
307 auto it = node_to_context_.find(node);
308 if (it == node_to_context_.end()) {
309 *refined = true;
310 return AddNode(node);
311 }
312 ExtendedInferenceContext* node_ext_context = it->second.get();
313 InferenceContext* node_context = node_ext_context->get_context();
314
315 // Give up if the context wasn't successfully built by the AddNode() method.
316 TF_RETURN_IF_ERROR(node_context->construction_status());
317
318 // Check if the shapes of the nodes in the fan-in of this node have changed,
319 // and if they have update the node input shapes.
320 for (const Edge* e : node->in_edges()) {
321 if (e->IsControlEdge()) continue;
322
323 int dst_input = e->dst_input();
324 int src_output = e->src_output();
325
326 Node* input = e->src();
327 auto iter = node_to_context_.find(input);
328 if (iter == node_to_context_.end()) {
329 return errors::FailedPrecondition(
330 "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
331 "' was not previously added to ShapeRefiner.");
332 }
333
334 InferenceContext* c = iter->second->get_context();
335 DCHECK_GE(dst_input, 0);
336 ShapeHandle existing_input = node_context->input(dst_input);
337 if (!relax) {
338 if (node_context->MergeInput(dst_input, c->output(src_output))) {
339 if (!SameDefinedShape(node_context, node_context->input(dst_input),
340 existing_input)) {
341 *refined = true;
342 }
343 }
344 } else {
345 if (node_context->RelaxInput(dst_input, c->output(src_output))) {
346 if (!SameDefinedShape(node_context, node_context->input(dst_input),
347 existing_input)) {
348 *refined = true;
349 }
350 }
351 }
352 if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
353 // The input value may have changed. Since we have no way to know if
354 // that's indeed the case, err on the safe side.
355 *refined = true;
356 }
357
358 // Also propagate handle shape and dtype of edges which are carrying
359 // resource handles.
360 if (e->src()->output_type(src_output) == DT_RESOURCE) {
361 auto* outputs = c->output_handle_shapes_and_types(src_output);
362 if (!outputs) continue;
363
364 if (!relax &&
365 node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
366 *refined = true;
367 } else if (relax) {
368 std::vector<ShapeAndType> existing_inputs;
369 const std::vector<ShapeAndType>* inputs =
370 node_context->input_handle_shapes_and_types(dst_input);
371 if (inputs) {
372 existing_inputs = *inputs;
373 }
374 if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
375 *outputs)) {
376 if (IsUpdatedShapesOrTypes(
377 node_context, existing_inputs,
378 *node_context->input_handle_shapes_and_types(dst_input))) {
379 *refined = true;
380 }
381 }
382 }
383 }
384 }
385
386 if (!*refined) {
387 // No input shape has changed, we're done
388 return OkStatus();
389 }
390
391 // Get and run the shape function for this node to update the shapes of the
392 // outputs.
393 const OpRegistrationData* op_reg_data;
394 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
395 if (op_reg_data->shape_inference_fn == nullptr &&
396 require_shape_inference_fns_) {
397 return errors::InvalidArgument(
398 "No shape inference function exists for op '", node->type_string(),
399 "', did you forget to define it?");
400 }
401
402 if (!op_reg_data->shape_inference_fn) {
403 // There is nothing more we can infer
404 return OkStatus();
405 }
406
407 return RunShapeFn(node, op_reg_data, node_ext_context);
408}
409
410Status ShapeRefiner::EvaluateConstantTensorForEdge(
411 const Node* node, int dst_idx, bool* evaluated, Tensor* result,
412 InferenceContext* outer_context) {
413 *evaluated = false;
414 const Edge* input_edge;
415 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
416 OutputTensor tensor(input_edge->src(), input_edge->src_output());
417 return EvaluateConstantTensor(
418 tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
419 &graph_runner_, &const_tensor_map_, kMaxTensorSize,
420 disable_constant_propagation_, outer_context);
421}
422
423Status ShapeRefiner::EvaluateConstantIntScalarEdge(
424 const Node* node, int dst_idx, bool* evaluated, int64_t* result,
425 shape_inference::InferenceContext* outer_context) {
426 Tensor scalar;
427 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
428 &scalar, outer_context));
429 if (*evaluated) {
430 if (scalar.NumElements() != 1) {
431 return errors::InvalidArgument(
432 "EvaluateConstantIntScalarEdge called on non-scalar edge: ",
433 scalar.NumElements());
434 }
435 if (scalar.dtype() == DT_INT32) {
436 *result = scalar.scalar<int32>()();
437 } else {
438 if (scalar.dtype() != DT_INT64) {
439 return errors::InvalidArgument(
440 "EvaluateConstantIntScalarEdge called on non-integer edge: ",
441 scalar.dtype());
442 }
443 *result = scalar.scalar<int64_t>()();
444 }
445 }
446 return OkStatus();
447}
448
449Status ShapeRefiner::ConstantPartialShape(
450 InferenceContext* target_context, const Node* node, int dst_idx,
451 ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
452 const Edge* input_edge;
453 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
454
455 InferenceContext* src_context = GetContext(input_edge->src());
456 if (src_context == nullptr) return errors::Internal("Missing src context");
457 ShapeHandle src_shape = src_context->output(input_edge->src_output());
458
459 // All shapes are expected to be 1D integer tensors with the exception of the
460 // sentinel that represents an unknown shape (scalar/rank 0 tensor with -1 as
461 // value). Handle the special case first before considering the more general
462 // rank 1 case.
463
464 if (src_context->Value(src_context->Rank(src_shape)) == 0) {
465 Tensor t;
466 bool evaluated = false;
467 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
468 &t, outer_context));
469 if (!evaluated) {
470 return errors::InvalidArgument(
471 "Received a shape scalar with unknown static value. A static value "
472 "of '-1' is required to represent an unknown shape.");
473 }
474 if (t.dims() == 0) {
475 if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
476 *result = target_context->UnknownShape();
477 return OkStatus();
478 } else if (t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1) {
479 *result = target_context->UnknownShape();
480 return OkStatus();
481 }
482 }
483 return errors::InvalidArgument(
484 "Received an invalid shape scalar with a static value that is not "
485 "'-1': ",
486 t.DebugString());
487 }
488
489 TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
490
491 const string& src_op = input_edge->src()->type_string();
492 if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
493 // Source tensor is a vector of length 0, so the shape it
494 // represents is as scalar.
495 *result = target_context->Scalar();
496 } else if (src_op == "Cast") {
497 // First try to evaluate the current tensor, as it might be a valid cast of
498 // a float.
499 Tensor t;
500 bool evaluated = false;
501 if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
502 outer_context)
503 .ok()) {
504 if (evaluated &&
505 target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
506 return OkStatus();
507 }
508 }
509
510 // Then try to infer partial shape from the input to the cast tensor.
511 ShapeHandle pre_cast_shape;
512 if (!ConstantPartialShape(target_context, input_edge->src(), 0,
513 &pre_cast_shape, outer_context)
514 .ok()) {
515 TF_RETURN_IF_ERROR(
516 target_context->MakeShapeFromTensor(nullptr, src_shape, result));
517 }
518 if (!target_context->RankKnown(pre_cast_shape)) {
519 // Failed to evaluate. Treat the output as completely unknown.
520 *result = target_context->UnknownShape();
521 return OkStatus();
522 }
523 auto* dest_type = input_edge->src()->attrs().Find("DstT");
524 if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
525 (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
526 // Casting to a weird type. Do not attempt to infer across it.
527 *result = target_context->MakeShape(std::vector<DimensionHandle>(
528 target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
529 return OkStatus();
530 }
531 *result = pre_cast_shape;
532 } else if (src_op == "Shape") {
533 *result = src_context->input(0);
534 } else if (src_op == "ShapeN") {
535 *result = src_context->input(input_edge->src_output());
536 } else if (src_op == "Pack") {
537 std::vector<DimensionHandle> dims;
538 // Pack is concatenating its input scalars to form the shape tensor vector.
539 for (int i = 0; i < src_context->num_inputs(); ++i) {
540 int64_t size;
541 bool evaluated;
542 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
543 input_edge->src(), i, &evaluated, &size, outer_context));
544 if (evaluated) {
545 dims.push_back(size < 0 ? target_context->UnknownDim()
546 : target_context->MakeDim(size));
547 } else {
548 dims.push_back(target_context->UnknownDim());
549 }
550 }
551 *result = target_context->MakeShape(dims);
552 } else if (src_op == "Concat" || src_op == "ConcatV2") {
553 *result = target_context->Scalar();
554 // For Concat, input 0 is concat dim; for V2 it is the last input.
555 const int concat_dim =
556 src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
557 // Concat is concatenating its input shape vectors.
558 for (int i = 0; i < src_context->num_inputs(); ++i) {
559 // Concat dim is ignored (and will always be a scalar).
560 if (i == concat_dim) continue;
561 ShapeHandle sub_result;
562 TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
563 i, &sub_result, outer_context));
564 if (!target_context->RankKnown(sub_result)) {
565 // Failed to evaluate. Treat the output as completely unknown.
566 // TODO(cwhipkey): we could rely on all inputs being the same rank, so
567 // figure that rank out and append the right number of unknown dims.
568 *result = target_context->UnknownShape();
569 return OkStatus();
570 }
571 TF_RETURN_IF_ERROR(
572 target_context->Concatenate(*result, sub_result, result));
573 }
574 } else if (src_op == "StridedSlice") {
575 TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
576 result, outer_context));
577 } else if (src_op == "VariableShape") {
578 auto* handle_data = src_context->input_handle_shapes_and_types(0);
579 if (handle_data != nullptr && !handle_data->empty()) {
580 *result = handle_data->at(0).shape;
581 } else {
582 *result = target_context->UnknownShape();
583 }
584 } else {
585 Tensor t;
586 bool evaluated = false;
587 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
588 &t, outer_context));
589 TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
590 evaluated ? &t : nullptr, src_shape, result));
591 }
592 return OkStatus();
593}
594
595Status ShapeRefiner::PartialStridedSliceShape(
596 Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
597 shape_inference::InferenceContext* outer_context) {
598 // Only attempt to evaluate if begin/end/strides all are scalars.
599 for (int i = 1; i <= 3; ++i) {
600 ShapeHandle input_shape = ctx->input(i);
601 if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
602 *result = ctx->UnknownShape();
603 return OkStatus();
604 }
605 }
606
607 int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
608 TF_RETURN_IF_ERROR(
609 GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
610 TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
611 TF_RETURN_IF_ERROR(
612 GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
613 TF_RETURN_IF_ERROR(
614 GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
615 TF_RETURN_IF_ERROR(
616 GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
617
618 // Only attempt to evaluate if there are no special masks set (note that we
619 // can handle begin/end_mask == 1).
620 if (!(begin_mask == 0 || begin_mask == 1) ||
621 !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
622 new_axis_mask != 0 || shrink_axis_mask != 0) {
623 *result = ctx->UnknownShape();
624 return OkStatus();
625 }
626
627 bool evaluated;
628 int64_t begin;
629 if (begin_mask == 1) {
630 begin = 0;
631 } else {
632 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
633 &begin, outer_context));
634 if (!evaluated) {
635 *result = ctx->UnknownShape();
636 return OkStatus();
637 }
638 }
639
640 int64_t end;
641 if (end_mask == 1) {
642 end = std::numeric_limits<int64_t>::max();
643 } else {
644 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
645 &end, outer_context));
646 if (!evaluated) {
647 *result = ctx->UnknownShape();
648 return OkStatus();
649 }
650 }
651
652 int64_t stride;
653 TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
654 &stride, outer_context));
655 if (!evaluated) {
656 *result = ctx->UnknownShape();
657 return OkStatus();
658 }
659
660 // Apply stride to input interpreted as a partial shape.
661 ShapeHandle input;
662 TF_RETURN_IF_ERROR(
663 ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
664 TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
665 return OkStatus();
666}
667
668Status ShapeRefiner::RunShapeFn(const Node* node,
669 const OpRegistrationData* op_reg_data,
670 ExtendedInferenceContext* ec,
671 InferenceContext* outer_context) {
672 // This will be filled in with real data in a second pass.
673 std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
674 std::vector<Tensor> real_tensors(node->num_inputs());
675 std::vector<bool> attempted_materialization(node->num_inputs());
676 std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
677 std::vector<ShapeHandle> input_tensors_as_shapes;
678
679 auto* c = ec->get_context();
680
681 c->set_input_tensors(input_tensors);
682 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
683
684 // Run the shape inference function, and return if there was an error.
685 // Capture as lambda, because we might need to re-run inference later on.
686 auto run_inference_lambda = [&]() {
687 if (function_library_ && IsFunctionCall(*function_library_, *node)) {
688 bool disable_shape_inference;
689 if (!GetNodeAttr(AttrSlice(node->def()), "_disable_call_shape_inference",
690 &disable_shape_inference)
691 .ok() ||
692 !disable_shape_inference) {
693 // Special inference logic for user-defined functions.
694 NameAttrList function;
695 TF_RETURN_IF_ERROR(
696 NameAndAttrsFromFunctionCall(node->def(), &function));
697 const FunctionDef* function_def =
698 function_library_->Find(function.name());
699 if (function_def != nullptr) {
700 // The constant Tensor map we have for the outside context is not
701 // valid inside the function. We need to push a new clean map while
702 // performing inference on the function body.
703 auto const_tensor_map_copy = const_tensor_map_;
704 const_tensor_map_.clear();
705 Status function_inference_status = InferShapesForFunction(
706 function_def, AttrSlice(&function.attr()), c);
707 const_tensor_map_ = const_tensor_map_copy;
708 return function_inference_status;
709 }
710 }
711 }
712
713 if (op_reg_data->shape_inference_fn) {
714 TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
715 } else {
716 TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
717 }
718 return OkStatus();
719 };
720 TF_RETURN_IF_ERROR(run_inference_lambda());
721
722 // We must run the shape function repeatedly, in case users write
723 // shape functions where they only conditionally call input_tensor()
724 // based on the values of another input tensor.
725 bool rerun_shape_fn;
726 do {
727 // If the result of running shape inference would have benefitted
728 // from knowing the values of input tensors, try to materialize
729 // the results of those tensors, and then run the shape inference
730 // function again using those known tensors.
731 rerun_shape_fn = false;
732
733 // NOTE: It is possible to batch the extraction and
734 // materialization of inputs, instead of materializing one input
735 // at a time like we do below. If input-at-a-time computation
736 // becomes a bottleneck, we could separate ExtractConstantSubgraph
737 // into two functions: one that returns true if an input is
738 // derivable from constants, and another function that extracts
739 // the subgraph for multiple target nodes and executes the whole
740 // subgraph once.
741
742 for (int i = 0; i < c->num_inputs(); ++i) {
743 if (!c->requested_input_tensor(i)) {
744 continue;
745 }
746 // Check if we have not already filled in the requested input,
747 // and if not, try to materialize the tensors.
748 if (!attempted_materialization[i]) {
749 attempted_materialization[i] = true;
750
751 Tensor result;
752 bool evaluated = false;
753 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
754 node, i, &evaluated, &result, outer_context));
755 if (evaluated) {
756 real_tensors[i] = result;
757 input_tensors[i] = &real_tensors[i];
758 // We have more concrete information about a shape,
759 // so re-run shape inference.
760 rerun_shape_fn = true;
761 }
762 }
763 if (c->requested_input_tensor_as_partial_shape(i) &&
764 !attempted_tensor_as_shape_conversion[i]) {
765 attempted_tensor_as_shape_conversion[i] = true;
766 if (i >= input_tensors_as_shapes.size()) {
767 input_tensors_as_shapes.resize(i + 1);
768 }
769 ShapeHandle s;
770 TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
771 input_tensors_as_shapes[i] = s;
772 rerun_shape_fn = true;
773 }
774 }
775
776 if (rerun_shape_fn) {
777 // We have more information about the shapes on this pass,
778 // so re-run shape inference.
779 c->set_input_tensors(input_tensors);
780 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
781 TF_RETURN_IF_ERROR(run_inference_lambda());
782 }
783 } while (rerun_shape_fn);
784
785 return OkStatus();
786}
787
788bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
789 ShapeHandle s1) {
790 if (s0.SameHandle(s1)) {
791 return true;
792 }
793 if (c->Rank(s0) != c->Rank(s1)) {
794 return false;
795 }
796 if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
797 return false;
798 }
799 for (int i = 0; i < c->Rank(s0); ++i) {
800 if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
801 int64_t val0 = c->Value(c->Dim(s0, i));
802 int64_t val1 = c->Value(c->Dim(s1, i));
803 if (val0 < 0 || val1 < 0 || val0 != val1) {
804 return false;
805 }
806 }
807 }
808
809 return true;
810}
811
812bool ShapeRefiner::IsUpdatedShapesOrTypes(
813 InferenceContext* c, const std::vector<ShapeAndType>& existing,
814 const std::vector<ShapeAndType>& updated) {
815 if (existing.size() != updated.size()) {
816 return true;
817 }
818 for (int i = 0; i < existing.size(); i++) {
819 if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
820 existing[i].dtype != updated[i].dtype) {
821 return true;
822 }
823 }
824 return false;
825}
826
827} // namespace tensorflow
828