1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
16 | |
17 | #include <deque> |
18 | #include <memory> |
19 | #include <unordered_set> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/common_runtime/eval_const_tensor.h" |
23 | #include "tensorflow/core/common_runtime/function_utils.h" |
24 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
25 | #include "tensorflow/core/framework/bounds_check.h" |
26 | #include "tensorflow/core/framework/common_shape_fns.h" |
27 | #include "tensorflow/core/framework/node_def.pb.h" |
28 | #include "tensorflow/core/framework/shape_inference.h" |
29 | #include "tensorflow/core/framework/tensor.h" |
30 | #include "tensorflow/core/framework/tensor.pb.h" |
31 | #include "tensorflow/core/framework/versions.pb.h" |
32 | #include "tensorflow/core/graph/algorithm.h" |
33 | #include "tensorflow/core/lib/core/errors.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | using shape_inference::DimensionHandle; |
38 | using shape_inference::InferenceContext; |
39 | using shape_inference::ShapeAndType; |
40 | using shape_inference::ShapeHandle; |
41 | |
42 | ShapeRefiner::ShapeRefiner(int graph_def_version, |
43 | const OpRegistryInterface* ops) |
44 | : graph_def_version_(graph_def_version), |
45 | ops_registry_(ops), |
46 | graph_runner_(Env::Default()) {} |
47 | |
48 | ShapeRefiner::ShapeRefiner(const VersionDef& versions, |
49 | const OpRegistryInterface* ops) |
50 | : ShapeRefiner(versions.producer(), ops) {} |
51 | |
52 | ShapeRefiner::~ShapeRefiner() { |
53 | // The lifetime of the tensors are bound to the GraphRunner, so the tensors |
54 | // should be deleted before it. |
55 | const_tensor_map_.clear(); |
56 | } |
57 | |
58 | namespace { |
59 | |
60 | constexpr char kArgOp[] = "_Arg" ; |
61 | constexpr char kRetvalOp[] = "_Retval" ; |
62 | |
63 | } // namespace |
64 | |
65 | // Runs shape inference for the given node using the given ShapeRefiner. |
66 | // The node must be a sub-node of a function node and the outer_context is |
67 | // the inference context of that function node in the outer graph. |
68 | Status ShapeRefiner::InferShapesForFunctionSubNode( |
69 | const Node* node, InferenceContext* outer_context) { |
70 | TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context)); |
71 | InferenceContext* node_context = CHECK_NOTNULL(GetContext(node)); |
72 | |
73 | if (StringPiece(node->type_string()) == kArgOp) { |
74 | // Handle special node: function input. |
75 | // Shapes for these nodes are provided in the outer inference |
76 | // context. |
77 | |
78 | int index; |
79 | TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index" , &index)); |
80 | |
81 | if (index < 0 || outer_context->num_inputs() <= index) { |
82 | return errors::Internal( |
83 | "Function instantiation included invalid input index: " , index, |
84 | " not in [0, " , outer_context->num_inputs(), ")." ); |
85 | } |
86 | |
87 | // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set |
88 | // in outer context, set _Arg node output shape to unknown. |
89 | if (outer_context->input(index).SameHandle(ShapeHandle())) { |
90 | VLOG(1) << "Function instantiation has undefined input shape at " |
91 | << "index: " << index << " in the outer inference context." ; |
92 | node_context->set_output(0, node_context->UnknownShape()); |
93 | } else { |
94 | node_context->set_output(0, outer_context->input(index)); |
95 | } |
96 | |
97 | auto* resource = outer_context->input_handle_shapes_and_types(index); |
98 | if (resource) { |
99 | node_context->set_output_handle_shapes_and_types(0, *resource); |
100 | } |
101 | } else if (StringPiece(node->type_string()) == kRetvalOp) { |
102 | // Handle special node: function output. |
103 | // Shapes inferred for these nodes go into the outer inference |
104 | // context. |
105 | |
106 | int index; |
107 | TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index" , &index)); |
108 | |
109 | if (index < 0 || outer_context->num_outputs() <= index) { |
110 | return errors::Internal( |
111 | "Function instantiation included invalid output index: " , index, |
112 | " not in [0, " , outer_context->num_outputs(), ")." ); |
113 | } |
114 | |
115 | // outer_context outlives node_context, therefore we need to create |
116 | // a new shape handle owned by outer_context instead. |
117 | ShapeHandle handle; |
118 | TensorShapeProto proto; |
119 | node_context->ShapeHandleToProto(node_context->input(0), &proto); |
120 | TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle)); |
121 | outer_context->set_output(index, handle); |
122 | |
123 | const std::vector<ShapeAndType>* resource = |
124 | node_context->input_handle_shapes_and_types(0); |
125 | if (resource) { |
126 | // `ShapesAndType`s contain `ShapeHandle`s. These `ShapeHandle`s point |
127 | // to `Shape`s that are owned by a different inference context too. We |
128 | // need to copy them to the outer context to prevent them from being |
129 | // destroyed before they are used. |
130 | std::vector<ShapeAndType> copied_shapes_and_types; |
131 | for (auto& shape_and_type : *resource) { |
132 | ShapeHandle handle; |
133 | TensorShapeProto proto; |
134 | node_context->ShapeHandleToProto(shape_and_type.shape, &proto); |
135 | TF_RETURN_IF_ERROR( |
136 | outer_context->MakeShapeFromShapeProto(proto, &handle)); |
137 | copied_shapes_and_types.push_back( |
138 | ShapeAndType(handle, shape_and_type.dtype, shape_and_type.type)); |
139 | } |
140 | |
141 | outer_context->set_output_handle_shapes_and_types( |
142 | index, copied_shapes_and_types); |
143 | } |
144 | } |
145 | |
146 | return OkStatus(); |
147 | } |
148 | |
149 | // TODO(cwhipkey): When an inference context inside function has |
150 | // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i) |
151 | // set when input(i) is an _Arg op, then this request should propagate to |
152 | // context, and vice versa. |
153 | // |
154 | // NOTE: Recursive user-defined functions are not supported. |
155 | // Maybe we won't support recursive functions at all in TF, because of |
156 | // other maintainability issues. |
157 | Status ShapeRefiner::InferShapesForFunction(const FunctionDef* function_def, |
158 | AttrSlice attributes, |
159 | InferenceContext* outer_context) { |
160 | const Graph* graph; |
161 | auto it = functions_.find(function_def); |
162 | if (it != functions_.end()) { |
163 | graph = it->second.get(); |
164 | } else { |
165 | InstantiationResult result; |
166 | TF_RETURN_IF_ERROR(InstantiateFunction( |
167 | *function_def, attributes, |
168 | [this](const string& op, const OpDef** sig) { |
169 | return this->function_library_->LookUpOpDef(op, sig); |
170 | }, |
171 | &result)); |
172 | |
173 | Graph* new_graph = new Graph(function_library_); |
174 | GraphConstructorOptions options; |
175 | options.allow_internal_ops = true; |
176 | TF_RETURN_IF_ERROR( |
177 | ConvertNodeDefsToGraph(options, result.nodes, new_graph)); |
178 | functions_[function_def].reset(new_graph); |
179 | graph = new_graph; |
180 | } |
181 | |
182 | std::unordered_set<const Node*> function_nodes; |
183 | Status inference_status = OkStatus(); |
184 | { |
185 | auto node_shape_inference_lambda = [this, &outer_context, &function_nodes, |
186 | &inference_status](const Node* node) { |
187 | if (!inference_status.ok()) return; |
188 | inference_status = InferShapesForFunctionSubNode(node, outer_context); |
189 | function_nodes.insert(node); |
190 | }; |
191 | |
192 | // Calls inference lambda for each node after visiting all predecessors. |
193 | // Ensures that we are adding nodes to ShapeRefiner in the topological |
194 | // order. |
195 | ReverseDFS(*graph, {}, node_shape_inference_lambda); |
196 | } |
197 | |
198 | // Delete the contexts created for the functions nodes to save memory. |
199 | for (const Node* node : function_nodes) { |
200 | node_to_context_.erase(node); |
201 | } |
202 | |
203 | return inference_status; |
204 | } |
205 | |
206 | Status ShapeRefiner::AddNode(const Node* node) { |
207 | return AddNodeInternal(node, /*outer_context=*/nullptr); |
208 | } |
209 | |
210 | Status ShapeRefiner::AddNodeInternal( |
211 | const Node* node, shape_inference::InferenceContext* outer_context) { |
212 | // Create the inference context for this node with the existing input shapes. |
213 | std::unique_ptr<InferenceContext> ic(new InferenceContext( |
214 | graph_def_version_, node->def(), node->op_def(), |
215 | std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {})); |
216 | TF_RETURN_IF_ERROR(ic->construction_status()); |
217 | |
218 | // For each 'input' of this node, fetch the corresponding shape |
219 | // from 'input's InferenceContext, and store into this node's |
220 | // InferenceContext. |
221 | for (const Edge* e : node->in_edges()) { |
222 | if (e->IsControlEdge()) continue; |
223 | |
224 | if (e->dst_input() < 0) { |
225 | return tensorflow::errors::Internal( |
226 | "Index " , e->dst_input(), " is negative but not a control edge." ); |
227 | } |
228 | |
229 | const Node* input = e->src(); |
230 | auto it = node_to_context_.find(input); |
231 | if (it == node_to_context_.end()) { |
232 | // v1 control flow adds loops to the graph; we have to break them |
233 | // somewhere, so we'll ignore this input and leave its shape undefined. |
234 | ic->SetInput(e->dst_input(), ic->UnknownShape()); |
235 | continue; |
236 | } |
237 | |
238 | InferenceContext* input_ic = it->second->get_context(); |
239 | ic->SetInput(e->dst_input(), input_ic->output(e->src_output())); |
240 | |
241 | const auto* in_v = |
242 | input_ic->output_handle_shapes_and_types(e->src_output()); |
243 | if (in_v != nullptr) { |
244 | DataType input_type = e->src()->output_type(e->src_output()); |
245 | DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT); |
246 | ic->set_input_handle_shapes_and_types(e->dst_input(), |
247 | std::vector<ShapeAndType>(*in_v)); |
248 | } |
249 | } |
250 | |
251 | // Get the shape function for this node |
252 | const OpRegistrationData* op_reg_data; |
253 | TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data)); |
254 | if (op_reg_data->shape_inference_fn == nullptr && |
255 | require_shape_inference_fns_) { |
256 | return errors::InvalidArgument( |
257 | "No shape inference function exists for op '" , node->type_string(), |
258 | "', did you forget to define it?" ); |
259 | } |
260 | |
261 | std::unique_ptr<ExtendedInferenceContext> ec( |
262 | new ExtendedInferenceContext(std::move(ic), node)); |
263 | |
264 | // Run the shape inference function, and return if there was an error. |
265 | TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context)); |
266 | |
267 | // Store the resulting context object in the map. |
268 | node_to_context_[node].swap(ec); |
269 | |
270 | return OkStatus(); |
271 | } |
272 | |
273 | Status ShapeRefiner::SetShape(const Node* node, int output_port, |
274 | ShapeHandle shape) { |
275 | auto c = GetContext(node); |
276 | if (c == nullptr) { |
277 | return errors::Internal("Could not find context for " , node->name()); |
278 | } |
279 | |
280 | if (output_port < 0 || output_port >= node->num_outputs()) { |
281 | return errors::InvalidArgument( |
282 | "output_port '" , output_port, "' is out of range, " , "node '" , |
283 | node->name(), "' has " , node->num_outputs(), " outputs" ); |
284 | } |
285 | // Note: it's possible, if the node's been updated, that the shape inference |
286 | // context doesn't have the right number of outputs. |
287 | if (node->num_outputs() > c->num_outputs()) { |
288 | TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs())); |
289 | } |
290 | |
291 | // Check compatibility, and merge the shapes. |
292 | ShapeHandle existing_shape = c->output(output_port); |
293 | TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape)); |
294 | c->set_output(output_port, shape); |
295 | |
296 | // TODO(vrv): Do we need to propagate the new shape through all |
297 | // consumers that change their outputs? At the moment, python |
298 | // does not do this, but this seems like a nice feature. |
299 | |
300 | // TODO(vrv): We might need to keep track of the fact that the |
301 | // existing shape is invalidated, in case we need to propagate |
302 | // this information to remote workers. |
303 | return OkStatus(); |
304 | } |
305 | |
306 | Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { |
307 | auto it = node_to_context_.find(node); |
308 | if (it == node_to_context_.end()) { |
309 | *refined = true; |
310 | return AddNode(node); |
311 | } |
312 | ExtendedInferenceContext* node_ext_context = it->second.get(); |
313 | InferenceContext* node_context = node_ext_context->get_context(); |
314 | |
315 | // Give up if the context wasn't successfully built by the AddNode() method. |
316 | TF_RETURN_IF_ERROR(node_context->construction_status()); |
317 | |
318 | // Check if the shapes of the nodes in the fan-in of this node have changed, |
319 | // and if they have update the node input shapes. |
320 | for (const Edge* e : node->in_edges()) { |
321 | if (e->IsControlEdge()) continue; |
322 | |
323 | int dst_input = e->dst_input(); |
324 | int src_output = e->src_output(); |
325 | |
326 | Node* input = e->src(); |
327 | auto iter = node_to_context_.find(input); |
328 | if (iter == node_to_context_.end()) { |
329 | return errors::FailedPrecondition( |
330 | "Input " , dst_input, " ('" , input->name(), "') for '" , node->name(), |
331 | "' was not previously added to ShapeRefiner." ); |
332 | } |
333 | |
334 | InferenceContext* c = iter->second->get_context(); |
335 | DCHECK_GE(dst_input, 0); |
336 | ShapeHandle existing_input = node_context->input(dst_input); |
337 | if (!relax) { |
338 | if (node_context->MergeInput(dst_input, c->output(src_output))) { |
339 | if (!SameDefinedShape(node_context, node_context->input(dst_input), |
340 | existing_input)) { |
341 | *refined = true; |
342 | } |
343 | } |
344 | } else { |
345 | if (node_context->RelaxInput(dst_input, c->output(src_output))) { |
346 | if (!SameDefinedShape(node_context, node_context->input(dst_input), |
347 | existing_input)) { |
348 | *refined = true; |
349 | } |
350 | } |
351 | } |
352 | if (node_context->requested_input_tensor_as_partial_shape(dst_input)) { |
353 | // The input value may have changed. Since we have no way to know if |
354 | // that's indeed the case, err on the safe side. |
355 | *refined = true; |
356 | } |
357 | |
358 | // Also propagate handle shape and dtype of edges which are carrying |
359 | // resource handles. |
360 | if (e->src()->output_type(src_output) == DT_RESOURCE) { |
361 | auto* outputs = c->output_handle_shapes_and_types(src_output); |
362 | if (!outputs) continue; |
363 | |
364 | if (!relax && |
365 | node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) { |
366 | *refined = true; |
367 | } else if (relax) { |
368 | std::vector<ShapeAndType> existing_inputs; |
369 | const std::vector<ShapeAndType>* inputs = |
370 | node_context->input_handle_shapes_and_types(dst_input); |
371 | if (inputs) { |
372 | existing_inputs = *inputs; |
373 | } |
374 | if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input, |
375 | *outputs)) { |
376 | if (IsUpdatedShapesOrTypes( |
377 | node_context, existing_inputs, |
378 | *node_context->input_handle_shapes_and_types(dst_input))) { |
379 | *refined = true; |
380 | } |
381 | } |
382 | } |
383 | } |
384 | } |
385 | |
386 | if (!*refined) { |
387 | // No input shape has changed, we're done |
388 | return OkStatus(); |
389 | } |
390 | |
391 | // Get and run the shape function for this node to update the shapes of the |
392 | // outputs. |
393 | const OpRegistrationData* op_reg_data; |
394 | TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data)); |
395 | if (op_reg_data->shape_inference_fn == nullptr && |
396 | require_shape_inference_fns_) { |
397 | return errors::InvalidArgument( |
398 | "No shape inference function exists for op '" , node->type_string(), |
399 | "', did you forget to define it?" ); |
400 | } |
401 | |
402 | if (!op_reg_data->shape_inference_fn) { |
403 | // There is nothing more we can infer |
404 | return OkStatus(); |
405 | } |
406 | |
407 | return RunShapeFn(node, op_reg_data, node_ext_context); |
408 | } |
409 | |
410 | Status ShapeRefiner::EvaluateConstantTensorForEdge( |
411 | const Node* node, int dst_idx, bool* evaluated, Tensor* result, |
412 | InferenceContext* outer_context) { |
413 | *evaluated = false; |
414 | const Edge* input_edge; |
415 | TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge)); |
416 | OutputTensor tensor(input_edge->src(), input_edge->src_output()); |
417 | return EvaluateConstantTensor( |
418 | tensor, *this, *ops_registry_, graph_def_version_, evaluated, result, |
419 | &graph_runner_, &const_tensor_map_, kMaxTensorSize, |
420 | disable_constant_propagation_, outer_context); |
421 | } |
422 | |
423 | Status ShapeRefiner::EvaluateConstantIntScalarEdge( |
424 | const Node* node, int dst_idx, bool* evaluated, int64_t* result, |
425 | shape_inference::InferenceContext* outer_context) { |
426 | Tensor scalar; |
427 | TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated, |
428 | &scalar, outer_context)); |
429 | if (*evaluated) { |
430 | if (scalar.NumElements() != 1) { |
431 | return errors::InvalidArgument( |
432 | "EvaluateConstantIntScalarEdge called on non-scalar edge: " , |
433 | scalar.NumElements()); |
434 | } |
435 | if (scalar.dtype() == DT_INT32) { |
436 | *result = scalar.scalar<int32>()(); |
437 | } else { |
438 | if (scalar.dtype() != DT_INT64) { |
439 | return errors::InvalidArgument( |
440 | "EvaluateConstantIntScalarEdge called on non-integer edge: " , |
441 | scalar.dtype()); |
442 | } |
443 | *result = scalar.scalar<int64_t>()(); |
444 | } |
445 | } |
446 | return OkStatus(); |
447 | } |
448 | |
449 | Status ShapeRefiner::ConstantPartialShape( |
450 | InferenceContext* target_context, const Node* node, int dst_idx, |
451 | ShapeHandle* result, shape_inference::InferenceContext* outer_context) { |
452 | const Edge* input_edge; |
453 | TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge)); |
454 | |
455 | InferenceContext* src_context = GetContext(input_edge->src()); |
456 | if (src_context == nullptr) return errors::Internal("Missing src context" ); |
457 | ShapeHandle src_shape = src_context->output(input_edge->src_output()); |
458 | |
459 | // All shapes are expected to be 1D integer tensors with the exception of the |
460 | // sentinel that represents an unknown shape (scalar/rank 0 tensor with -1 as |
461 | // value). Handle the special case first before considering the more general |
462 | // rank 1 case. |
463 | |
464 | if (src_context->Value(src_context->Rank(src_shape)) == 0) { |
465 | Tensor t; |
466 | bool evaluated = false; |
467 | TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, |
468 | &t, outer_context)); |
469 | if (!evaluated) { |
470 | return errors::InvalidArgument( |
471 | "Received a shape scalar with unknown static value. A static value " |
472 | "of '-1' is required to represent an unknown shape." ); |
473 | } |
474 | if (t.dims() == 0) { |
475 | if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) { |
476 | *result = target_context->UnknownShape(); |
477 | return OkStatus(); |
478 | } else if (t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1) { |
479 | *result = target_context->UnknownShape(); |
480 | return OkStatus(); |
481 | } |
482 | } |
483 | return errors::InvalidArgument( |
484 | "Received an invalid shape scalar with a static value that is not " |
485 | "'-1': " , |
486 | t.DebugString()); |
487 | } |
488 | |
489 | TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape)); |
490 | |
491 | const string& src_op = input_edge->src()->type_string(); |
492 | if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) { |
493 | // Source tensor is a vector of length 0, so the shape it |
494 | // represents is as scalar. |
495 | *result = target_context->Scalar(); |
496 | } else if (src_op == "Cast" ) { |
497 | // First try to evaluate the current tensor, as it might be a valid cast of |
498 | // a float. |
499 | Tensor t; |
500 | bool evaluated = false; |
501 | if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t, |
502 | outer_context) |
503 | .ok()) { |
504 | if (evaluated && |
505 | target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) { |
506 | return OkStatus(); |
507 | } |
508 | } |
509 | |
510 | // Then try to infer partial shape from the input to the cast tensor. |
511 | ShapeHandle pre_cast_shape; |
512 | if (!ConstantPartialShape(target_context, input_edge->src(), 0, |
513 | &pre_cast_shape, outer_context) |
514 | .ok()) { |
515 | TF_RETURN_IF_ERROR( |
516 | target_context->MakeShapeFromTensor(nullptr, src_shape, result)); |
517 | } |
518 | if (!target_context->RankKnown(pre_cast_shape)) { |
519 | // Failed to evaluate. Treat the output as completely unknown. |
520 | *result = target_context->UnknownShape(); |
521 | return OkStatus(); |
522 | } |
523 | auto* dest_type = input_edge->src()->attrs().Find("DstT" ); |
524 | if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType || |
525 | (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) { |
526 | // Casting to a weird type. Do not attempt to infer across it. |
527 | *result = target_context->MakeShape(std::vector<DimensionHandle>( |
528 | target_context->Rank(pre_cast_shape), target_context->UnknownDim())); |
529 | return OkStatus(); |
530 | } |
531 | *result = pre_cast_shape; |
532 | } else if (src_op == "Shape" ) { |
533 | *result = src_context->input(0); |
534 | } else if (src_op == "ShapeN" ) { |
535 | *result = src_context->input(input_edge->src_output()); |
536 | } else if (src_op == "Pack" ) { |
537 | std::vector<DimensionHandle> dims; |
538 | // Pack is concatenating its input scalars to form the shape tensor vector. |
539 | for (int i = 0; i < src_context->num_inputs(); ++i) { |
540 | int64_t size; |
541 | bool evaluated; |
542 | TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge( |
543 | input_edge->src(), i, &evaluated, &size, outer_context)); |
544 | if (evaluated) { |
545 | dims.push_back(size < 0 ? target_context->UnknownDim() |
546 | : target_context->MakeDim(size)); |
547 | } else { |
548 | dims.push_back(target_context->UnknownDim()); |
549 | } |
550 | } |
551 | *result = target_context->MakeShape(dims); |
552 | } else if (src_op == "Concat" || src_op == "ConcatV2" ) { |
553 | *result = target_context->Scalar(); |
554 | // For Concat, input 0 is concat dim; for V2 it is the last input. |
555 | const int concat_dim = |
556 | src_op == "Concat" ? 0 : src_context->num_inputs() - 1; |
557 | // Concat is concatenating its input shape vectors. |
558 | for (int i = 0; i < src_context->num_inputs(); ++i) { |
559 | // Concat dim is ignored (and will always be a scalar). |
560 | if (i == concat_dim) continue; |
561 | ShapeHandle sub_result; |
562 | TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(), |
563 | i, &sub_result, outer_context)); |
564 | if (!target_context->RankKnown(sub_result)) { |
565 | // Failed to evaluate. Treat the output as completely unknown. |
566 | // TODO(cwhipkey): we could rely on all inputs being the same rank, so |
567 | // figure that rank out and append the right number of unknown dims. |
568 | *result = target_context->UnknownShape(); |
569 | return OkStatus(); |
570 | } |
571 | TF_RETURN_IF_ERROR( |
572 | target_context->Concatenate(*result, sub_result, result)); |
573 | } |
574 | } else if (src_op == "StridedSlice" ) { |
575 | TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context, |
576 | result, outer_context)); |
577 | } else if (src_op == "VariableShape" ) { |
578 | auto* handle_data = src_context->input_handle_shapes_and_types(0); |
579 | if (handle_data != nullptr && !handle_data->empty()) { |
580 | *result = handle_data->at(0).shape; |
581 | } else { |
582 | *result = target_context->UnknownShape(); |
583 | } |
584 | } else { |
585 | Tensor t; |
586 | bool evaluated = false; |
587 | TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, |
588 | &t, outer_context)); |
589 | TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor( |
590 | evaluated ? &t : nullptr, src_shape, result)); |
591 | } |
592 | return OkStatus(); |
593 | } |
594 | |
595 | Status ShapeRefiner::PartialStridedSliceShape( |
596 | Node* slice_node, InferenceContext* ctx, ShapeHandle* result, |
597 | shape_inference::InferenceContext* outer_context) { |
598 | // Only attempt to evaluate if begin/end/strides all are scalars. |
599 | for (int i = 1; i <= 3; ++i) { |
600 | ShapeHandle input_shape = ctx->input(i); |
601 | if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) { |
602 | *result = ctx->UnknownShape(); |
603 | return OkStatus(); |
604 | } |
605 | } |
606 | |
607 | int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; |
608 | TF_RETURN_IF_ERROR( |
609 | GetNodeAttr(slice_node->attrs(), "begin_mask" , &begin_mask)); |
610 | TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask" , &end_mask)); |
611 | TF_RETURN_IF_ERROR( |
612 | GetNodeAttr(slice_node->attrs(), "ellipsis_mask" , &ellipsis_mask)); |
613 | TF_RETURN_IF_ERROR( |
614 | GetNodeAttr(slice_node->attrs(), "new_axis_mask" , &new_axis_mask)); |
615 | TF_RETURN_IF_ERROR( |
616 | GetNodeAttr(slice_node->attrs(), "shrink_axis_mask" , &shrink_axis_mask)); |
617 | |
618 | // Only attempt to evaluate if there are no special masks set (note that we |
619 | // can handle begin/end_mask == 1). |
620 | if (!(begin_mask == 0 || begin_mask == 1) || |
621 | !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 || |
622 | new_axis_mask != 0 || shrink_axis_mask != 0) { |
623 | *result = ctx->UnknownShape(); |
624 | return OkStatus(); |
625 | } |
626 | |
627 | bool evaluated; |
628 | int64_t begin; |
629 | if (begin_mask == 1) { |
630 | begin = 0; |
631 | } else { |
632 | TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, |
633 | &begin, outer_context)); |
634 | if (!evaluated) { |
635 | *result = ctx->UnknownShape(); |
636 | return OkStatus(); |
637 | } |
638 | } |
639 | |
640 | int64_t end; |
641 | if (end_mask == 1) { |
642 | end = std::numeric_limits<int64_t>::max(); |
643 | } else { |
644 | TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, |
645 | &end, outer_context)); |
646 | if (!evaluated) { |
647 | *result = ctx->UnknownShape(); |
648 | return OkStatus(); |
649 | } |
650 | } |
651 | |
652 | int64_t stride; |
653 | TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, |
654 | &stride, outer_context)); |
655 | if (!evaluated) { |
656 | *result = ctx->UnknownShape(); |
657 | return OkStatus(); |
658 | } |
659 | |
660 | // Apply stride to input interpreted as a partial shape. |
661 | ShapeHandle input; |
662 | TF_RETURN_IF_ERROR( |
663 | ConstantPartialShape(ctx, slice_node, 0, &input, outer_context)); |
664 | TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result)); |
665 | return OkStatus(); |
666 | } |
667 | |
668 | Status ShapeRefiner::RunShapeFn(const Node* node, |
669 | const OpRegistrationData* op_reg_data, |
670 | ExtendedInferenceContext* ec, |
671 | InferenceContext* outer_context) { |
672 | // This will be filled in with real data in a second pass. |
673 | std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr); |
674 | std::vector<Tensor> real_tensors(node->num_inputs()); |
675 | std::vector<bool> attempted_materialization(node->num_inputs()); |
676 | std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs()); |
677 | std::vector<ShapeHandle> input_tensors_as_shapes; |
678 | |
679 | auto* c = ec->get_context(); |
680 | |
681 | c->set_input_tensors(input_tensors); |
682 | c->set_input_tensors_as_shapes(input_tensors_as_shapes); |
683 | |
684 | // Run the shape inference function, and return if there was an error. |
685 | // Capture as lambda, because we might need to re-run inference later on. |
686 | auto run_inference_lambda = [&]() { |
687 | if (function_library_ && IsFunctionCall(*function_library_, *node)) { |
688 | bool disable_shape_inference; |
689 | if (!GetNodeAttr(AttrSlice(node->def()), "_disable_call_shape_inference" , |
690 | &disable_shape_inference) |
691 | .ok() || |
692 | !disable_shape_inference) { |
693 | // Special inference logic for user-defined functions. |
694 | NameAttrList function; |
695 | TF_RETURN_IF_ERROR( |
696 | NameAndAttrsFromFunctionCall(node->def(), &function)); |
697 | const FunctionDef* function_def = |
698 | function_library_->Find(function.name()); |
699 | if (function_def != nullptr) { |
700 | // The constant Tensor map we have for the outside context is not |
701 | // valid inside the function. We need to push a new clean map while |
702 | // performing inference on the function body. |
703 | auto const_tensor_map_copy = const_tensor_map_; |
704 | const_tensor_map_.clear(); |
705 | Status function_inference_status = InferShapesForFunction( |
706 | function_def, AttrSlice(&function.attr()), c); |
707 | const_tensor_map_ = const_tensor_map_copy; |
708 | return function_inference_status; |
709 | } |
710 | } |
711 | } |
712 | |
713 | if (op_reg_data->shape_inference_fn) { |
714 | TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn)); |
715 | } else { |
716 | TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape)); |
717 | } |
718 | return OkStatus(); |
719 | }; |
720 | TF_RETURN_IF_ERROR(run_inference_lambda()); |
721 | |
722 | // We must run the shape function repeatedly, in case users write |
723 | // shape functions where they only conditionally call input_tensor() |
724 | // based on the values of another input tensor. |
725 | bool rerun_shape_fn; |
726 | do { |
727 | // If the result of running shape inference would have benefitted |
728 | // from knowing the values of input tensors, try to materialize |
729 | // the results of those tensors, and then run the shape inference |
730 | // function again using those known tensors. |
731 | rerun_shape_fn = false; |
732 | |
733 | // NOTE: It is possible to batch the extraction and |
734 | // materialization of inputs, instead of materializing one input |
735 | // at a time like we do below. If input-at-a-time computation |
736 | // becomes a bottleneck, we could separate ExtractConstantSubgraph |
737 | // into two functions: one that returns true if an input is |
738 | // derivable from constants, and another function that extracts |
739 | // the subgraph for multiple target nodes and executes the whole |
740 | // subgraph once. |
741 | |
742 | for (int i = 0; i < c->num_inputs(); ++i) { |
743 | if (!c->requested_input_tensor(i)) { |
744 | continue; |
745 | } |
746 | // Check if we have not already filled in the requested input, |
747 | // and if not, try to materialize the tensors. |
748 | if (!attempted_materialization[i]) { |
749 | attempted_materialization[i] = true; |
750 | |
751 | Tensor result; |
752 | bool evaluated = false; |
753 | TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge( |
754 | node, i, &evaluated, &result, outer_context)); |
755 | if (evaluated) { |
756 | real_tensors[i] = result; |
757 | input_tensors[i] = &real_tensors[i]; |
758 | // We have more concrete information about a shape, |
759 | // so re-run shape inference. |
760 | rerun_shape_fn = true; |
761 | } |
762 | } |
763 | if (c->requested_input_tensor_as_partial_shape(i) && |
764 | !attempted_tensor_as_shape_conversion[i]) { |
765 | attempted_tensor_as_shape_conversion[i] = true; |
766 | if (i >= input_tensors_as_shapes.size()) { |
767 | input_tensors_as_shapes.resize(i + 1); |
768 | } |
769 | ShapeHandle s; |
770 | TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context)); |
771 | input_tensors_as_shapes[i] = s; |
772 | rerun_shape_fn = true; |
773 | } |
774 | } |
775 | |
776 | if (rerun_shape_fn) { |
777 | // We have more information about the shapes on this pass, |
778 | // so re-run shape inference. |
779 | c->set_input_tensors(input_tensors); |
780 | c->set_input_tensors_as_shapes(input_tensors_as_shapes); |
781 | TF_RETURN_IF_ERROR(run_inference_lambda()); |
782 | } |
783 | } while (rerun_shape_fn); |
784 | |
785 | return OkStatus(); |
786 | } |
787 | |
788 | bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0, |
789 | ShapeHandle s1) { |
790 | if (s0.SameHandle(s1)) { |
791 | return true; |
792 | } |
793 | if (c->Rank(s0) != c->Rank(s1)) { |
794 | return false; |
795 | } |
796 | if (!c->RankKnown(s0) && !c->RankKnown(s1)) { |
797 | return false; |
798 | } |
799 | for (int i = 0; i < c->Rank(s0); ++i) { |
800 | if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) { |
801 | int64_t val0 = c->Value(c->Dim(s0, i)); |
802 | int64_t val1 = c->Value(c->Dim(s1, i)); |
803 | if (val0 < 0 || val1 < 0 || val0 != val1) { |
804 | return false; |
805 | } |
806 | } |
807 | } |
808 | |
809 | return true; |
810 | } |
811 | |
812 | bool ShapeRefiner::IsUpdatedShapesOrTypes( |
813 | InferenceContext* c, const std::vector<ShapeAndType>& existing, |
814 | const std::vector<ShapeAndType>& updated) { |
815 | if (existing.size() != updated.size()) { |
816 | return true; |
817 | } |
818 | for (int i = 0; i < existing.size(); i++) { |
819 | if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) || |
820 | existing[i].dtype != updated[i].dtype) { |
821 | return true; |
822 | } |
823 | } |
824 | return false; |
825 | } |
826 | |
827 | } // namespace tensorflow |
828 | |