1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "shape_inference.h"
6#include "onnx/defs/tensor_proto_util.h"
7
8namespace ONNX_NAMESPACE {
9
10// Note: for all methods below for propagating type or shape, callers are
11// responsible to handle optional inputs/outputs and ensure that the specified
12// index value is less than NumInputs/NumOutputs.
13// Supports mixed tensor and sparse tensor
14void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
15 auto input_type = ctx.getInputType(inputIndex);
16 if (nullptr == input_type) {
17 fail_type_inference("Input type was null");
18 }
19
20 const auto input_value_case = input_type->value_case();
21 if (input_value_case != TypeProto::kTensorType && input_value_case != TypeProto::kSparseTensorType) {
22 fail_type_inference(
23 "Input ", inputIndex, " expected to have tensor or sparse tensor type. Got: ", input_value_case);
24 }
25
26 const auto input_elem_type = getTensorElementType(*input_type);
27 if (input_elem_type == TensorProto::UNDEFINED) {
28 fail_type_inference("Element type of input ", inputIndex, " unknown");
29 }
30 auto output_type = ctx.getOutputType(outputIndex);
31 const auto output_value_case = output_type->value_case();
32 if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
33 setTensorElementType(input_elem_type, output_value_case, *output_type);
34 } else if (output_value_case == TypeProto::VALUE_NOT_SET) {
35 // Assume output will have the same type
36 setTensorElementType(input_elem_type, input_value_case, *output_type);
37 } else {
38 // This is not expected to happen
39 fail_type_inference(
40 "Output ", outputIndex, " expected to have tensor or sparse tensor type. Got: ", output_value_case);
41 }
42}
43
44void propagateElemTypeFromSequenceInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
45 auto input_type = ctx.getInputType(inputIndex);
46 if (nullptr == input_type || input_type->value_case() != TypeProto::kSequenceType) {
47 fail_type_inference("Input ", inputIndex, " expected to have sequence type");
48 }
49 auto input_seq_type = input_type->sequence_type();
50 if (!input_seq_type.has_elem_type()) {
51 fail_type_inference("Element type of sequence input ", inputIndex, " unknown");
52 }
53
54 auto output_type = ctx.getOutputType(outputIndex);
55 output_type->mutable_sequence_type()->mutable_elem_type()->CopyFrom(input_seq_type.elem_type());
56}
57
58void propagateElemTypeFromOptionalInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
59 auto input_type = ctx.getInputType(inputIndex);
60 if (nullptr == input_type || input_type->value_case() != TypeProto::kOptionalType) {
61 fail_type_inference("Input ", inputIndex, " expected to have optional type");
62 }
63 auto input_opt_type = input_type->optional_type();
64 if (!input_opt_type.has_elem_type()) {
65 fail_type_inference("Element type of optional input ", inputIndex, " unknown");
66 }
67
68 auto output_type = ctx.getOutputType(outputIndex);
69 output_type->mutable_optional_type()->mutable_elem_type()->CopyFrom(input_opt_type.elem_type());
70}
71
72void propagateElemTypeFromMapInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
73 auto input_type = ctx.getInputType(inputIndex);
74 if (nullptr == input_type || input_type->value_case() != TypeProto::kMapType) {
75 fail_type_inference("Input ", inputIndex, " expected to have map type");
76 }
77 auto input_map_type = input_type->map_type();
78 if (!input_map_type.has_key_type()) {
79 fail_type_inference("Key type of map input ", inputIndex, " unknown");
80 }
81 if (!input_map_type.has_value_type()) {
82 fail_type_inference("Value type of map input ", inputIndex, " unknown");
83 }
84
85 auto output_type = ctx.getOutputType(outputIndex);
86 output_type->mutable_map_type()->set_key_type(input_map_type.key_type());
87 output_type->mutable_map_type()->mutable_value_type()->CopyFrom(input_map_type.value_type());
88}
89
90void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
91 auto input_type = ctx.getInputType(inputIndex);
92 if (nullptr == input_type) {
93 fail_type_inference("Input ", inputIndex, " expected to have type but instead is null");
94 }
95 const auto input_value_case = input_type->value_case();
96 if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) {
97 propagateElemTypeFromTensorInputToOutput(ctx, inputIndex, outputIndex);
98 } else if (input_value_case == TypeProto::kSequenceType) {
99 propagateElemTypeFromSequenceInputToOutput(ctx, inputIndex, outputIndex);
100 } else if (input_value_case == TypeProto::kOptionalType) {
101 propagateElemTypeFromOptionalInputToOutput(ctx, inputIndex, outputIndex);
102 } else if (input_value_case == TypeProto::kMapType) {
103 propagateElemTypeFromMapInputToOutput(ctx, inputIndex, outputIndex);
104 }
105}
106
107/*
108Merge shape information from a source shape into a target shape.
109* merges each TensorShapeProto_Dimension separately.
110* prefer values over params.
111* If both have values, values must match.
112* prefer target param over source param if mismatched.
113* Fail if there are mismatches in number of dimensions or dimension values.
114*/
115void mergeInShapeInfo(const TensorShapeProto& source, TensorShapeProto& target) {
116 auto num_source_dims = source.dim_size();
117 auto num_target_dims = target.dim_size();
118 if (num_source_dims != num_target_dims) {
119 fail_shape_inference(
120 "Mismatch between number of source and target dimensions. Source=",
121 num_source_dims,
122 " Target=",
123 num_target_dims);
124 }
125
126 auto& source_dims = source.dim();
127 auto* target_dims = target.mutable_dim();
128
129 for (int i = 0, end = source_dims.size(); i < end; ++i) {
130 auto& source_dim = source_dims.Get(i);
131 auto& target_dim = *target_dims->Mutable(i);
132 mergeInDimensionInfo(source_dim, target_dim, i);
133 }
134}
135
136void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) {
137 if (target_type.has_shape()) {
138 // merge with existing info.
139 mergeInShapeInfo(source_shape, *target_type.mutable_shape());
140 } else {
141 // copy to target
142 (*target_type.mutable_shape()) = source_shape;
143 }
144}
145
146void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
147 if (target_type.has_shape()) {
148 // merge with existing info.
149 mergeInShapeInfo(source_shape, *target_type.mutable_shape());
150 } else {
151 // copy to target
152 (*target_type.mutable_shape()) = source_shape;
153 }
154}
155
156/*
157Merge the shape information from two TypeProto_Tensor instances.
158Values are merged into target from source.
159If target has no shape information, copy from source.
160If source has no shape information, ignore source.
161If both have shape information:
162- merge each TensorShapeProto_Dimension separately.
163- Prefer values over params. If both have values, values must match.
164- Prefer target param over source param if mismatched.
165Fail if there are mismatches in number of dimensions or dimension values.
166*/
167void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target) {
168 if (source.has_shape())
169 mergeInShapeInfo(source.shape(), target);
170}
171
172void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target) {
173 if (source.has_shape())
174 mergeInShapeInfo(source.shape(), target);
175}
176
177/// <summary>
178/// Utility function for UnionShapeInfoForTensor.
179/// Both shapes must be of the same rank
180/// </summary>
181/// <param name="source_shape"></param>
182/// <param name="target_shape">destination shape</param>
183void UnionShapeInfo(const TensorShapeProto& source_shape, TensorShapeProto& target_shape) {
184 auto source_rank = source_shape.dim_size();
185 for (int i = 0; i < source_rank; ++i) {
186 const auto source_dim = source_shape.dim(i);
187 const auto target_dim = target_shape.dim(i);
188 bool is_dims_conflict = [&]() {
189 if (source_dim.has_dim_value()) {
190 if (target_dim.has_dim_value() && target_dim.dim_value() == source_dim.dim_value()) {
191 return false;
192 }
193 return true;
194 }
195
196 if (source_dim.has_dim_param()) {
197 if (target_dim.has_dim_param() && target_dim.dim_param() == source_dim.dim_param()) {
198 return false;
199 }
200 return true;
201 }
202
203 return (target_dim.has_dim_value() || target_dim.has_dim_param());
204 }();
205 if (is_dims_conflict && (target_dim.has_dim_value() || target_dim.has_dim_param())) {
206 auto dim = target_shape.mutable_dim(i);
207 dim->clear_dim_value();
208 dim->clear_dim_param();
209 }
210 }
211}
212
213template <typename TENSOR_TYPE>
214void UnionShapeInfoForTensor(const TensorShapeProto& source_shape, TENSOR_TYPE& target_type) {
215 if (target_type.has_shape()) {
216 TensorShapeProto* target_shape = target_type.mutable_shape();
217
218 auto source_rank = source_shape.dim_size();
219 auto target_rank = target_shape->dim_size();
220 if (source_rank != target_rank) {
221 target_type.clear_shape();
222 return;
223 }
224
225 UnionShapeInfo(source_shape, *target_shape);
226 }
227}
228
229void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) {
230 UnionShapeInfoForTensor(source_shape, target_type);
231}
232
233void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
234 UnionShapeInfoForTensor(source_shape, target_type);
235}
236
237void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) {
238 if (source_type.value_case() != target_type.value_case()) {
239 fail_type_inference("Mismatched type:", " source=", source_type.value_case(), " target=", target_type.value_case());
240 }
241
242 const auto target_case = target_type.value_case();
243 if (target_case == TypeProto::ValueCase::kTensorType) {
244 auto source_elem_type = source_type.tensor_type().elem_type();
245 auto target_elem_type = target_type.tensor_type().elem_type();
246
247 if (source_elem_type != target_elem_type) {
248 fail_type_inference(
249 "Mismatched tensor element type:", " source=", source_elem_type, " target=", target_elem_type);
250 }
251
252 UnionShapeInfoForTensor(source_type.tensor_type().shape(), *target_type.mutable_tensor_type());
253 } else if (target_case == TypeProto::ValueCase::kSparseTensorType) {
254 auto source_elem_type = source_type.sparse_tensor_type().elem_type();
255 auto target_elem_type = target_type.sparse_tensor_type().elem_type();
256 if (source_elem_type != target_elem_type) {
257 fail_type_inference(
258 "Mismatched sparse tensor element type:", " source=", source_elem_type, " target=", target_elem_type);
259 }
260
261 UnionShapeInfoForTensor(source_type.sparse_tensor_type().shape(), *target_type.mutable_sparse_tensor_type());
262 } else if (target_case == TypeProto::ValueCase::kSequenceType) {
263 if (!source_type.sequence_type().has_elem_type()) {
264 fail_type_inference("source sequence type missing element type.");
265 }
266 if (!target_type.sequence_type().has_elem_type()) {
267 fail_type_inference("target sequence type missing element type.");
268 }
269 UnionTypeInfo(source_type.sequence_type().elem_type(), *target_type.mutable_sequence_type()->mutable_elem_type());
270 } else if (target_case == TypeProto::ValueCase::kOptionalType) {
271 if (!source_type.optional_type().has_elem_type()) {
272 fail_type_inference("source optional type missing element type.");
273 }
274 if (!target_type.optional_type().has_elem_type()) {
275 fail_type_inference("target optional type missing element type.");
276 }
277 UnionTypeInfo(source_type.optional_type().elem_type(), *target_type.mutable_optional_type()->mutable_elem_type());
278 } else if (target_case == TypeProto::ValueCase::kMapType) {
279 if (!source_type.map_type().has_key_type()) {
280 fail_type_inference("source map type missing key type.");
281 }
282 if (!target_type.map_type().has_key_type()) {
283 fail_type_inference("target map type missing key type.");
284 }
285 auto source_key_type = source_type.map_type().key_type();
286 auto target_key_type = target_type.map_type().key_type();
287 if (source_key_type != target_key_type) {
288 fail_type_inference(
289 "Mismatched map tensor key type:",
290 " source=",
291 Utils::DataTypeUtils::ToDataTypeString(source_key_type),
292 " target=",
293 Utils::DataTypeUtils::ToDataTypeString(target_key_type));
294 }
295
296 if (!source_type.map_type().has_value_type()) {
297 fail_type_inference("source map type missing value type.");
298 }
299 if (!target_type.map_type().has_value_type()) {
300 fail_type_inference("target map type missing value type.");
301 }
302 UnionTypeInfo(source_type.map_type().value_type(), *target_type.mutable_map_type()->mutable_value_type());
303 }
304}
305
306// Supports both Tensor and SparseTensor
307// This does not fail if input_type is Tensor and output type is SparseTensor
308// or the other way around. This is to support mixed cases when an op receives
309// sparse input and outputs dense or vice-versa.
310// If the output value_case is not set, then
311// the input value_case is propagated.
312void propagateTensorElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
313 if (nullptr == input_type) {
314 fail_type_inference("Input type was null");
315 }
316
317 int32_t input_elem_type = TensorProto::UNDEFINED;
318 const auto input_value_case = input_type->value_case();
319 if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) {
320 input_elem_type = getTensorElementType(*input_type);
321 if (input_elem_type == TensorProto::UNDEFINED) {
322 fail_type_inference("Element type of tensor or sparse tensor input was unknown");
323 }
324 } else {
325 fail_type_inference("Input was expected to have tensor or sparse tensor type. Got ", input_value_case);
326 }
327
328 const auto output_value_case = output_type->value_case();
329 if (output_value_case == TypeProto::VALUE_NOT_SET) {
330 setTensorElementType(input_elem_type, input_value_case, *output_type);
331 } else if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
332 const auto output_elem_type = getTensorElementType(*output_type);
333 if (output_elem_type != TensorProto::UNDEFINED) {
334 if (input_elem_type != output_elem_type) {
335 fail_type_inference(
336 "Input element type of ", input_elem_type, " does not match existing output type of ", output_elem_type);
337 }
338 } else {
339 setTensorElementType(input_elem_type, output_value_case, *output_type);
340 }
341 } else {
342 // This is not expected to happen
343 fail_type_inference("Output was expected to have tensor type. Got ", output_value_case);
344 }
345}
346
347void propagateSequenceElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
348 if (nullptr == input_type) {
349 fail_type_inference("Input type was null");
350 }
351
352 if (input_type->value_case() != TypeProto::kSequenceType) {
353 fail_type_inference("Input was expected to have sequence type. Got ", input_type->value_case());
354 }
355
356 auto input_seq_type = input_type->sequence_type();
357
358 if (input_seq_type.has_elem_type()) {
359 propagateElemTypeWithValidation(
360 &input_seq_type.elem_type(), output_type->mutable_sequence_type()->mutable_elem_type());
361 } else {
362 fail_type_inference("Element type of sequence input was unknown");
363 }
364}
365
366void propagateOptionalElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
367 if (nullptr == input_type) {
368 fail_type_inference("Input type was null");
369 }
370
371 if (input_type->value_case() != TypeProto::kOptionalType) {
372 fail_type_inference("Input was expected to have optional type. Got ", input_type->value_case());
373 }
374
375 auto input_opt_type = input_type->optional_type();
376
377 if (input_opt_type.has_elem_type()) {
378 propagateElemTypeWithValidation(
379 &input_opt_type.elem_type(), output_type->mutable_optional_type()->mutable_elem_type());
380 } else {
381 fail_type_inference("Element type of optional input was unknown");
382 }
383}
384
385void propagateMapElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
386 if (nullptr == input_type) {
387 fail_type_inference("Input type was null");
388 }
389
390 if (input_type->value_case() != TypeProto::kMapType) {
391 fail_type_inference("Input was expected to have map type. Got ", input_type->value_case());
392 }
393
394 auto input_map_type = input_type->map_type();
395
396 if (!input_map_type.has_key_type()) {
397 fail_type_inference("Key type of map input was unknown");
398 }
399 if (!input_map_type.has_value_type()) {
400 fail_type_inference("Value type of map input was unknown");
401 }
402 output_type->mutable_map_type()->set_key_type(input_map_type.key_type());
403 propagateElemTypeWithValidation(&input_map_type.value_type(), output_type->mutable_map_type()->mutable_value_type());
404}
405
406// propagate the element type from an input type to an output type.
407// if an existing output element type exists, validate it matches.
408void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
409 if (nullptr == input_type) {
410 fail_type_inference("Input type was null");
411 }
412
413 const auto input_value_case = input_type->value_case();
414 if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) {
415 propagateTensorElemTypeWithValidation(input_type, output_type);
416 } else if (input_value_case == TypeProto::kSequenceType) {
417 propagateSequenceElemTypeWithValidation(input_type, output_type);
418 } else if (input_value_case == TypeProto::kOptionalType) {
419 propagateOptionalElemTypeWithValidation(input_type, output_type);
420 } else if (input_value_case == TypeProto::kMapType) {
421 propagateMapElemTypeWithValidation(input_type, output_type);
422 } else {
423 fail_type_inference(
424 "Input was expected to have either tensor, sequence, optional or map type. Got ", input_value_case);
425 }
426}
427
428TensorShapeProto getShapeInput(InferenceContext& ctx, size_t input_index, bool& found) {
429 TensorShapeProto shape_input;
430
431 // First, check initializer.
432 const TensorProto* shape_initializer = ctx.getInputData(input_index);
433 if (shape_initializer) {
434 const std::vector<int64_t>& shape_data = ParseData<int64_t>(shape_initializer);
435 for (const int64_t& e : shape_data) {
436 shape_input.add_dim()->set_dim_value(e);
437 }
438 found = true;
439 return shape_input;
440 }
441
442 // Then, check symbolic input.
443 const TensorShapeProto* symbolic_input = ctx.getSymbolicInput(input_index);
444 if (symbolic_input) {
445 shape_input.CopyFrom(*symbolic_input);
446 found = true;
447 return shape_input;
448 }
449
450 // Try rank inference.
451 if (hasInputShape(ctx, input_index)) {
452 const TensorShapeProto& shape_input_shape = getInputShape(ctx, input_index);
453 if (shape_input_shape.dim_size() != 1) {
454 fail_shape_inference("shape input must be 1D tensor");
455 }
456 if (shape_input_shape.dim(0).has_dim_value()) {
457 // Attempt rank inference using shape of shape input
458 int64_t dim_value = shape_input_shape.dim(0).dim_value();
459 for (int64_t i = 0; i < dim_value; ++i) {
460 shape_input.add_dim();
461 }
462 found = true;
463 return shape_input;
464 }
465 }
466
467 // Shape input was not found.
468 found = false;
469 return shape_input;
470}
471
472} // namespace ONNX_NAMESPACE