1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #pragma once |
6 | |
7 | #include <functional> |
8 | #include "onnx/defs/data_type_utils.h" |
9 | #include "onnx/proto_utils.h" |
10 | #include "onnx/string_utils.h" |
11 | |
12 | namespace ONNX_NAMESPACE { |
13 | |
14 | using Dim = TensorShapeProto_Dimension; |
15 | |
16 | struct ShapeInferenceOptions { |
17 | // Checks the type-equality for input and output |
18 | bool check_type; |
19 | // 1: Will throw any node level shape infer errors |
20 | // 0: Won't throw node-level shape infer errors, but other errors |
21 | // like merging existing shape with inferred etc are thrown |
22 | int error_mode; |
23 | // Enables data propagation for limited operators |
24 | // to perform shape computation |
25 | bool enable_data_propagation; |
26 | ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false) |
27 | : check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){}; |
28 | }; |
29 | |
30 | // Maintains a SymbolTable for symbolic shape inference |
31 | class SymbolTable { |
32 | public: |
33 | // Adds existing symbols from a main graph or subgraph |
34 | virtual void addFromGraph(const GraphProto& g) = 0; |
35 | // Creates a new symbol which is not duplicate as any existing one |
36 | virtual std::string createNew(const std::string& symbol_prefix) = 0; |
37 | virtual ~SymbolTable() = default; |
38 | }; |
39 | |
40 | class GraphInferencer { |
41 | public: |
42 | // Perform inferencing on the graph contained in GraphInferencer. |
43 | // Returns the graph output types post-inferencing. |
44 | virtual std::vector<const TypeProto*> doInferencing( |
45 | const std::vector<const TypeProto*>& inputTypes, |
46 | const std::vector<const TensorProto*>& inputData) = 0; |
47 | virtual ~GraphInferencer() = default; |
48 | }; |
49 | |
50 | // Exception class used for handling errors in type and shape inference |
51 | |
52 | class InferenceError final : public std::runtime_error { |
53 | public: |
54 | using std::runtime_error::runtime_error; |
55 | |
56 | InferenceError(const std::string& message) : std::runtime_error(message) {} |
57 | |
58 | const char* what() const noexcept override { |
59 | if (!expanded_message_.empty()) { |
60 | return expanded_message_.c_str(); |
61 | } |
62 | return std::runtime_error::what(); |
63 | } |
64 | |
65 | void AppendContext(const std::string& context) { |
66 | expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: " , context); |
67 | } |
68 | |
69 | private: |
70 | std::string expanded_message_; |
71 | }; |
72 | |
73 | #define fail_type_inference(...) \ |
74 | ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__))); |
75 | |
76 | #define fail_shape_inference(...) \ |
77 | ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__))); |
78 | |
79 | struct InferenceContext { |
80 | virtual const AttributeProto* getAttribute(const std::string& name) const = 0; |
81 | virtual size_t getNumInputs() const = 0; |
82 | virtual const TypeProto* getInputType(size_t index) const = 0; |
83 | virtual bool hasInput(size_t index) const { |
84 | // The default implementation below is used for backward-compatibility |
85 | // for implementations of InferenceContext that don't provide an explicit |
86 | // implementation. This works for normal usage, but may be imprecise in |
87 | // the edge-case where an input is supplied but has no known type. |
88 | // However, inference-methods work only under the assumption that the |
89 | // input-types of all inputs are known. |
90 | return ((index < getNumInputs()) && (getInputType(index) != nullptr)); |
91 | } |
92 | virtual const TensorProto* getInputData(size_t index) const = 0; |
93 | virtual size_t getNumOutputs() const = 0; |
94 | virtual TypeProto* getOutputType(size_t index) = 0; |
95 | virtual GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) = 0; |
96 | virtual ~InferenceContext() {} |
97 | virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0; |
98 | // Gets the shape inputs computed by partial data propagation. |
99 | virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0; |
100 | }; |
101 | |
102 | // We use data propagation to perform partial evaluation of the model, to compute statically |
103 | // known information about tensor values. It is intended to improve the precision of shape |
104 | // inference. We reuse TensorShapeProto to represent the statically known values. One |
105 | // limitation of this is that TensorShapeProto can represent only integer values. |
106 | // As an example, data-propagation is intended to handle code-fragments like below: |
107 | // shape = Shape(X) |
108 | // batchsize = Slice(shape, [0], [1]) |
109 | // newshape = Concat (batchsize, [1024, 1024]) |
110 | // Z = Reshape(Y, newshape) |
111 | // If the shape of X is statically known, then data-propagation should be able to determine |
112 | // the value of newshape, as well as the shape of Z. |
113 | struct DataPropagationContext { |
114 | virtual const AttributeProto* getAttribute(const std::string& name) const = 0; |
115 | virtual size_t getNumInputs() const = 0; |
116 | virtual const TypeProto* getInputType(size_t index) const = 0; |
117 | virtual size_t getNumOutputs() const = 0; |
118 | virtual const TypeProto* getOutputType(size_t index) const = 0; |
119 | virtual ~DataPropagationContext() {} |
120 | virtual const TensorShapeProto* getInputData(size_t index) = 0; |
121 | virtual void addOutputData(size_t index, TensorShapeProto&& tp) = 0; |
122 | }; |
123 | |
124 | using InferenceFunction = std::function<void(InferenceContext&)>; |
125 | using DataPropagationFunction = std::function<void(DataPropagationContext&)>; |
126 | |
127 | // This no-op inference function is used for operators without an |
128 | // inference implementation. |
129 | inline void dummyInferenceFunction(InferenceContext&){}; |
130 | |
131 | // This no-op data propagation function is used for operators without a defined data propagator |
132 | inline void dummyDataPropagationFunction(DataPropagationContext&){}; |
133 | |
134 | template <typename T> |
135 | inline bool getRepeatedAttribute(InferenceContext& ctx, std::string attr_name, std::vector<T>& values) { |
136 | const auto* attr = ctx.getAttribute(attr_name); |
137 | if (attr) { |
138 | values = RetrieveValues<T>(*attr); |
139 | return true; |
140 | } else { |
141 | return false; |
142 | } |
143 | } |
144 | |
145 | inline int64_t getAttribute(InferenceContext& ctx, const std::string& attributeName, int64_t defaultValue) { |
146 | auto attr_proto = ctx.getAttribute(attributeName); |
147 | if ((nullptr != attr_proto) && attr_proto->has_i()) |
148 | return attr_proto->i(); |
149 | return defaultValue; |
150 | } |
151 | |
152 | inline int64_t getAttribute(DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) { |
153 | auto attr_proto = ctx.getAttribute(attributeName); |
154 | if ((nullptr != attr_proto) && attr_proto->has_i()) |
155 | return attr_proto->i(); |
156 | return defaultValue; |
157 | } |
158 | |
159 | inline std::string |
160 | getAttribute(InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) { |
161 | auto attr_proto = ctx.getAttribute(attributeName); |
162 | if ((nullptr != attr_proto) && attr_proto->has_s()) |
163 | return attr_proto->s(); |
164 | return defaultValue; |
165 | } |
166 | |
167 | inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, TensorShapeProto::Dimension dim2) { |
168 | TensorShapeProto::Dimension result; |
169 | if (dim1.has_dim_value() && dim2.has_dim_value()) { |
170 | result.set_dim_value(dim1.dim_value() * dim2.dim_value()); |
171 | } else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) { |
172 | return dim2; |
173 | } else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) { |
174 | return dim1; |
175 | } |
176 | return result; |
177 | } |
178 | |
179 | inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, int64_t dim2) { |
180 | TensorShapeProto::Dimension result; |
181 | if (dim1.has_dim_value()) { |
182 | result.set_dim_value(dim1.dim_value() * dim2); |
183 | } else if (dim2 == 1) { |
184 | return dim1; |
185 | } |
186 | return result; |
187 | } |
188 | |
189 | inline TensorShapeProto::Dimension operator/(TensorShapeProto::Dimension dim1, int64_t dim2) { |
190 | TensorShapeProto::Dimension result; |
191 | if (dim1.has_dim_value()) { |
192 | result.set_dim_value(dim1.dim_value() / dim2); |
193 | } else if (dim2 == 1) { |
194 | return dim1; |
195 | } |
196 | return result; |
197 | } |
198 | |
199 | // if from >= upto_exclusive, return 1. |
200 | // Caller must make sure upto_exclusive is less than or equal to shape.size() |
201 | // Caller must make sure from>=0 |
202 | inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) { |
203 | TensorShapeProto::Dimension dim; |
204 | dim.set_dim_value(1); |
205 | for (int i = from; i < upto_exclusive; ++i) { |
206 | dim = dim * shape.dim(i); |
207 | } |
208 | return dim; |
209 | } |
210 | |
211 | inline int32_t getTensorElementType(const TypeProto& type) { |
212 | int32_t result = TensorProto::UNDEFINED; |
213 | const auto value_case = type.value_case(); |
214 | if (value_case == TypeProto::kTensorType) { |
215 | result = type.tensor_type().elem_type(); |
216 | } else if (value_case == TypeProto::kSparseTensorType) { |
217 | result = type.sparse_tensor_type().elem_type(); |
218 | } |
219 | return result; |
220 | } |
221 | |
222 | inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto& type) { |
223 | if (value_case == TypeProto::kTensorType) { |
224 | type.mutable_tensor_type()->set_elem_type(elem_type); |
225 | } else if (value_case == TypeProto::kSparseTensorType) { |
226 | type.mutable_sparse_tensor_type()->set_elem_type(elem_type); |
227 | } |
228 | } |
229 | |
230 | void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type); |
231 | |
232 | void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex); |
233 | |
234 | void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex); |
235 | |
236 | inline void propagateElemTypeFromDtypeToOutput( |
237 | InferenceContext& ctx, |
238 | const int data_type, |
239 | size_t outputIndex, |
240 | TypeProto::ValueCase expected_value_case) { |
241 | const auto attribute_tensor_datatype = data_type; |
242 | auto output_type = ctx.getOutputType(outputIndex); |
243 | const auto output_value_case = output_type->value_case(); |
244 | if (output_value_case == TypeProto::VALUE_NOT_SET || output_value_case == expected_value_case) { |
245 | setTensorElementType(attribute_tensor_datatype, expected_value_case, *output_type); |
246 | } else { |
247 | // This is not expected to happen |
248 | fail_type_inference( |
249 | "Output " , outputIndex, " expected to have: " , expected_value_case, " or UNDEFINED. Got: " , output_value_case); |
250 | } |
251 | } |
252 | |
253 | inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const int data_type, size_t outputIndex) { |
254 | propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, TypeProto::kTensorType); |
255 | } |
256 | |
257 | inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const AttributeProto* attr, size_t outputIndex) { |
258 | int32_t data_type = TensorProto::UNDEFINED; |
259 | TypeProto::ValueCase expected_value_case = TypeProto::VALUE_NOT_SET; |
260 | const auto attr_type = attr->type(); |
261 | if (attr_type == AttributeProto::TENSOR) { |
262 | if (attr->t().dims().size() != 1) { |
263 | fail_type_inference("Attribute expected to have a one-dim tensor" ); |
264 | } |
265 | data_type = attr->t().data_type(); |
266 | expected_value_case = TypeProto::kTensorType; |
267 | } else if (attr_type == AttributeProto::SPARSE_TENSOR) { |
268 | if (attr->sparse_tensor().dims().size() != 1) { |
269 | fail_type_inference("Attribute expected to have a one-dim sparse tensor" ); |
270 | } |
271 | data_type = attr->sparse_tensor().values().data_type(); |
272 | expected_value_case = TypeProto::kSparseTensorType; |
273 | } else { |
274 | fail_type_inference("Attribute expected to have tensor or sparse tensor type" ); |
275 | } |
276 | |
277 | propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case); |
278 | } |
279 | |
280 | inline bool hasShape(const TypeProto& type) { |
281 | if (type.has_tensor_type()) { |
282 | return type.tensor_type().has_shape(); |
283 | } else if (type.has_sparse_tensor_type()) { |
284 | return type.sparse_tensor_type().has_shape(); |
285 | } else if (type.has_sequence_type() && type.sequence_type().has_elem_type()) { |
286 | return hasShape(type.sequence_type().elem_type()); |
287 | } else if (type.has_optional_type() && type.optional_type().has_elem_type()) { |
288 | return hasShape(type.optional_type().elem_type()); |
289 | } |
290 | return false; |
291 | } |
292 | |
293 | template <typename Context> |
294 | inline bool hasInputShape(Context& ctx, size_t n) { |
295 | return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) && hasShape(*ctx.getInputType(n)); |
296 | } |
297 | |
298 | template <typename Context> |
299 | inline bool hasNInputShapes(Context& ctx, size_t n) { |
300 | for (size_t i = 0; i < n; i++) { |
301 | if (!hasInputShape(ctx, i)) { |
302 | return false; |
303 | } |
304 | } |
305 | return true; |
306 | } |
307 | |
308 | inline const TensorShapeProto& getInputShape(InferenceContext& ctx, size_t n) { |
309 | const auto* input_type = ctx.getInputType(n); |
310 | const auto value_case = input_type->value_case(); |
311 | if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { |
312 | fail_type_inference("Attribute expected to have tensor or sparse tensor type" ); |
313 | } |
314 | if (value_case == TypeProto::kTensorType) { |
315 | return input_type->tensor_type().shape(); |
316 | } else { |
317 | return input_type->sparse_tensor_type().shape(); |
318 | } |
319 | } |
320 | |
321 | inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size_t n) { |
322 | const auto* input_type = ctx.getInputType(n); |
323 | |
324 | if (input_type == nullptr) { |
325 | return nullptr; |
326 | } |
327 | |
328 | const auto value_case = input_type->value_case(); |
329 | if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) { |
330 | fail_type_inference("Attribute expected to have tensor or sparse tensor type" ); |
331 | } |
332 | if (value_case == TypeProto::kTensorType) { |
333 | return &input_type->tensor_type().shape(); |
334 | } else { |
335 | return &input_type->sparse_tensor_type().shape(); |
336 | } |
337 | } |
338 | |
339 | // Caller must make sure fromDimIndex is strictly less than shape.dim_size() |
340 | inline void appendSingleDimCopiedFromInputTypeToOutputType( |
341 | InferenceContext& ctx, |
342 | size_t inputIndex, |
343 | size_t outputIndex, |
344 | size_t fromDimIndex) { |
345 | auto output_type = ctx.getOutputType(outputIndex); |
346 | const auto output_value_case = output_type->value_case(); |
347 | auto input_type = ctx.getInputType(inputIndex); |
348 | const auto input_value_case = input_type->value_case(); |
349 | if (output_value_case != input_value_case) { |
350 | fail_type_inference( |
351 | "Input: " , |
352 | inputIndex, |
353 | " type: " , |
354 | input_value_case, |
355 | " does not match type of output: " , |
356 | outputIndex, |
357 | "type: " , |
358 | output_value_case); |
359 | } |
360 | if (TypeProto::kTensorType == input_value_case) { |
361 | auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim(); |
362 | *dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex)); |
363 | } else if (TypeProto::kSparseTensorType == input_value_case) { |
364 | auto* dim = output_type->mutable_sparse_tensor_type()->mutable_shape()->add_dim(); |
365 | *dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex)); |
366 | } else { |
367 | fail_type_inference( |
368 | "Input " , inputIndex, " and Output " , outputIndex, " expected to have tensor or sparse tensor type" ); |
369 | } |
370 | } |
371 | |
372 | inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) { |
373 | const auto from_type_case = from_type->value_case(); |
374 | const auto to_type_case = to_type->value_case(); |
375 | if (from_type_case != to_type_case) { |
376 | fail_shape_inference("Mismatch between source and target type. Source=" , from_type_case, " Target=" , to_type_case); |
377 | } |
378 | |
379 | if (TypeProto::kTensorType == from_type_case || TypeProto::kSparseTensorType == from_type_case) { |
380 | // If input shape is "unknown", the corresponding should be "unknown" too. |
381 | // The way to make output shape unknown is not to assign it any value. |
382 | if (hasShape(*from_type)) { |
383 | if (TypeProto::kTensorType == from_type_case) { |
384 | *to_type->mutable_tensor_type()->mutable_shape() = from_type->tensor_type().shape(); |
385 | } else { |
386 | *to_type->mutable_sparse_tensor_type()->mutable_shape() = from_type->sparse_tensor_type().shape(); |
387 | } |
388 | } |
389 | } else if (TypeProto::kSequenceType == from_type_case) { |
390 | propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type()); |
391 | } else if (TypeProto::kOptionalType == from_type_case) { |
392 | propagateShape(&from_type->optional_type().elem_type(), to_type->mutable_optional_type()->mutable_elem_type()); |
393 | } else if (TypeProto::kMapType == from_type_case) { |
394 | propagateShape(&from_type->map_type().value_type(), to_type->mutable_map_type()->mutable_value_type()); |
395 | } else { |
396 | fail_shape_inference("Unsupported Source/Target type=" , from_type_case); |
397 | } |
398 | } |
399 | |
400 | inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) { |
401 | auto output_type = ctx.getOutputType(outputIndex); |
402 | auto input_type = ctx.getInputType(inputIndex); |
403 | |
404 | propagateShape(input_type, output_type); |
405 | } |
406 | |
407 | inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) { |
408 | propagateElemTypeFromInputToOutput(ctx, 0, 0); |
409 | if (!hasNInputShapes(ctx, 1)) { |
410 | return; |
411 | } |
412 | propagateShapeFromInputToOutput(ctx, 0, 0); |
413 | } |
414 | |
415 | inline void |
416 | updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type) { |
417 | auto output_type = ctx.getOutputType(outputIndex); |
418 | if (output_type == nullptr) { |
419 | fail_type_inference("Output " , outputIndex, " is null" ); |
420 | } |
421 | if (output_type->value_case() == expected_type || output_type->value_case() == TypeProto::VALUE_NOT_SET) { |
422 | setTensorElementType(elemType, expected_type, *output_type); |
423 | } else { |
424 | // This is not expected to happen |
425 | fail_type_inference("Output " , outputIndex, " expected to have tensor or sparse tensor type: " , expected_type); |
426 | } |
427 | } |
428 | |
429 | inline void updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType) { |
430 | updateOutputElemType(ctx, outputIndex, elemType, TypeProto::kTensorType); |
431 | } |
432 | |
433 | // Infer type of an output from the value of a specified attribute, which is |
434 | // expected to have a valid value representing a TensorProto_DataType. |
435 | inline void propagateElemTypeFromAttributeToOutput( |
436 | InferenceContext& ctx, |
437 | const std::string& attributeName, |
438 | size_t outputIndex, |
439 | TypeProto::ValueCase expected_type, |
440 | TensorProto_DataType default_value = TensorProto::UNDEFINED) { |
441 | auto attr_proto = ctx.getAttribute(attributeName); |
442 | if (nullptr == attr_proto) { // attribute not present |
443 | if (default_value != TensorProto::UNDEFINED) { |
444 | updateOutputElemType(ctx, outputIndex, default_value, expected_type); |
445 | return; |
446 | } else { |
447 | fail_type_inference("Value of attribute " , attributeName, " not specified" ); |
448 | } |
449 | } |
450 | if (!attr_proto->has_i()) { |
451 | fail_type_inference("Attribute " , attributeName, " should be of integer type and specify a type." ); |
452 | } |
453 | auto attr_value = attr_proto->i(); |
454 | auto elem_type = static_cast<TensorProto_DataType>(attr_value); |
455 | if (!TensorProto_DataType_IsValid(elem_type)) { |
456 | fail_type_inference("Attribute " , attributeName, " does not specify a valid type." ); |
457 | } |
458 | updateOutputElemType(ctx, outputIndex, elem_type, expected_type); |
459 | } |
460 | |
461 | inline void propagateElemTypeFromAttributeToOutput( |
462 | InferenceContext& ctx, |
463 | const std::string& attributeName, |
464 | size_t outputIndex, |
465 | TensorProto_DataType default_value = TensorProto::UNDEFINED) { |
466 | propagateElemTypeFromAttributeToOutput(ctx, attributeName, outputIndex, TypeProto::kTensorType, default_value); |
467 | } |
468 | |
469 | inline TensorShapeProto* getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto& type) { |
470 | if (value_case == TypeProto::kTensorType) { |
471 | return type.mutable_tensor_type()->mutable_shape(); |
472 | } else if (value_case == TypeProto::kSparseTensorType) { |
473 | return type.mutable_tensor_type()->mutable_shape(); |
474 | } |
475 | return nullptr; |
476 | } |
477 | |
478 | inline TensorShapeProto* |
479 | getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) { |
480 | auto output_type = ctx.getOutputType(n); |
481 | if (output_type == nullptr) { |
482 | fail_type_inference("Output " , n, " expected to have tensor or sparse type" ); |
483 | } |
484 | const auto output_value_case = output_type->value_case(); |
485 | if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) { |
486 | return getTensorMutableShape(output_value_case, *output_type); |
487 | } else if (output_value_case == TypeProto::VALUE_NOT_SET) { |
488 | return getTensorMutableShape(default_type, *output_type); |
489 | } else { |
490 | fail_type_inference("Output " , n, " expected to have tensor type" ); |
491 | } |
492 | } |
493 | |
494 | inline void appendDim(TensorShapeProto* shape, int64_t dim_value) { |
495 | shape->add_dim()->set_dim_value(dim_value); |
496 | } |
497 | |
498 | inline void updateOutputShape( |
499 | InferenceContext& ctx, |
500 | size_t outputIndex, |
501 | const TensorShapeProto& shape, |
502 | TypeProto::ValueCase default_type = TypeProto::kTensorType) { |
503 | auto* output_shape = getOutputShape(ctx, outputIndex, default_type); |
504 | *output_shape = shape; |
505 | } |
506 | |
507 | inline void updateOutputShape( |
508 | InferenceContext& ctx, |
509 | size_t outputIndex, |
510 | const TensorProto& tensorProto, |
511 | TypeProto::ValueCase default_type = TypeProto::kTensorType) { |
512 | auto* output_shape = getOutputShape(ctx, outputIndex, default_type); |
513 | for (auto d : tensorProto.dims()) { |
514 | auto* dim = output_shape->add_dim(); |
515 | dim->set_dim_value(d); |
516 | } |
517 | } |
518 | |
519 | inline void updateOutputShape( |
520 | InferenceContext& ctx, |
521 | size_t outputIndex, |
522 | std::initializer_list<TensorShapeProto::Dimension> dims, |
523 | TypeProto::ValueCase default_type = TypeProto::kTensorType) { |
524 | auto* output_shape = getOutputShape(ctx, outputIndex, default_type); |
525 | for (auto& d : dims) { |
526 | auto* dim = output_shape->add_dim(); |
527 | *dim = d; |
528 | } |
529 | } |
530 | |
531 | // Get shape input by first checking initializer and then propagated symbolic data. |
532 | // If neither is available, try rank inference. |
533 | // When one of above succeeds, `true` is stored in `found`. |
534 | // Otherwise, `false` is stored, which means that returned TensorShapeProto does not make sense. |
535 | TensorShapeProto getShapeInput(InferenceContext& ctx, size_t input_index, bool& found); |
536 | |
537 | // Infer shape of an output from the value of a specified attribute, which is |
538 | // expected to be a list of integers specifying a valid shape. |
539 | inline void propagateShapeFromAttributeToOutput( |
540 | InferenceContext& ctx, |
541 | const std::string& attributeName, |
542 | size_t outputIndex, |
543 | TypeProto::ValueCase default_type = TypeProto::kTensorType) { |
544 | auto attr_proto = ctx.getAttribute(attributeName); |
545 | if ((nullptr == attr_proto) || (!attr_proto->has_type()) || |
546 | (attr_proto->type() != AttributeProto_AttributeType_INTS)) { |
547 | fail_shape_inference("Attribute " , attributeName, " should specify a shape" ); |
548 | } |
549 | auto& int_list = attr_proto->ints(); |
550 | TensorShapeProto shape; |
551 | for (auto dim_size : int_list) { |
552 | if (dim_size < 0) { |
553 | fail_shape_inference("Negative values are not allowed in a shape specification" ); |
554 | } |
555 | shape.add_dim()->set_dim_value(dim_size); |
556 | } |
557 | |
558 | updateOutputShape(ctx, outputIndex, shape, default_type); |
559 | } |
560 | |
561 | inline void multidirectionalBroadcastShapeInference( |
562 | const std::vector<const TensorShapeProto*>& shapes, |
563 | TensorShapeProto& resultShape) { |
564 | int result_shape_size = 0; |
565 | // Get the result shape size. |
566 | for (size_t i = 0; i < shapes.size(); ++i) { |
567 | if (shapes[i]->dim_size() > result_shape_size) { |
568 | result_shape_size = shapes[i]->dim_size(); |
569 | } |
570 | } |
571 | |
572 | for (int i = 0; i < result_shape_size; ++i) { |
573 | int64_t dim_value = 1; |
574 | TensorShapeProto_Dimension symbolic_dim; |
575 | int num_symbolic_dims = 0; |
576 | for (size_t j = 0; j < shapes.size(); ++j) { |
577 | if (i < result_shape_size - shapes[j]->dim_size()) { |
578 | // Shape j will be filled with 1 at dimension i; |
579 | continue; |
580 | } |
581 | |
582 | auto dim_i_j = shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size()); |
583 | if (dim_i_j.has_dim_value()) { |
584 | if (dim_i_j.dim_value() != 1) { |
585 | if (dim_value != dim_i_j.dim_value() && dim_value != 1) { |
586 | fail_shape_inference("Incompatible dimensions" ); |
587 | } else { |
588 | dim_value = dim_i_j.dim_value(); |
589 | } |
590 | } |
591 | } else { |
592 | if (num_symbolic_dims == 0) { |
593 | symbolic_dim = dim_i_j; |
594 | ++num_symbolic_dims; |
595 | } else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) { |
596 | ++num_symbolic_dims; |
597 | } |
598 | } |
599 | } |
600 | |
601 | if (dim_value != 1 || num_symbolic_dims == 0) { |
602 | resultShape.add_dim()->set_dim_value(dim_value); |
603 | } else if (num_symbolic_dims == 1) { |
604 | *resultShape.add_dim() = symbolic_dim; |
605 | } else { |
606 | resultShape.add_dim(); |
607 | } |
608 | } |
609 | } |
610 | |
611 | inline void bidirectionalBroadcastShapeInference( |
612 | const TensorShapeProto& shapeL, |
613 | const TensorShapeProto& shapeR, |
614 | TensorShapeProto& resultShape) { |
615 | std::vector<const TensorShapeProto*> shapes; |
616 | shapes.push_back(&shapeL); |
617 | shapes.push_back(&shapeR); |
618 | multidirectionalBroadcastShapeInference(shapes, resultShape); |
619 | } |
620 | |
621 | /* |
622 | Merge the dimension information from two TensorShapeProto_Dimension instances. |
623 | Values are merged into target from source. |
624 | If target has no dimension information, copy from source. |
625 | If source has no dimension information, ignore source. |
626 | If both have dimension information: |
627 | - Prefer values over params. If both have values, values must match. |
628 | - Prefer target param over source param if mismatched. |
629 | Fail if there are mismatches in dimension values. |
630 | Currently, there is no way to refine/update dimension information for the |
631 | source from information available in the target. |
632 | */ |
633 | inline void mergeInDimensionInfo( |
634 | const TensorShapeProto_Dimension& source_dim, |
635 | TensorShapeProto_Dimension& target_dim, |
636 | int dim_index) { |
637 | // if source has value, merge into target |
638 | // else if target has value, preserve it |
639 | // else merge params |
640 | if (source_dim.has_dim_value()) { |
641 | auto source_value = source_dim.dim_value(); |
642 | if (target_dim.has_dim_value()) { |
643 | auto target_value = target_dim.dim_value(); |
644 | if (target_value != source_value) { |
645 | fail_shape_inference( |
646 | "Can't merge shape info. " |
647 | "Both source and target dimension have values but they differ. Source=" , |
648 | source_value, |
649 | " Target=" , |
650 | target_value, |
651 | " Dimension=" , |
652 | dim_index); |
653 | } |
654 | } else { |
655 | target_dim.set_dim_value(source_value); |
656 | } |
657 | } else if (target_dim.has_dim_value()) { |
658 | // if target has a value we preserve it so do nothing |
659 | } else if (target_dim.has_dim_param()) { |
660 | // prefer target param over source |
661 | } else if (source_dim.has_dim_param()) { |
662 | target_dim.set_dim_param(source_dim.dim_param()); |
663 | } |
664 | } |
665 | |
666 | void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type); |
667 | |
668 | void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type); |
669 | |
670 | /* |
671 | Merge the shape information from two TypeProto_Tensor instances. |
672 | Values are merged into target from source. |
673 | If target has no shape information, copy from source. |
674 | If source has no shape information, ignore source. |
675 | If both have shape information: |
676 | - merge each TensorShapeProto_Dimension separately. |
677 | - Prefer values over params. If both have values, values must match. |
678 | - Prefer target param over source param if mismatched. |
679 | Fail if there are mismatches in number of dimensions or dimension values. |
680 | */ |
681 | void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target); |
682 | |
683 | void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target); |
684 | |
685 | // Return a copy of a type, with a specified dimension removed from its shape. |
686 | inline TypeProto RemoveIthDimensionFromShape(const TypeProto& proto, int removed_dim) { |
687 | TypeProto t(proto); |
688 | auto mutable_shape = t.mutable_tensor_type()->mutable_shape(); |
689 | mutable_shape->clear_dim(); |
690 | |
691 | const auto& dims = proto.tensor_type().shape().dim(); |
692 | |
693 | for (int j = 0, end = dims.size(); j < end; ++j) { |
694 | if (j != removed_dim) |
695 | (*mutable_shape->add_dim()) = dims.Get(j); |
696 | } |
697 | |
698 | return t; |
699 | } |
700 | |
701 | // Return a copy of a type, with specified number of dimensions removed from the |
702 | // beginning. |
703 | inline TypeProto RemoveDimensionsFromShape(const TypeProto& proto, int num_dimensions) { |
704 | TypeProto t(proto); |
705 | auto mutable_shape = t.mutable_tensor_type()->mutable_shape(); |
706 | mutable_shape->clear_dim(); |
707 | |
708 | const auto& dims = proto.tensor_type().shape().dim(); |
709 | |
710 | // skip first num_dimensions |
711 | for (int j = num_dimensions, end = dims.size(); j < end; ++j) { |
712 | (*mutable_shape->add_dim()) = dims.Get(j); |
713 | } |
714 | |
715 | return t; |
716 | } |
717 | |
718 | // copied from GSL: |
719 | // https://github.com/Microsoft/GSL/blob/master/include/gsl/gsl_util |
720 | template <class T, class U> |
721 | static constexpr T narrow_cast(U&& u) noexcept { |
722 | return static_cast<T>(std::forward<U>(u)); |
723 | } |
724 | |
725 | inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expected_rank) { |
726 | // We check the rank only if a rank is known for the input: |
727 | if (hasInputShape(ctx, input_index)) { |
728 | auto rank = getInputShape(ctx, input_index).dim_size(); |
729 | if (rank != expected_rank) { |
730 | fail_shape_inference("Input " , input_index, " expected to have rank " , expected_rank, " but has rank " , rank); |
731 | } |
732 | } |
733 | } |
734 | |
735 | // Unification (between dimensions and/or shapes) is at the heart of |
736 | // shape-inference. The current inference algorithm can check input |
737 | // shapes/dimensions of a node and update the output shapes/dimensions. It |
738 | // cannot currently update input shapes and dimensions (even though in some |
739 | // contexts this inference is possible). Hence, we have the variants below to |
740 | // support "const" and "mutable" dimensions/shapes in unification. |
741 | |
742 | inline void checkDimEquality(int64_t value1, int64_t value2) { |
743 | if (value1 != value2) { |
744 | fail_shape_inference("Dimension mismatch in unification between " , value1, " and " , value2); |
745 | } |
746 | } |
747 | |
748 | inline void unifyDim(const Dim& dim1, const Dim& dim2) { |
749 | if (dim1.has_dim_value() && dim2.has_dim_value()) |
750 | checkDimEquality(dim1.dim_value(), dim2.dim_value()); |
751 | } |
752 | |
753 | // TODO: The functionality of unifyDim is similar to that of |
754 | // mergeInDimensionInfo. However, the error messages are different. Leaving this |
755 | // duplication in-place to preserve error message content. |
756 | inline void unifyDim(const Dim& source_dim, Dim& target_dim) { |
757 | if (source_dim.has_dim_value()) { |
758 | auto source_value = source_dim.dim_value(); |
759 | if (target_dim.has_dim_value()) { |
760 | auto target_value = target_dim.dim_value(); |
761 | checkDimEquality(source_value, target_value); |
762 | } else { |
763 | target_dim.set_dim_value(source_value); |
764 | } |
765 | } else if (target_dim.has_dim_value()) { |
766 | // if target has a value we preserve it. |
767 | // we cannot set source dim value. |
768 | } else if (target_dim.has_dim_param()) { |
769 | // prefer target param over source |
770 | // we cannot currently unify the dim_params |
771 | } else if (source_dim.has_dim_param()) { |
772 | target_dim.set_dim_param(source_dim.dim_param()); |
773 | } |
774 | } |
775 | |
776 | inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_index, Dim& dim) { |
777 | // We unify the dimensions only if it is available for specified input: |
778 | if (hasInputShape(ctx, input_index)) { |
779 | auto& input_shape = getInputShape(ctx, input_index); |
780 | // This shape is expected to have rank > dim_index: |
781 | if (input_shape.dim_size() <= dim_index) { |
782 | fail_shape_inference( |
783 | "Input " , input_index, " expected to have rank >" , dim_index, " but has rank " , input_shape.dim_size()); |
784 | } |
785 | const Dim& input_dim = input_shape.dim(dim_index); |
786 | // Now, unify dim and input_dim: |
787 | unifyDim(input_dim, dim); |
788 | } |
789 | } |
790 | |
791 | // unifyDim: unifies a dimension with a constant value. If the dimension |
792 | // already has a value, we check for equality of new value with old value. |
793 | inline void unifyDim(Dim& dim, int64_t value) { |
794 | if (dim.has_dim_value()) { |
795 | checkDimEquality(dim.dim_value(), value); |
796 | } else |
797 | dim.set_dim_value(value); |
798 | } |
799 | |
800 | // target-shape = Union (target-shape, source_shape) |
801 | // Example 1: same rank, different dimensions |
802 | // input1 shape: (2, 3, 4, 'x') |
803 | // input2 shape: (2, 'y', 5, 'x') |
804 | // output shape: (2, None, None, 'x') |
805 | // Example 2: different rank |
806 | // input1 shape: (2, 3, 4, 'x') |
807 | // input2 shape: (2, 3, 4) |
808 | // output shape: None |
809 | void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type); |
810 | |
811 | void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type); |
812 | |
813 | // target-type = Union (target-type, source-type) |
814 | // target and source are required to have the same type. |
815 | // Example 1: same tensor type, different shape |
816 | // source: tensor elem_type: int64, shape: (2, 3, 4, 'x') |
817 | // target: tensor elem_type: int64, shape: (2, 'y', 5, 'x') |
818 | // output: tensor elem_type: int64, shape: (2, None, None, 'x') |
819 | // Example 2: same sequence type, different shape |
820 | // source: sequence of tensor, elem_type: float, shape: (2, 3, 4) |
821 | // target: sequence of tensor, elem_type: float, shape: None |
822 | // output: sequence of tensor, elem_type: float, shape: None |
823 | void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type); |
824 | |
825 | } // namespace ONNX_NAMESPACE |
826 | |