1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ |
16 | #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ |
17 | |
18 | #include <vector> |
19 | |
20 | #include "absl/memory/memory.h" |
21 | #include "tensorflow/core/framework/full_type.pb.h" |
22 | #include "tensorflow/core/framework/node_def_util.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/platform/macros.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | namespace grappler { |
30 | class GraphProperties; |
31 | class SymbolicShapeManager; |
32 | } // namespace grappler |
33 | |
34 | namespace shape_inference { |
35 | |
36 | struct DimensionOrConstant; |
37 | class InferenceContext; |
38 | |
39 | // This header contains the InferenceContext that is used to infer the shape of |
40 | // the results of an operation or flag an operation with invalid inputs (e.g., |
41 | // mismatched shapes for elementwise operation) by ShapeRefiner. The shape of an |
42 | // operation is computed using the OpShapeInferenceFn set via SetShapeFn in op |
43 | // registration. The OpShapeInferenceFn uses a per op InferenceContext populated |
44 | // with input shapes to compute resultant shape (including resource shapes). |
45 | // |
46 | // The shapes created in the InferenceContext are bound to the lifetime of the |
47 | // InferenceContext in which it was created. E.g., in |
48 | // |
49 | // ```c++ |
50 | // InferenceContext c; |
51 | // // Below a ShapeHandle is returned by MakeShape, while UnknownDim returns a |
52 | // // DimensionHandle. |
53 | // ShapeHandle in0 = c.MakeShape({10, c.UnknownDim()}); |
54 | // ``` |
55 | // |
56 | // the ShapeHandle `in0` (and the nested unknown dim inside) is only valid while |
57 | // `c` is in scope, as ShapeHandle and DimensionHandle are effectively |
58 | // wrappers around pointers stored inside the context with the lifetime of the |
59 | // value pointed to managed by the context. The result from one operation's |
60 | // inference context will be passed as input to the inference of consumer |
61 | // operations. Hence it is possible for ShapeHandles produced by inference on a |
62 | // node to consist of ShapeHandles owned by different InferenceContexts. While |
63 | // inferring the shapes of a Graph, the InferenceContext of all nodes/operations |
64 | // in the Graph remain resident for the lifetime of the Graph (e.g, there is a |
65 | // map from each node to its InferenceContext, technically its |
66 | // ExtendedInferencContext which additionally stores the element types of inputs |
67 | // & outputs, which remains resident). |
68 | // |
69 | // For functions, the body of the function is instantiated as a Graph while |
70 | // inferring the result shapes of a function call node. The rules above apply |
71 | // while the function's shape is being inferred, but the contexts associated |
72 | // with nodes in the function body are released once the function call's |
73 | // resultant shapes are inferred. The shapes of results returned by a function |
74 | // are propagated to the InferenceContext of the function call's op (which is |
75 | // associated with a Graph of nodes whose shape is being inferred) as the return |
76 | // values of a function call node are the inputs of its consumer, but the return |
77 | // values are produced by nodes inside the function whose InferenceContexts |
78 | // (which owns the values pointed to by ShapeHandle and DimensionHandle) are |
79 | // reclaimed after inferring function result shapes. Recursive user-defined |
80 | // function are not supported hence inference of functions are fully nested with |
81 | // the InferenceContext's of function calls forming a stack. |
82 | // |
83 | // For example, consider the following call and function: |
84 | // |
85 | // ```python |
86 | // @tf.function |
87 | // def g(st): |
88 | // d = tf.add(st, st) |
89 | // return d |
90 | // |
91 | // @tf.function |
92 | // def f(): |
93 | // st = tf.A() |
94 | // result = g(st) |
95 | // return h(result) |
96 | // ``` |
97 | // |
98 | // During inference of f, the shape of `A` will be inferred and the results from |
99 | // its InferenceContext used as inputs to function call `g(st)`. The call node |
100 | // will have an InferenceContext created (call it outer context) and the graph |
101 | // corresponding to function `g` will be instantiated. The result shape of the |
102 | // Arg nodes of the function will be associated with input from outer context. |
103 | // During inference of `g` (for the callsite `g(st)` in `f`), the |
104 | // InferenceContext of all nodes inside `g` will remain alive. Thus, when shape |
105 | // of `tf.add` is computed it may rely on all inputs. Once the RetVal nodes of a |
106 | // function is reached, we know the shape of its input may correspond to a shape |
107 | // queried in the outer context and it is explicitly copied to outer context. In |
108 | // this case that means that the shape of `d` is copied to the InferenceContext |
109 | // of `g(st)` and so when `h(result)` is executed this shape may be queried. |
110 | // Furthermore, no shapes computed due to call `g(st)` can be queried post this |
111 | // point and, as the RetVal shapes have been coppied into outer context, all |
112 | // InferenceContexts associated with nodes in function `g` instantiated for |
113 | // `g(st)` may be and are released. |
114 | |
115 | // Dimension values are accessed through InferenceContext. |
116 | class Dimension { |
117 | private: |
118 | Dimension(); |
119 | Dimension(int64_t value); |
120 | ~Dimension() {} |
121 | |
122 | const int64_t value_; |
123 | |
124 | friend class InferenceContext; |
125 | friend class ShapeManager; |
126 | TF_DISALLOW_COPY_AND_ASSIGN(Dimension); |
127 | }; |
128 | |
129 | class DimensionHandle { |
130 | public: |
131 | DimensionHandle() {} |
132 | bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; } |
133 | std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); } |
134 | |
135 | private: |
136 | DimensionHandle(const Dimension* dim) { ptr_ = dim; } |
137 | |
138 | const Dimension* operator->() const { return ptr_; } |
139 | bool IsSet() const { return ptr_ != nullptr; } |
140 | |
141 | const Dimension* ptr_ = nullptr; |
142 | |
143 | friend struct DimensionOrConstant; |
144 | friend class InferenceContext; |
145 | friend class ShapeInferenceTest; |
146 | friend class ShapeInferenceTestutil; |
147 | friend class ::tensorflow::grappler::GraphProperties; |
148 | friend class ::tensorflow::grappler::SymbolicShapeManager; |
149 | |
150 | // Intentionally copyable. |
151 | }; |
152 | |
153 | // Shape rank and dimensions are accessed through InferenceContext. |
154 | class Shape { |
155 | private: |
156 | Shape(); |
157 | Shape(const std::vector<DimensionHandle>& dims); |
158 | ~Shape() {} |
159 | |
160 | const int32 rank_; |
161 | const std::vector<DimensionHandle> dims_; |
162 | |
163 | friend class InferenceContext; |
164 | friend class ::tensorflow::grappler::SymbolicShapeManager; |
165 | |
166 | TF_DISALLOW_COPY_AND_ASSIGN(Shape); |
167 | }; |
168 | |
169 | class ShapeHandle { |
170 | public: |
171 | ShapeHandle() {} |
172 | bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; } |
173 | std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); } |
174 | |
175 | private: |
176 | ShapeHandle(const Shape* shape) { ptr_ = shape; } |
177 | const Shape* operator->() const { return ptr_; } |
178 | bool IsSet() const { return ptr_ != nullptr; } |
179 | |
180 | const Shape* ptr_ = nullptr; |
181 | |
182 | friend class InferenceContext; |
183 | friend class ShapeInferenceTest; |
184 | friend class ShapeInferenceTestutil; |
185 | friend class ::tensorflow::grappler::SymbolicShapeManager; |
186 | |
187 | // Intentionally copyable. |
188 | }; |
189 | |
190 | // Struct used to allow functions to take DimensionHandle or a dimension value. |
191 | // Not meant to be constructed directly. |
192 | struct DimensionOrConstant { |
193 | public: |
194 | // Intentionally not explicit. |
195 | DimensionOrConstant(DimensionHandle dim); |
196 | |
197 | // val must be non-negative or InferenceContext::kUnknownDim. |
198 | DimensionOrConstant(int64_t val); |
199 | |
200 | // dim takes precedence. If dim != nullptr, val is ignored. |
201 | DimensionHandle dim; |
202 | int64_t val; |
203 | |
204 | private: |
205 | DimensionOrConstant(); |
206 | }; |
207 | |
208 | struct ShapeAndType { |
209 | ShapeAndType() {} |
210 | ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {} |
211 | // TODO(mdan): Remove dtype from constructor, and use type_ instead. |
212 | // dtype is kept here for backward compatibiity. Its information should |
213 | // be redundant to that in type; |
214 | ShapeAndType(ShapeHandle s, DataType t, FullTypeDef type_) |
215 | : shape(s), dtype(t), type(type_) {} |
216 | |
217 | ShapeHandle shape; |
218 | DataType dtype = DT_INVALID; |
219 | FullTypeDef type; |
220 | }; |
221 | |
222 | // Shape inference functions registered on ops in REGISTER_OP implement |
223 | // their shape functions in terms of this InferenceContext. An InferenceContext |
224 | // is created by the framework and passed to a shape inference function. The |
225 | // shape inference function calls functions on the context, and should call |
226 | // set_output() to set the shape on all outputs. |
227 | // |
228 | // To infer shapes for user-defined functions see ShapeRefiner. |
229 | // |
230 | // All Shape* and Dimension* returned by functions of InferenceContext are owned |
231 | // by the InferenceContext. |
232 | class InferenceContext { |
233 | public: |
234 | static constexpr int64_t kUnknownDim = -1; |
235 | static constexpr int32_t kUnknownRank = -1; |
236 | |
237 | // <input_tensors> is NULL-padded to be the same size as <input_shapes>. |
238 | // |
239 | // Elements of <input_tensors_as_shapes> are used for when a shape function |
240 | // makes a call to MakeShapeFromShapeTensor; in particular, when the |
241 | // input_tensors[i] is nullptr but the shape represented by it is partially |
242 | // known from analysis of the graph. |
243 | // <input_tensors_as_shapes> can have fewer elements than <input_shapes>. |
244 | // Values of <input_tensors_as_shapes> do not need to outlive the context. |
245 | InferenceContext(int graph_def_version, const AttrSlice& attrs, |
246 | const OpDef& op_def, |
247 | const std::vector<ShapeHandle>& input_shapes, |
248 | const std::vector<const Tensor*>& input_tensors, |
249 | const std::vector<ShapeHandle>& input_tensors_as_shapes, |
250 | std::vector<std::unique_ptr<std::vector<ShapeAndType>>> |
251 | input_handle_shapes_and_types); |
252 | |
253 | // <input_tensors> is NULL-padded to be the same size as <input_shapes>. |
254 | // |
255 | // Elements of <input_tensors_as_shapes> are used for when a shape |
256 | // function makes a call to MakeShapeFromShapeTensor; in particular, when |
257 | // the input_tensors[i] is nullptr but the shape represented by it is |
258 | // partially known from analysis of the graph. <input_tensors_as_shapes> |
259 | // can have fewer elements than <input_shapes>. Values of |
260 | // <input_tensors_as_shapes> do not need to outlive the context. |
261 | InferenceContext( |
262 | int graph_def_version, const AttrSlice& attrs, const OpDef& op_def, |
263 | const std::vector<PartialTensorShape>& input_shapes, |
264 | const std::vector<const Tensor*>& input_tensors, |
265 | const std::vector<PartialTensorShape>& input_tensors_as_shapes, |
266 | const std::vector<std::unique_ptr< |
267 | std::vector<std::pair<PartialTensorShape, DataType>>>>& |
268 | input_handle_shapes_and_types); |
269 | |
270 | ~InferenceContext(); |
271 | |
272 | // Runs the shape inference function 'fn' with 'this' as the |
273 | // argument, returns the status of the inference. |
274 | // |
275 | // On error, additional context is provided in the error message. |
276 | Status Run( |
277 | const std::function<Status(shape_inference::InferenceContext* c)>& fn); |
278 | |
279 | // Merge the stored shape of the input in position idx with <shape> according |
280 | // to the following rules: |
281 | // |
282 | // - If the ShapeHandles are the same or <shape> is unknown, there will be no |
283 | // change. Otherwise if the stored shape is unknown, the new shape will be |
284 | // <shape>. |
285 | // - If both shapes are known, then they must have the same rank. |
286 | // - For any one dimension, if the values for that dimension in both shapes |
287 | // are known, then the values must match. |
288 | // - If one shape has equal or more information than the other shape in every |
289 | // dimension, the new shape will become the shape with more information. |
290 | // - Example: merging [2,?] and [?,2] results in [2,2] |
291 | // - Example: [2,2] cannot be merged with [1,2] |
292 | // |
293 | // This requires idx to be in the [0, num_inputs) range. If the merge is |
294 | // successful, return true. Return false otherwise. |
295 | bool MergeInput(int idx, ShapeHandle shape) { |
296 | ShapeHandle new_shape; |
297 | if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false; |
298 | inputs_[idx] = new_shape; |
299 | return true; |
300 | } |
301 | |
302 | // Relax the stored shape of the input in position idx with <shape> according |
303 | // to the following rules: |
304 | // |
305 | // - If the ShapeHandles are the same then the stored shape will be returned. |
306 | // - If either of the ShapeHandles are unknown, then a new UnknownShape will |
307 | // be returned. A new shape must be returned because we cannot claim that |
308 | // the resulting shape is necessarily the same as either of the input |
309 | // shapes. |
310 | // - If the shapes both have known ranks but their ranks are different, a new |
311 | // UnknownShape will be returned. |
312 | // - For any one dimension, if the value for that dimension in either of the |
313 | // shapes is unknown, a new shape will be returned with a new UnknownDim in |
314 | // that dimension. |
315 | // - For any one dimension, if the values for that dimension in both shapes |
316 | // are known but do not match, a new shape will be returned with a new |
317 | // UnknownDim in that dimension. |
318 | // - If both shapes have the same known rank and match in every dimension, |
319 | // the stored shape will be returned. |
320 | // - Example: relaxing [2,?] and [?,2] results in [?,?] |
321 | // - Example: relaxing [2,2] and [3,2] results in [?,2] |
322 | // - Example: relaxing [2,2] with [1,2,3] results in ? |
323 | // |
324 | // This requires idx to be in the [0, num_inputs) range. If the relax is |
325 | // successful and the new shape differs from the old one, store the new |
326 | // shape and return true. Return false otherwise. |
327 | bool RelaxInput(int idx, ShapeHandle shape) { |
328 | ShapeHandle new_shape; |
329 | Relax(inputs_[idx], shape, &new_shape); |
330 | if (inputs_[idx].SameHandle(new_shape)) { |
331 | return false; |
332 | } |
333 | inputs_[idx] = new_shape; |
334 | return true; |
335 | } |
336 | |
337 | void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; } |
338 | |
339 | ShapeHandle input(int64_t idx) const { return inputs_[idx]; } |
340 | Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const; |
341 | int num_inputs() const { return inputs_.size(); } |
342 | |
343 | // Returns the input tensor at index <idx>, or nullptr if the input tensor is |
344 | // not available at the time of shape inference. |
345 | const Tensor* input_tensor(int idx) { |
346 | // Mark that this idx was requested. |
347 | request_input_tensor(idx); |
348 | return input_tensors_[idx]; |
349 | } |
350 | |
351 | // Notifies the shape refiner that the value of the tensor at index <idx> |
352 | // is needed. The shape refiner tries to statically compute this tensor, |
353 | // and if successful re-runs the shape function with this tensor available |
354 | // in the call to 'input_tensor(idx)'. |
355 | void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; } |
356 | |
357 | // Returns true iff input_tensor(idx) was called by the shape function. |
358 | bool requested_input_tensor(int idx) const { |
359 | return requested_input_tensor_[idx]; |
360 | } |
361 | |
362 | // Notifies the shape refiner that the value of the tensor at index <idx> |
363 | // as a partial shape is needed. The shape refiner tries to statically compute |
364 | // this, and if successful re-runs the shape function with the |
365 | // computed PartialTensorShape available in the call to |
366 | // 'MakeShapeFromShapeTensor(idx, handle)' or |
367 | // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'. |
368 | void request_input_tensor_as_partial_shape(int idx) { |
369 | requested_input_tensor_as_partial_shape_[idx] = true; |
370 | } |
371 | |
372 | // Returns true if MakeShapeFromInputTensor was called but the constant |
373 | // input_tensor was not present. |
374 | bool requested_input_tensor_as_partial_shape(int idx) const { |
375 | return requested_input_tensor_as_partial_shape_[idx]; |
376 | } |
377 | |
378 | void set_input_tensors(const std::vector<const Tensor*>& input_tensors) { |
379 | input_tensors_ = input_tensors; |
380 | } |
381 | |
382 | void set_input_tensors_as_shapes( |
383 | const std::vector<ShapeHandle>& input_tensors_as_shapes) { |
384 | input_tensors_as_shapes_ = input_tensors_as_shapes; |
385 | } |
386 | |
387 | const std::vector<ShapeHandle>& input_tensors_as_shapes() const { |
388 | return input_tensors_as_shapes_; |
389 | } |
390 | |
391 | ShapeHandle output(int64_t idx) const { return outputs_.at(idx); } |
392 | void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; } |
393 | Status set_output(StringPiece output_name, |
394 | const std::vector<ShapeHandle>& shapes); |
395 | |
396 | int num_outputs() const { return outputs_.size(); } |
397 | ShapeHandle output(int idx) const { return outputs_.at(idx); } |
398 | Status output(StringPiece output_name, |
399 | std::vector<ShapeHandle>* output) const; |
400 | |
401 | // Returns the value for attribute named `attr_name`. |
402 | Status GetAttr(StringPiece attr_name, const AttrValue** attr_value) const { |
403 | return attrs_.Find(attr_name, attr_value); |
404 | } |
405 | const AttrValue* GetAttr(StringPiece attr_name) const { |
406 | return attrs_.Find(attr_name); |
407 | } |
408 | |
409 | const FullTypeDef& ret_types() const { return ret_types_; } |
410 | |
411 | // idx can be negative for an offset from end of dimensions. |
412 | // idx must be in the range [-1 * s.rank, s.rank). |
413 | DimensionHandle Dim(ShapeHandle s, int64_t idx) { |
414 | if (!s.Handle() || s->rank_ == kUnknownRank) { |
415 | return UnknownDim(); |
416 | } |
417 | return DimKnownRank(s, idx); |
418 | } |
419 | // As above, but asserts that the rank of the shape is known. |
420 | static DimensionHandle DimKnownRank(ShapeHandle s, int64_t idx) { |
421 | CHECK_NE(s->rank_, kUnknownRank); |
422 | if (idx < 0) { |
423 | return s->dims_[s->dims_.size() + idx]; |
424 | } |
425 | return s->dims_[idx]; |
426 | } |
427 | |
428 | static int32 Rank(ShapeHandle s) { |
429 | return s.IsSet() ? s->rank_ : kUnknownRank; |
430 | } |
431 | static bool RankKnown(ShapeHandle s) { |
432 | return (s.IsSet() && (Rank(s) != kUnknownRank)); |
433 | } |
434 | static inline int64_t Value(DimensionOrConstant d) { |
435 | return d.dim.IsSet() ? d.dim->value_ : d.val; |
436 | } |
437 | static inline bool ValueKnown(DimensionOrConstant d) { |
438 | return Value(d) != kUnknownDim; |
439 | } |
440 | |
441 | // Fills the output proto with the shape defined by the handle. |
442 | // "proto" is expected to be empty prior to the call. |
443 | void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto); |
444 | TensorShapeProto ShapeHandleToProto(ShapeHandle handle); |
445 | |
446 | // Returns true if the rank and all dimensions of the Shape are known. |
447 | bool FullyDefined(ShapeHandle s); |
448 | |
449 | // Returns the total number of elements, or an unknown dimension for an |
450 | // incomplete shape. |
451 | DimensionHandle NumElements(ShapeHandle s); |
452 | |
453 | std::string DebugString(ShapeHandle s); |
454 | std::string DebugString(DimensionHandle d); |
455 | std::string DebugString(const ShapeAndType& shape_and_type); |
456 | std::string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types); |
457 | |
458 | // Describes the whole context, for debugging purposes. |
459 | std::string DebugString() const; |
460 | |
461 | // If <shape> has rank <rank>, or its rank is unknown, return OK and return |
462 | // the shape with asserted rank in <*out>. Otherwise return an error. |
463 | // |
464 | // Note that <*out> may be set to <shape>. |
465 | Status WithRank(ShapeHandle shape, int64_t rank, |
466 | ShapeHandle* out) TF_MUST_USE_RESULT; |
467 | Status WithRankAtLeast(ShapeHandle shape, int64_t rank, |
468 | ShapeHandle* out) TF_MUST_USE_RESULT; |
469 | Status WithRankAtMost(ShapeHandle shape, int64_t rank, |
470 | ShapeHandle* out) TF_MUST_USE_RESULT; |
471 | |
472 | // If <dim> has value <value>, or its value is unknown, returns OK and returns |
473 | // the dimension with asserted value in <*out>. Otherwise returns an error. |
474 | // |
475 | // Note that <*out> may be set to <dim>. |
476 | Status WithValue(DimensionHandle dim, int64_t value, |
477 | DimensionHandle* out) TF_MUST_USE_RESULT; |
478 | |
479 | // Merges <s0> and <s1> and returns the merged shape in <*out>. See |
480 | // 'MergeInput' function for full details and examples. |
481 | Status Merge(ShapeHandle s0, ShapeHandle s1, |
482 | ShapeHandle* out) TF_MUST_USE_RESULT; |
483 | |
484 | // Asserts that <s>'s rank >= <prefix>'s rank, and the first |
485 | // <prefix.rank> dimensions of <s> are compatible with the dimensions of |
486 | // <prefix>. |
487 | // Returns the merged results in <*s_out> and <*prefix_out>. |
488 | Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out, |
489 | ShapeHandle* prefix_out) TF_MUST_USE_RESULT; |
490 | |
491 | // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0> |
492 | // and <d1> have incompatible values, returns an error. |
493 | // |
494 | // Note that <*out> may be set to <d0> or <d1>. |
495 | Status Merge(DimensionHandle d0, DimensionHandle d1, |
496 | DimensionHandle* out) TF_MUST_USE_RESULT; |
497 | |
498 | // Returns in <*out> a sub-shape of <s> with dimensions [start:]. |
499 | // <start> can be negative to index from the end of the shape. If <start> > |
500 | // rank of <s>, then an empty subshape is returned. |
501 | Status Subshape(ShapeHandle s, int64_t start, |
502 | ShapeHandle* out) TF_MUST_USE_RESULT; |
503 | |
504 | // Returns in <*out> a sub-shape of <s>, with dimensions [start:end]. |
505 | // <start> and <end> can be negative, to index from the end of the shape. |
506 | // <start> and <end> are set to the rank of <s> if > rank of <s>. |
507 | Status Subshape(ShapeHandle s, int64_t start, int64_t end, |
508 | ShapeHandle* out) TF_MUST_USE_RESULT; |
509 | |
510 | // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride]. |
511 | // <start> and <end> can be negative, to index from the end of the shape. |
512 | // <start> and <end> are set to the rank of <s> if > rank of <s>. |
513 | // <stride> can be negative, to reverse the <s>. |
514 | Status Subshape(ShapeHandle s, int64_t start, int64_t end, int64_t stride, |
515 | ShapeHandle* out) TF_MUST_USE_RESULT; |
516 | |
517 | // Returns in <*out> the result of appending the dimensions of <s2> to those |
518 | // of <s1>. |
519 | Status Concatenate(ShapeHandle s1, ShapeHandle s2, |
520 | ShapeHandle* out) TF_MUST_USE_RESULT; |
521 | |
522 | // Returns in <out> the shape from replacing <s.dim[dim_index]> with |
523 | // <new_dim>. |
524 | Status ReplaceDim(ShapeHandle s, int64_t dim_index, DimensionHandle new_dim, |
525 | ShapeHandle* out) TF_MUST_USE_RESULT; |
526 | |
527 | // Returns a new shape with the given dims. The returned value is owned by |
528 | // this context. |
529 | ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims); |
530 | ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims); |
531 | |
532 | // Returns a new unknown shape. |
533 | ShapeHandle UnknownShape(); |
534 | |
535 | // Returns a shape with specified rank but unknown dims. |
536 | ShapeHandle UnknownShapeOfRank(int64_t rank); |
537 | |
538 | // Returns a new shape of zero dimensions. |
539 | ShapeHandle Scalar(); |
540 | |
541 | // Returns a new shape of one dimension. |
542 | ShapeHandle Vector(DimensionOrConstant dim); |
543 | |
544 | // Returns a new shape of two dimensions. |
545 | ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2); |
546 | |
547 | // Returns in <out> a new shape whose dimension sizes come from input tensor |
548 | // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If |
549 | // the input tensor is NULL, then an unknown shape is returned. |
550 | Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out); |
551 | |
552 | // Like the function above, but treats scalar values as unknown |
553 | // shapes. **NOTE** If the scalar is statically known, its value |
554 | // must be -1 or an error is returned. |
555 | Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx, |
556 | ShapeHandle* out); |
557 | |
558 | // Returns in <out> a new shape corresponding to <proto>. |
559 | Status MakeShapeFromShapeProto(const TensorShapeProto& proto, |
560 | ShapeHandle* out); |
561 | |
562 | // Returns in <out> a new shape corresponding to <partial_shape>. |
563 | Status MakeShapeFromPartialTensorShape( |
564 | const PartialTensorShape& partial_shape, ShapeHandle* out); |
565 | |
566 | // Returns in <out> a new shape corresponding to <shape>. |
567 | Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out); |
568 | StatusOr<ShapeHandle> MakeShapeFromShapeTensor(const TensorShape& shape); |
569 | |
570 | // Returns a new dimension of the given size. The returned value is owned by |
571 | // this context. |
572 | inline DimensionHandle MakeDim(DimensionOrConstant d) { |
573 | return shape_manager_.MakeDim(d); |
574 | } |
575 | |
576 | inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } |
577 | |
578 | // Returns in <val> a scalar value from an input tensor <t>. The input tensor |
579 | // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the |
580 | // input tensor is not NULL. |
581 | Status GetScalarFromTensor(const Tensor* t, int64_t* val); |
582 | |
583 | // Returns in <val> a scalar value from a 1D input tensor <t> with int32 or |
584 | // int64 elements. Caller must ensure that the input tensor is not NULL. |
585 | Status GetScalarFromTensor(const Tensor* t, int64_t idx, int64_t* val); |
586 | |
587 | // Returns a new dimension whose value is given by a scalar input tensor. |
588 | // The input tensor must be in host memory, since it is dereferenced to get |
589 | // the value. |
590 | Status MakeDimForScalarInput(int idx, DimensionHandle* out); |
591 | |
592 | // Returns a new dimension whose value is given by a scalar input tensor. |
593 | // This allows for a negative input dimension given the rank of a separate |
594 | // tensor. This rank can be negative if unknown. |
595 | // The input tensor must be in host memory, since it is dereferenced to get |
596 | // the value. |
597 | Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank, |
598 | DimensionHandle* out); |
599 | |
600 | // Look up the attr being evaluated with name attr_name and set *value to its |
601 | // value. If no attr with attr_name is found in def(), or the attr does not |
602 | // have a matching type, a non-ok status will be returned. |
603 | template <class T> |
604 | Status GetAttr(StringPiece attr_name, T* value) const; |
605 | |
606 | // Returns in <out> the result of dividing <dividend> by <divisor>. |
607 | // Returns an error if <divisor> is not positive or if <evenly_divisible> |
608 | // and <divisor> does not evenly divide <dividend>. |
609 | Status Divide(DimensionHandle dividend, DimensionOrConstant divisor, |
610 | bool evenly_divisible, DimensionHandle* out); |
611 | |
612 | // Returns in <out> the sum of <first> and <second>. |
613 | Status Add(DimensionHandle first, DimensionOrConstant second, |
614 | DimensionHandle* out); |
615 | |
616 | // Returns in <out> the dimension that is <first> minus <second>. |
617 | Status Subtract(DimensionHandle first, DimensionOrConstant second, |
618 | DimensionHandle* out); |
619 | |
620 | // Returns in <out> the product of <first> and <second>. |
621 | Status Multiply(DimensionHandle first, DimensionOrConstant second, |
622 | DimensionHandle* out); |
623 | |
624 | // Returns in <out> the minimum of <first> and <second>. If either <first> or |
625 | // <second> is zero the results is zero. Otherwise, if either <first> or |
626 | // <second> is unknown the results is unknown. |
627 | Status Min(DimensionHandle first, DimensionOrConstant second, |
628 | DimensionHandle* out); |
629 | |
630 | // Returns in <out> the maximum of <first> and <second>. If either <first> or |
631 | // <second> is unknown the results is unknown. |
632 | Status Max(DimensionHandle first, DimensionOrConstant second, |
633 | DimensionHandle* out); |
634 | |
635 | Status construction_status() const { return construction_status_; } |
636 | |
637 | // Methods to propagate shape and dtype on edges of handles. Handles are the |
638 | // dtype DT_RESOURCE which can be used to access state stored in a |
639 | // ResourceManager. When ops (such as variables) consume these handles to |
640 | // produce tensors they might need to know side-information about the shapes |
641 | // and dtypes of tensors which can be accessed via the handle. These methods |
642 | // propagate that information. Output handle dtypes and shapes are ignored if |
643 | // the output tensor is not of type DT_RESOURCE. |
644 | |
645 | // Merge the stored shapes and types corresponding to the input handle in |
646 | // position idx with the specified shapes and types. This requires idx to be |
647 | // in the [0, num_inputs) range. |
648 | // |
649 | // If the merge is successful and any of the new shapes differs from the old |
650 | // one, or any of the old dtypes was DT_INVALID, store the new shapes and |
651 | // return true. Return false otherwise. |
652 | // |
653 | // See 'MergeInput' function for full details and examples. |
654 | bool MergeInputHandleShapesAndTypes( |
655 | int idx, |
656 | const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; |
657 | |
658 | // As MergeInputHandleShapesAndTypes, but for an output. |
659 | bool MergeOutputHandleShapesAndTypes( |
660 | int idx, |
661 | const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; |
662 | |
663 | // Relaxes the stored shapes and types corresponding to the input handle in |
664 | // position idx with the specified shapes and types. This requires idx to be |
665 | // in the [0, num_inputs) range. |
666 | // |
667 | // If the relax is successful (sizes are the same, old dtypes match new ones |
668 | // or are DT_INVALID), then store the relaxed shapes and return true. |
669 | // Return false otherwise. |
670 | // |
671 | // See 'RelaxInput' function for full details and examples. |
672 | bool RelaxInputHandleShapesAndMergeTypes( |
673 | int idx, |
674 | const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; |
675 | |
676 | // As RelaxInputHandleShapesAndTypes, but for an output. |
677 | bool RelaxOutputHandleShapesAndMergeTypes( |
678 | int idx, |
679 | const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; |
680 | |
681 | void set_input_handle_shapes_and_types( |
682 | int idx, const std::vector<ShapeAndType>& shapes_and_types) { |
683 | input_handle_shapes_and_types_[idx] = |
684 | absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types); |
685 | } |
686 | |
687 | // Returns the output handle shapes and types, for the resource tensor output |
688 | // at index <idx>. Returns NULL if the shape and types were never set. |
689 | const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) { |
690 | return output_handle_shapes_and_types_[idx].get(); |
691 | } |
692 | |
693 | // Returns the inputs handle shapes and types, for the resource tensor input |
694 | // at index <idx>. Returns NULL if the shape and types were not available. |
695 | const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) { |
696 | return input_handle_shapes_and_types_[idx].get(); |
697 | } |
698 | |
699 | void set_output_handle_shapes_and_types( |
700 | int idx, const std::vector<ShapeAndType>& shapes_and_types) { |
701 | output_handle_shapes_and_types_[idx] = |
702 | absl::make_unique<std::vector<ShapeAndType>>(shapes_and_types); |
703 | } |
704 | |
705 | // Note that shape functions should usually call MakeShapeFromShapeTensor, |
706 | // as it does more analysis to provide partial shapes. |
707 | // |
708 | // Returns in <out> a new shape whose dimension sizes come from tensor <t>. |
709 | // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL, |
710 | // then an unknown shape is returned. |
711 | Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, |
712 | ShapeHandle* out); |
713 | |
714 | int graph_def_version() const { return graph_def_version_; } |
715 | |
716 | const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const { |
717 | return merged_shapes_; |
718 | } |
719 | const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims() |
720 | const { |
721 | return merged_dims_; |
722 | } |
723 | |
724 | // Adds new outputs; useful when mutating the graph. |
725 | Status ExpandOutputs(int new_output_size); |
726 | |
727 | private: |
728 | // Creates and stores shapes for use in InferenceContext. |
729 | class ShapeManager { |
730 | public: |
731 | ShapeManager(); |
732 | ~ShapeManager(); |
733 | |
734 | // Returns a new shape with the given dims. The returned value is owned by |
735 | // this class. |
736 | ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims); |
737 | |
738 | // Returns a new unknown shape. |
739 | ShapeHandle UnknownShape(); |
740 | |
741 | // Returns a new dimension of the given size. The returned value |
742 | // is owned by this class. |
743 | inline DimensionHandle MakeDim(DimensionOrConstant d) { |
744 | if (d.dim.IsSet()) { |
745 | return d.dim; |
746 | } else { |
747 | all_dims_.push_back(new Dimension(d.val)); |
748 | return all_dims_.back(); |
749 | } |
750 | } |
751 | |
752 | private: |
753 | std::vector<Shape*> all_shapes_; // values are owned. |
754 | std::vector<Dimension*> all_dims_; // values are owned. |
755 | }; |
756 | |
757 | friend class ::tensorflow::grappler::GraphProperties; |
758 | |
759 | friend class ShapeInferenceTest; // For testing Relax functions. |
760 | friend class ShapeInferenceTestutil; // For testing shapes. |
761 | |
762 | // Shared initialization across the two constructors. Remove |
763 | // once we get rid of one of them. |
764 | void PreInputInit(const OpDef& op_def, |
765 | const std::vector<const Tensor*>& input_tensors, |
766 | const std::vector<ShapeHandle>& input_tensors_as_shapes); |
767 | void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> |
768 | input_handle_data); |
769 | |
770 | Status ReturnUnknownShape(ShapeHandle* out) { |
771 | *out = UnknownShape(); |
772 | return OkStatus(); |
773 | } |
774 | Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims, |
775 | ShapeHandle* out) { |
776 | *out = MakeShape(dims); |
777 | return OkStatus(); |
778 | } |
779 | |
780 | // Adds additional context to the given status. |
781 | Status AttachContext(const Status& status); |
782 | |
783 | // Relaxes an existing value <d_old> with a new value <d_new> and returns the |
784 | // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible |
785 | // values, returns an error. |
786 | // |
787 | // Note that <*out> may be set to <d_old> or <d_new>. |
788 | void Relax(DimensionHandle d_old, DimensionHandle d_new, |
789 | DimensionHandle* out); |
790 | // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the |
791 | // relaxed shape in <*out>. See 'RelaxInput' function for full details and |
792 | // examples. |
793 | void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out); |
794 | |
795 | // Used to implement MergeInputHandleShapesAndTypes and |
796 | // MergeOutputHandleShapesAndTypes. |
797 | bool MergeHandleShapesAndTypes( |
798 | const std::vector<ShapeAndType>& shapes_and_types, |
799 | std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT; |
800 | // Used to implement RelaxInputHandleShapesAndMergeTypes and |
801 | // RelaxOutputHandleShapesAndMergeTypes. |
802 | bool RelaxHandleShapesAndMergeTypes( |
803 | const std::vector<ShapeAndType>& shapes_and_types, |
804 | std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT; |
805 | |
806 | // Forget all the previous merged shapes and dims. |
807 | void ForgetMerges() { |
808 | merged_shapes_.clear(); |
809 | merged_dims_.clear(); |
810 | } |
811 | |
812 | // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor. |
813 | Status InternalMakeShapeFromTensor( |
814 | bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t, |
815 | ShapeHandle tensor_shape, ShapeHandle* out); |
816 | |
817 | ShapeManager shape_manager_; |
818 | |
819 | // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from |
820 | // `shape_manager_`. |
821 | std::vector<ShapeHandle> inputs_; |
822 | std::vector<const Tensor*> input_tensors_; |
823 | std::vector<bool> requested_input_tensor_; |
824 | std::vector<ShapeHandle> outputs_; |
825 | // Can have fewer elements than inputs_. |
826 | std::vector<ShapeHandle> input_tensors_as_shapes_; |
827 | std::vector<bool> requested_input_tensor_as_partial_shape_; |
828 | |
829 | // input_handle_shapes_and_types_[i] is the list of shape/type pairs available |
830 | // through the resource handle passed along input i of the node. |
831 | // |
832 | // Values may be NULL. |
833 | std::vector<std::unique_ptr<std::vector<ShapeAndType>>> |
834 | input_handle_shapes_and_types_; |
835 | |
836 | // output_handle_shapes_and_types_[i] is the list of shape/type pairs |
837 | // available through the resource handle passed along output i of the node. |
838 | // |
839 | // Values may be NULL. |
840 | std::vector<std::unique_ptr<std::vector<ShapeAndType>>> |
841 | output_handle_shapes_and_types_; |
842 | |
843 | // Return types for the node this context is associated with. This information |
844 | // is to eventually consolidate all the dtype and shape info, allowing for |
845 | // output_handle_shapes_and_types_ to be removed. |
846 | FullTypeDef ret_types_; |
847 | |
848 | const int graph_def_version_; |
849 | AttrSlice attrs_; |
850 | NameRangeMap input_name_map_; |
851 | NameRangeMap output_name_map_; |
852 | |
853 | // An error set during construction. TODO(cwhipkey): remove when test |
854 | // constructor is removed. |
855 | Status construction_status_; |
856 | |
857 | // Pair of shape or dim handles that are equivalent, ie that represent the |
858 | // same underlying shape of dimension. Note that for each pair at least one of |
859 | // the handles must contain an unknown shape, since we don't keep track of |
860 | // known shapes or dims here. |
861 | std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_; |
862 | std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_; |
863 | |
864 | TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext); |
865 | }; |
866 | |
867 | // ----------------------------------------------------------------------------- |
868 | // Template and inline method implementations, please ignore |
869 | |
870 | inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} |
871 | inline Dimension::Dimension(int64_t value) : value_(value) { |
872 | DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) |
873 | << "Dimension must be non-negative or equal to " |
874 | "InferenceContext::kUnknownDim but got " |
875 | << value; |
876 | } |
877 | |
878 | inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {} |
879 | inline Shape::Shape(const std::vector<DimensionHandle>& dims) |
880 | : rank_(dims.size()), dims_(dims) {} |
881 | |
882 | inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim) |
883 | : dim(dim) { |
884 | DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension." ; |
885 | } |
886 | |
887 | inline DimensionOrConstant::DimensionOrConstant(int64_t val) : val(val) { |
888 | DCHECK(val >= 0 || val == InferenceContext::kUnknownDim) |
889 | << "Dimension must be non-negative or equal to " |
890 | "InferenceContext::kUnknownDim but got " |
891 | << val; |
892 | } |
893 | |
894 | template <class T> |
895 | Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { |
896 | return GetNodeAttr(attrs_, attr_name, value); |
897 | } |
898 | |
899 | } // namespace shape_inference |
900 | } // namespace tensorflow |
901 | |
902 | #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ |
903 | |