1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#pragma once
6
7#include <functional>
8#include "onnx/defs/data_type_utils.h"
9#include "onnx/proto_utils.h"
10#include "onnx/string_utils.h"
11
12namespace ONNX_NAMESPACE {
13
14using Dim = TensorShapeProto_Dimension;
15
16struct ShapeInferenceOptions {
17 // Checks the type-equality for input and output
18 bool check_type;
19 // 1: Will throw any node level shape infer errors
20 // 0: Won't throw node-level shape infer errors, but other errors
21 // like merging existing shape with inferred etc are thrown
22 int error_mode;
23 // Enables data propagation for limited operators
24 // to perform shape computation
25 bool enable_data_propagation;
26 ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
27 : check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
28};
29
30// Maintains a SymbolTable for symbolic shape inference
31class SymbolTable {
32 public:
33 // Adds existing symbols from a main graph or subgraph
34 virtual void addFromGraph(const GraphProto& g) = 0;
35 // Creates a new symbol which is not duplicate as any existing one
36 virtual std::string createNew(const std::string& symbol_prefix) = 0;
37 virtual ~SymbolTable() = default;
38};
39
40class GraphInferencer {
41 public:
42 // Perform inferencing on the graph contained in GraphInferencer.
43 // Returns the graph output types post-inferencing.
44 virtual std::vector<const TypeProto*> doInferencing(
45 const std::vector<const TypeProto*>& inputTypes,
46 const std::vector<const TensorProto*>& inputData) = 0;
47 virtual ~GraphInferencer() = default;
48};
49
50// Exception class used for handling errors in type and shape inference
51
52class InferenceError final : public std::runtime_error {
53 public:
54 using std::runtime_error::runtime_error;
55
56 InferenceError(const std::string& message) : std::runtime_error(message) {}
57
58 const char* what() const noexcept override {
59 if (!expanded_message_.empty()) {
60 return expanded_message_.c_str();
61 }
62 return std::runtime_error::what();
63 }
64
65 void AppendContext(const std::string& context) {
66 expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
67 }
68
69 private:
70 std::string expanded_message_;
71};
72
73#define fail_type_inference(...) \
74 ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
75
76#define fail_shape_inference(...) \
77 ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
78
79struct InferenceContext {
80 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
81 virtual size_t getNumInputs() const = 0;
82 virtual const TypeProto* getInputType(size_t index) const = 0;
83 virtual bool hasInput(size_t index) const {
84 // The default implementation below is used for backward-compatibility
85 // for implementations of InferenceContext that don't provide an explicit
86 // implementation. This works for normal usage, but may be imprecise in
87 // the edge-case where an input is supplied but has no known type.
88 // However, inference-methods work only under the assumption that the
89 // input-types of all inputs are known.
90 return ((index < getNumInputs()) && (getInputType(index) != nullptr));
91 }
92 virtual const TensorProto* getInputData(size_t index) const = 0;
93 virtual size_t getNumOutputs() const = 0;
94 virtual TypeProto* getOutputType(size_t index) = 0;
95 virtual GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) = 0;
96 virtual ~InferenceContext() {}
97 virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0;
98 // Gets the shape inputs computed by partial data propagation.
99 virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0;
100};
101
102// We use data propagation to perform partial evaluation of the model, to compute statically
103// known information about tensor values. It is intended to improve the precision of shape
104// inference. We reuse TensorShapeProto to represent the statically known values. One
105// limitation of this is that TensorShapeProto can represent only integer values.
106// As an example, data-propagation is intended to handle code-fragments like below:
107// shape = Shape(X)
108// batchsize = Slice(shape, [0], [1])
109// newshape = Concat (batchsize, [1024, 1024])
110// Z = Reshape(Y, newshape)
111// If the shape of X is statically known, then data-propagation should be able to determine
112// the value of newshape, as well as the shape of Z.
113struct DataPropagationContext {
114 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
115 virtual size_t getNumInputs() const = 0;
116 virtual const TypeProto* getInputType(size_t index) const = 0;
117 virtual size_t getNumOutputs() const = 0;
118 virtual const TypeProto* getOutputType(size_t index) const = 0;
119 virtual ~DataPropagationContext() {}
120 virtual const TensorShapeProto* getInputData(size_t index) = 0;
121 virtual void addOutputData(size_t index, TensorShapeProto&& tp) = 0;
122};
123
124using InferenceFunction = std::function<void(InferenceContext&)>;
125using DataPropagationFunction = std::function<void(DataPropagationContext&)>;
126
127// This no-op inference function is used for operators without an
128// inference implementation.
129inline void dummyInferenceFunction(InferenceContext&){};
130
131// This no-op data propagation function is used for operators without a defined data propagator
132inline void dummyDataPropagationFunction(DataPropagationContext&){};
133
134template <typename T>
135inline bool getRepeatedAttribute(InferenceContext& ctx, std::string attr_name, std::vector<T>& values) {
136 const auto* attr = ctx.getAttribute(attr_name);
137 if (attr) {
138 values = RetrieveValues<T>(*attr);
139 return true;
140 } else {
141 return false;
142 }
143}
144
145inline int64_t getAttribute(InferenceContext& ctx, const std::string& attributeName, int64_t defaultValue) {
146 auto attr_proto = ctx.getAttribute(attributeName);
147 if ((nullptr != attr_proto) && attr_proto->has_i())
148 return attr_proto->i();
149 return defaultValue;
150}
151
152inline int64_t getAttribute(DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) {
153 auto attr_proto = ctx.getAttribute(attributeName);
154 if ((nullptr != attr_proto) && attr_proto->has_i())
155 return attr_proto->i();
156 return defaultValue;
157}
158
159inline std::string
160getAttribute(InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) {
161 auto attr_proto = ctx.getAttribute(attributeName);
162 if ((nullptr != attr_proto) && attr_proto->has_s())
163 return attr_proto->s();
164 return defaultValue;
165}
166
167inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, TensorShapeProto::Dimension dim2) {
168 TensorShapeProto::Dimension result;
169 if (dim1.has_dim_value() && dim2.has_dim_value()) {
170 result.set_dim_value(dim1.dim_value() * dim2.dim_value());
171 } else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) {
172 return dim2;
173 } else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) {
174 return dim1;
175 }
176 return result;
177}
178
179inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, int64_t dim2) {
180 TensorShapeProto::Dimension result;
181 if (dim1.has_dim_value()) {
182 result.set_dim_value(dim1.dim_value() * dim2);
183 } else if (dim2 == 1) {
184 return dim1;
185 }
186 return result;
187}
188
189inline TensorShapeProto::Dimension operator/(TensorShapeProto::Dimension dim1, int64_t dim2) {
190 TensorShapeProto::Dimension result;
191 if (dim1.has_dim_value()) {
192 result.set_dim_value(dim1.dim_value() / dim2);
193 } else if (dim2 == 1) {
194 return dim1;
195 }
196 return result;
197}
198
199// if from >= upto_exclusive, return 1.
200// Caller must make sure upto_exclusive is less than or equal to shape.size()
201// Caller must make sure from>=0
202inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
203 TensorShapeProto::Dimension dim;
204 dim.set_dim_value(1);
205 for (int i = from; i < upto_exclusive; ++i) {
206 dim = dim * shape.dim(i);
207 }
208 return dim;
209}
210
211inline int32_t getTensorElementType(const TypeProto& type) {
212 int32_t result = TensorProto::UNDEFINED;
213 const auto value_case = type.value_case();
214 if (value_case == TypeProto::kTensorType) {
215 result = type.tensor_type().elem_type();
216 } else if (value_case == TypeProto::kSparseTensorType) {
217 result = type.sparse_tensor_type().elem_type();
218 }
219 return result;
220}
221
222inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto& type) {
223 if (value_case == TypeProto::kTensorType) {
224 type.mutable_tensor_type()->set_elem_type(elem_type);
225 } else if (value_case == TypeProto::kSparseTensorType) {
226 type.mutable_sparse_tensor_type()->set_elem_type(elem_type);
227 }
228}
229
230void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type);
231
232void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
233
234void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
235
236inline void propagateElemTypeFromDtypeToOutput(
237 InferenceContext& ctx,
238 const int data_type,
239 size_t outputIndex,
240 TypeProto::ValueCase expected_value_case) {
241 const auto attribute_tensor_datatype = data_type;
242 auto output_type = ctx.getOutputType(outputIndex);
243 const auto output_value_case = output_type->value_case();
244 if (output_value_case == TypeProto::VALUE_NOT_SET || output_value_case == expected_value_case) {
245 setTensorElementType(attribute_tensor_datatype, expected_value_case, *output_type);
246 } else {
247 // This is not expected to happen
248 fail_type_inference(
249 "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case);
250 }
251}
252
253inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const int data_type, size_t outputIndex) {
254 propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, TypeProto::kTensorType);
255}
256
257inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const AttributeProto* attr, size_t outputIndex) {
258 int32_t data_type = TensorProto::UNDEFINED;
259 TypeProto::ValueCase expected_value_case = TypeProto::VALUE_NOT_SET;
260 const auto attr_type = attr->type();
261 if (attr_type == AttributeProto::TENSOR) {
262 if (attr->t().dims().size() != 1) {
263 fail_type_inference("Attribute expected to have a one-dim tensor");
264 }
265 data_type = attr->t().data_type();
266 expected_value_case = TypeProto::kTensorType;
267 } else if (attr_type == AttributeProto::SPARSE_TENSOR) {
268 if (attr->sparse_tensor().dims().size() != 1) {
269 fail_type_inference("Attribute expected to have a one-dim sparse tensor");
270 }
271 data_type = attr->sparse_tensor().values().data_type();
272 expected_value_case = TypeProto::kSparseTensorType;
273 } else {
274 fail_type_inference("Attribute expected to have tensor or sparse tensor type");
275 }
276
277 propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case);
278}
279
280inline bool hasShape(const TypeProto& type) {
281 if (type.has_tensor_type()) {
282 return type.tensor_type().has_shape();
283 } else if (type.has_sparse_tensor_type()) {
284 return type.sparse_tensor_type().has_shape();
285 } else if (type.has_sequence_type() && type.sequence_type().has_elem_type()) {
286 return hasShape(type.sequence_type().elem_type());
287 } else if (type.has_optional_type() && type.optional_type().has_elem_type()) {
288 return hasShape(type.optional_type().elem_type());
289 }
290 return false;
291}
292
293template <typename Context>
294inline bool hasInputShape(Context& ctx, size_t n) {
295 return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) && hasShape(*ctx.getInputType(n));
296}
297
298template <typename Context>
299inline bool hasNInputShapes(Context& ctx, size_t n) {
300 for (size_t i = 0; i < n; i++) {
301 if (!hasInputShape(ctx, i)) {
302 return false;
303 }
304 }
305 return true;
306}
307
308inline const TensorShapeProto& getInputShape(InferenceContext& ctx, size_t n) {
309 const auto* input_type = ctx.getInputType(n);
310 const auto value_case = input_type->value_case();
311 if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
312 fail_type_inference("Attribute expected to have tensor or sparse tensor type");
313 }
314 if (value_case == TypeProto::kTensorType) {
315 return input_type->tensor_type().shape();
316 } else {
317 return input_type->sparse_tensor_type().shape();
318 }
319}
320
321inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size_t n) {
322 const auto* input_type = ctx.getInputType(n);
323
324 if (input_type == nullptr) {
325 return nullptr;
326 }
327
328 const auto value_case = input_type->value_case();
329 if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
330 fail_type_inference("Attribute expected to have tensor or sparse tensor type");
331 }
332 if (value_case == TypeProto::kTensorType) {
333 return &input_type->tensor_type().shape();
334 } else {
335 return &input_type->sparse_tensor_type().shape();
336 }
337}
338
339// Caller must make sure fromDimIndex is strictly less than shape.dim_size()
340inline void appendSingleDimCopiedFromInputTypeToOutputType(
341 InferenceContext& ctx,
342 size_t inputIndex,
343 size_t outputIndex,
344 size_t fromDimIndex) {
345 auto output_type = ctx.getOutputType(outputIndex);
346 const auto output_value_case = output_type->value_case();
347 auto input_type = ctx.getInputType(inputIndex);
348 const auto input_value_case = input_type->value_case();
349 if (output_value_case != input_value_case) {
350 fail_type_inference(
351 "Input: ",
352 inputIndex,
353 " type: ",
354 input_value_case,
355 " does not match type of output: ",
356 outputIndex,
357 "type: ",
358 output_value_case);
359 }
360 if (TypeProto::kTensorType == input_value_case) {
361 auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim();
362 *dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
363 } else if (TypeProto::kSparseTensorType == input_value_case) {
364 auto* dim = output_type->mutable_sparse_tensor_type()->mutable_shape()->add_dim();
365 *dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex));
366 } else {
367 fail_type_inference(
368 "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type");
369 }
370}
371
372inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) {
373 const auto from_type_case = from_type->value_case();
374 const auto to_type_case = to_type->value_case();
375 if (from_type_case != to_type_case) {
376 fail_shape_inference("Mismatch between source and target type. Source=", from_type_case, " Target=", to_type_case);
377 }
378
379 if (TypeProto::kTensorType == from_type_case || TypeProto::kSparseTensorType == from_type_case) {
380 // If input shape is "unknown", the corresponding should be "unknown" too.
381 // The way to make output shape unknown is not to assign it any value.
382 if (hasShape(*from_type)) {
383 if (TypeProto::kTensorType == from_type_case) {
384 *to_type->mutable_tensor_type()->mutable_shape() = from_type->tensor_type().shape();
385 } else {
386 *to_type->mutable_sparse_tensor_type()->mutable_shape() = from_type->sparse_tensor_type().shape();
387 }
388 }
389 } else if (TypeProto::kSequenceType == from_type_case) {
390 propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type());
391 } else if (TypeProto::kOptionalType == from_type_case) {
392 propagateShape(&from_type->optional_type().elem_type(), to_type->mutable_optional_type()->mutable_elem_type());
393 } else if (TypeProto::kMapType == from_type_case) {
394 propagateShape(&from_type->map_type().value_type(), to_type->mutable_map_type()->mutable_value_type());
395 } else {
396 fail_shape_inference("Unsupported Source/Target type=", from_type_case);
397 }
398}
399
400inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
401 auto output_type = ctx.getOutputType(outputIndex);
402 auto input_type = ctx.getInputType(inputIndex);
403
404 propagateShape(input_type, output_type);
405}
406
407inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {
408 propagateElemTypeFromInputToOutput(ctx, 0, 0);
409 if (!hasNInputShapes(ctx, 1)) {
410 return;
411 }
412 propagateShapeFromInputToOutput(ctx, 0, 0);
413}
414
415inline void
416updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type) {
417 auto output_type = ctx.getOutputType(outputIndex);
418 if (output_type == nullptr) {
419 fail_type_inference("Output ", outputIndex, " is null");
420 }
421 if (output_type->value_case() == expected_type || output_type->value_case() == TypeProto::VALUE_NOT_SET) {
422 setTensorElementType(elemType, expected_type, *output_type);
423 } else {
424 // This is not expected to happen
425 fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type);
426 }
427}
428
429inline void updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType) {
430 updateOutputElemType(ctx, outputIndex, elemType, TypeProto::kTensorType);
431}
432
433// Infer type of an output from the value of a specified attribute, which is
434// expected to have a valid value representing a TensorProto_DataType.
435inline void propagateElemTypeFromAttributeToOutput(
436 InferenceContext& ctx,
437 const std::string& attributeName,
438 size_t outputIndex,
439 TypeProto::ValueCase expected_type,
440 TensorProto_DataType default_value = TensorProto::UNDEFINED) {
441 auto attr_proto = ctx.getAttribute(attributeName);
442 if (nullptr == attr_proto) { // attribute not present
443 if (default_value != TensorProto::UNDEFINED) {
444 updateOutputElemType(ctx, outputIndex, default_value, expected_type);
445 return;
446 } else {
447 fail_type_inference("Value of attribute ", attributeName, " not specified");
448 }
449 }
450 if (!attr_proto->has_i()) {
451 fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type.");
452 }
453 auto attr_value = attr_proto->i();
454 auto elem_type = static_cast<TensorProto_DataType>(attr_value);
455 if (!TensorProto_DataType_IsValid(elem_type)) {
456 fail_type_inference("Attribute ", attributeName, " does not specify a valid type.");
457 }
458 updateOutputElemType(ctx, outputIndex, elem_type, expected_type);
459}
460
461inline void propagateElemTypeFromAttributeToOutput(
462 InferenceContext& ctx,
463 const std::string& attributeName,
464 size_t outputIndex,
465 TensorProto_DataType default_value = TensorProto::UNDEFINED) {
466 propagateElemTypeFromAttributeToOutput(ctx, attributeName, outputIndex, TypeProto::kTensorType, default_value);
467}
468
469inline TensorShapeProto* getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto& type) {
470 if (value_case == TypeProto::kTensorType) {
471 return type.mutable_tensor_type()->mutable_shape();
472 } else if (value_case == TypeProto::kSparseTensorType) {
473 return type.mutable_tensor_type()->mutable_shape();
474 }
475 return nullptr;
476}
477
478inline TensorShapeProto*
479getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) {
480 auto output_type = ctx.getOutputType(n);
481 if (output_type == nullptr) {
482 fail_type_inference("Output ", n, " expected to have tensor or sparse type");
483 }
484 const auto output_value_case = output_type->value_case();
485 if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
486 return getTensorMutableShape(output_value_case, *output_type);
487 } else if (output_value_case == TypeProto::VALUE_NOT_SET) {
488 return getTensorMutableShape(default_type, *output_type);
489 } else {
490 fail_type_inference("Output ", n, " expected to have tensor type");
491 }
492}
493
494inline void appendDim(TensorShapeProto* shape, int64_t dim_value) {
495 shape->add_dim()->set_dim_value(dim_value);
496}
497
498inline void updateOutputShape(
499 InferenceContext& ctx,
500 size_t outputIndex,
501 const TensorShapeProto& shape,
502 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
503 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
504 *output_shape = shape;
505}
506
507inline void updateOutputShape(
508 InferenceContext& ctx,
509 size_t outputIndex,
510 const TensorProto& tensorProto,
511 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
512 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
513 for (auto d : tensorProto.dims()) {
514 auto* dim = output_shape->add_dim();
515 dim->set_dim_value(d);
516 }
517}
518
519inline void updateOutputShape(
520 InferenceContext& ctx,
521 size_t outputIndex,
522 std::initializer_list<TensorShapeProto::Dimension> dims,
523 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
524 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
525 for (auto& d : dims) {
526 auto* dim = output_shape->add_dim();
527 *dim = d;
528 }
529}
530
531// Get shape input by first checking initializer and then propagated symbolic data.
532// If neither is available, try rank inference.
533// When one of above succeeds, `true` is stored in `found`.
534// Otherwise, `false` is stored, which means that returned TensorShapeProto does not make sense.
535TensorShapeProto getShapeInput(InferenceContext& ctx, size_t input_index, bool& found);
536
537// Infer shape of an output from the value of a specified attribute, which is
538// expected to be a list of integers specifying a valid shape.
539inline void propagateShapeFromAttributeToOutput(
540 InferenceContext& ctx,
541 const std::string& attributeName,
542 size_t outputIndex,
543 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
544 auto attr_proto = ctx.getAttribute(attributeName);
545 if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
546 (attr_proto->type() != AttributeProto_AttributeType_INTS)) {
547 fail_shape_inference("Attribute ", attributeName, " should specify a shape");
548 }
549 auto& int_list = attr_proto->ints();
550 TensorShapeProto shape;
551 for (auto dim_size : int_list) {
552 if (dim_size < 0) {
553 fail_shape_inference("Negative values are not allowed in a shape specification");
554 }
555 shape.add_dim()->set_dim_value(dim_size);
556 }
557
558 updateOutputShape(ctx, outputIndex, shape, default_type);
559}
560
561inline void multidirectionalBroadcastShapeInference(
562 const std::vector<const TensorShapeProto*>& shapes,
563 TensorShapeProto& resultShape) {
564 int result_shape_size = 0;
565 // Get the result shape size.
566 for (size_t i = 0; i < shapes.size(); ++i) {
567 if (shapes[i]->dim_size() > result_shape_size) {
568 result_shape_size = shapes[i]->dim_size();
569 }
570 }
571
572 for (int i = 0; i < result_shape_size; ++i) {
573 int64_t dim_value = 1;
574 TensorShapeProto_Dimension symbolic_dim;
575 int num_symbolic_dims = 0;
576 for (size_t j = 0; j < shapes.size(); ++j) {
577 if (i < result_shape_size - shapes[j]->dim_size()) {
578 // Shape j will be filled with 1 at dimension i;
579 continue;
580 }
581
582 auto dim_i_j = shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size());
583 if (dim_i_j.has_dim_value()) {
584 if (dim_i_j.dim_value() != 1) {
585 if (dim_value != dim_i_j.dim_value() && dim_value != 1) {
586 fail_shape_inference("Incompatible dimensions");
587 } else {
588 dim_value = dim_i_j.dim_value();
589 }
590 }
591 } else {
592 if (num_symbolic_dims == 0) {
593 symbolic_dim = dim_i_j;
594 ++num_symbolic_dims;
595 } else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) {
596 ++num_symbolic_dims;
597 }
598 }
599 }
600
601 if (dim_value != 1 || num_symbolic_dims == 0) {
602 resultShape.add_dim()->set_dim_value(dim_value);
603 } else if (num_symbolic_dims == 1) {
604 *resultShape.add_dim() = symbolic_dim;
605 } else {
606 resultShape.add_dim();
607 }
608 }
609}
610
611inline void bidirectionalBroadcastShapeInference(
612 const TensorShapeProto& shapeL,
613 const TensorShapeProto& shapeR,
614 TensorShapeProto& resultShape) {
615 std::vector<const TensorShapeProto*> shapes;
616 shapes.push_back(&shapeL);
617 shapes.push_back(&shapeR);
618 multidirectionalBroadcastShapeInference(shapes, resultShape);
619}
620
621/*
622Merge the dimension information from two TensorShapeProto_Dimension instances.
623Values are merged into target from source.
624If target has no dimension information, copy from source.
625If source has no dimension information, ignore source.
626If both have dimension information:
627 - Prefer values over params. If both have values, values must match.
628 - Prefer target param over source param if mismatched.
629Fail if there are mismatches in dimension values.
630Currently, there is no way to refine/update dimension information for the
631source from information available in the target.
632*/
633inline void mergeInDimensionInfo(
634 const TensorShapeProto_Dimension& source_dim,
635 TensorShapeProto_Dimension& target_dim,
636 int dim_index) {
637 // if source has value, merge into target
638 // else if target has value, preserve it
639 // else merge params
640 if (source_dim.has_dim_value()) {
641 auto source_value = source_dim.dim_value();
642 if (target_dim.has_dim_value()) {
643 auto target_value = target_dim.dim_value();
644 if (target_value != source_value) {
645 fail_shape_inference(
646 "Can't merge shape info. "
647 "Both source and target dimension have values but they differ. Source=",
648 source_value,
649 " Target=",
650 target_value,
651 " Dimension=",
652 dim_index);
653 }
654 } else {
655 target_dim.set_dim_value(source_value);
656 }
657 } else if (target_dim.has_dim_value()) {
658 // if target has a value we preserve it so do nothing
659 } else if (target_dim.has_dim_param()) {
660 // prefer target param over source
661 } else if (source_dim.has_dim_param()) {
662 target_dim.set_dim_param(source_dim.dim_param());
663 }
664}
665
666void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
667
668void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
669
670/*
671Merge the shape information from two TypeProto_Tensor instances.
672Values are merged into target from source.
673If target has no shape information, copy from source.
674If source has no shape information, ignore source.
675If both have shape information:
676- merge each TensorShapeProto_Dimension separately.
677- Prefer values over params. If both have values, values must match.
678- Prefer target param over source param if mismatched.
679Fail if there are mismatches in number of dimensions or dimension values.
680*/
681void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target);
682
683void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target);
684
685// Return a copy of a type, with a specified dimension removed from its shape.
686inline TypeProto RemoveIthDimensionFromShape(const TypeProto& proto, int removed_dim) {
687 TypeProto t(proto);
688 auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
689 mutable_shape->clear_dim();
690
691 const auto& dims = proto.tensor_type().shape().dim();
692
693 for (int j = 0, end = dims.size(); j < end; ++j) {
694 if (j != removed_dim)
695 (*mutable_shape->add_dim()) = dims.Get(j);
696 }
697
698 return t;
699}
700
701// Return a copy of a type, with specified number of dimensions removed from the
702// beginning.
703inline TypeProto RemoveDimensionsFromShape(const TypeProto& proto, int num_dimensions) {
704 TypeProto t(proto);
705 auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
706 mutable_shape->clear_dim();
707
708 const auto& dims = proto.tensor_type().shape().dim();
709
710 // skip first num_dimensions
711 for (int j = num_dimensions, end = dims.size(); j < end; ++j) {
712 (*mutable_shape->add_dim()) = dims.Get(j);
713 }
714
715 return t;
716}
717
718// copied from GSL:
719// https://github.com/Microsoft/GSL/blob/master/include/gsl/gsl_util
720template <class T, class U>
721static constexpr T narrow_cast(U&& u) noexcept {
722 return static_cast<T>(std::forward<U>(u));
723}
724
725inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expected_rank) {
726 // We check the rank only if a rank is known for the input:
727 if (hasInputShape(ctx, input_index)) {
728 auto rank = getInputShape(ctx, input_index).dim_size();
729 if (rank != expected_rank) {
730 fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank);
731 }
732 }
733}
734
735// Unification (between dimensions and/or shapes) is at the heart of
736// shape-inference. The current inference algorithm can check input
737// shapes/dimensions of a node and update the output shapes/dimensions. It
738// cannot currently update input shapes and dimensions (even though in some
739// contexts this inference is possible). Hence, we have the variants below to
740// support "const" and "mutable" dimensions/shapes in unification.
741
742inline void checkDimEquality(int64_t value1, int64_t value2) {
743 if (value1 != value2) {
744 fail_shape_inference("Dimension mismatch in unification between ", value1, " and ", value2);
745 }
746}
747
748inline void unifyDim(const Dim& dim1, const Dim& dim2) {
749 if (dim1.has_dim_value() && dim2.has_dim_value())
750 checkDimEquality(dim1.dim_value(), dim2.dim_value());
751}
752
753// TODO: The functionality of unifyDim is similar to that of
754// mergeInDimensionInfo. However, the error messages are different. Leaving this
755// duplication in-place to preserve error message content.
756inline void unifyDim(const Dim& source_dim, Dim& target_dim) {
757 if (source_dim.has_dim_value()) {
758 auto source_value = source_dim.dim_value();
759 if (target_dim.has_dim_value()) {
760 auto target_value = target_dim.dim_value();
761 checkDimEquality(source_value, target_value);
762 } else {
763 target_dim.set_dim_value(source_value);
764 }
765 } else if (target_dim.has_dim_value()) {
766 // if target has a value we preserve it.
767 // we cannot set source dim value.
768 } else if (target_dim.has_dim_param()) {
769 // prefer target param over source
770 // we cannot currently unify the dim_params
771 } else if (source_dim.has_dim_param()) {
772 target_dim.set_dim_param(source_dim.dim_param());
773 }
774}
775
776inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_index, Dim& dim) {
777 // We unify the dimensions only if it is available for specified input:
778 if (hasInputShape(ctx, input_index)) {
779 auto& input_shape = getInputShape(ctx, input_index);
780 // This shape is expected to have rank > dim_index:
781 if (input_shape.dim_size() <= dim_index) {
782 fail_shape_inference(
783 "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size());
784 }
785 const Dim& input_dim = input_shape.dim(dim_index);
786 // Now, unify dim and input_dim:
787 unifyDim(input_dim, dim);
788 }
789}
790
791// unifyDim: unifies a dimension with a constant value. If the dimension
792// already has a value, we check for equality of new value with old value.
793inline void unifyDim(Dim& dim, int64_t value) {
794 if (dim.has_dim_value()) {
795 checkDimEquality(dim.dim_value(), value);
796 } else
797 dim.set_dim_value(value);
798}
799
800// target-shape = Union (target-shape, source_shape)
801// Example 1: same rank, different dimensions
802// input1 shape: (2, 3, 4, 'x')
803// input2 shape: (2, 'y', 5, 'x')
804// output shape: (2, None, None, 'x')
805// Example 2: different rank
806// input1 shape: (2, 3, 4, 'x')
807// input2 shape: (2, 3, 4)
808// output shape: None
809void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
810
811void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
812
813// target-type = Union (target-type, source-type)
814// target and source are required to have the same type.
815// Example 1: same tensor type, different shape
816// source: tensor elem_type: int64, shape: (2, 3, 4, 'x')
817// target: tensor elem_type: int64, shape: (2, 'y', 5, 'x')
818// output: tensor elem_type: int64, shape: (2, None, None, 'x')
819// Example 2: same sequence type, different shape
820// source: sequence of tensor, elem_type: float, shape: (2, 3, 4)
821// target: sequence of tensor, elem_type: float, shape: None
822// output: sequence of tensor, elem_type: float, shape: None
823void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type);
824
825} // namespace ONNX_NAMESPACE
826