1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "shape_inference.h" |
6 | #include "onnx/defs/tensor_proto_util.h" |
7 | |
8 | namespace ONNX_NAMESPACE { |
9 | |
10 | // Note: for all methods below for propagating type or shape, callers are |
11 | // responsible to handle optional inputs/outputs and ensure that the specified |
12 | // index value is less than NumInputs/NumOutputs. |
13 | // Supports mixed tensor and sparse tensor |
14 | void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) { |
15 | auto input_type = ctx.getInputType(inputIndex); |
16 | if (nullptr == input_type) { |
17 | fail_type_inference("Input type was null" ); |
18 | } |
19 | |
20 | const auto input_value_case = input_type->value_case(); |
21 | if (input_value_case != TypeProto::kTensorType && input_value_case != TypeProto::kSparseTensorType) { |
22 | fail_type_inference( |
23 | "Input " , inputIndex, " expected to have tensor or sparse tensor type. Got: " , input_value_case); |
24 | } |
25 | |
26 | const auto input_elem_type = getTensorElementType(*input_type); |
27 | if (input_elem_type == TensorProto::UNDEFINED) { |
28 | fail_type_inference("Element type of input " , inputIndex, " unknown" ); |
29 | } |
30 | auto output_type = ctx.getOutputType(outputIndex); |
31 | const auto output_value_case = output_type->value_case(); |
32 | if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) { |
33 | setTensorElementType(input_elem_type, output_value_case, *output_type); |
34 | } else if (output_value_case == TypeProto::VALUE_NOT_SET) { |
35 | // Assume output will have the same type |
36 | setTensorElementType(input_elem_type, input_value_case, *output_type); |
37 | } else { |
38 | // This is not expected to happen |
39 | fail_type_inference( |
40 | "Output " , outputIndex, " expected to have tensor or sparse tensor type. Got: " , output_value_case); |
41 | } |
42 | } |
43 | |
44 | void propagateElemTypeFromSequenceInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) { |
45 | auto input_type = ctx.getInputType(inputIndex); |
46 | if (nullptr == input_type || input_type->value_case() != TypeProto::kSequenceType) { |
47 | fail_type_inference("Input " , inputIndex, " expected to have sequence type" ); |
48 | } |
49 | auto input_seq_type = input_type->sequence_type(); |
50 | if (!input_seq_type.has_elem_type()) { |
51 | fail_type_inference("Element type of sequence input " , inputIndex, " unknown" ); |
52 | } |
53 | |
54 | auto output_type = ctx.getOutputType(outputIndex); |
55 | output_type->mutable_sequence_type()->mutable_elem_type()->CopyFrom(input_seq_type.elem_type()); |
56 | } |
57 | |
58 | void propagateElemTypeFromOptionalInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) { |
59 | auto input_type = ctx.getInputType(inputIndex); |
60 | if (nullptr == input_type || input_type->value_case() != TypeProto::kOptionalType) { |
61 | fail_type_inference("Input " , inputIndex, " expected to have optional type" ); |
62 | } |
63 | auto input_opt_type = input_type->optional_type(); |
64 | if (!input_opt_type.has_elem_type()) { |
65 | fail_type_inference("Element type of optional input " , inputIndex, " unknown" ); |
66 | } |
67 | |
68 | auto output_type = ctx.getOutputType(outputIndex); |
69 | output_type->mutable_optional_type()->mutable_elem_type()->CopyFrom(input_opt_type.elem_type()); |
70 | } |
71 | |
72 | void propagateElemTypeFromMapInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) { |
73 | auto input_type = ctx.getInputType(inputIndex); |
74 | if (nullptr == input_type || input_type->value_case() != TypeProto::kMapType) { |
75 | fail_type_inference("Input " , inputIndex, " expected to have map type" ); |
76 | } |
77 | auto input_map_type = input_type->map_type(); |
78 | if (!input_map_type.has_key_type()) { |
79 | fail_type_inference("Key type of map input " , inputIndex, " unknown" ); |
80 | } |
81 | if (!input_map_type.has_value_type()) { |
82 | fail_type_inference("Value type of map input " , inputIndex, " unknown" ); |
83 | } |
84 | |
85 | auto output_type = ctx.getOutputType(outputIndex); |
86 | output_type->mutable_map_type()->set_key_type(input_map_type.key_type()); |
87 | output_type->mutable_map_type()->mutable_value_type()->CopyFrom(input_map_type.value_type()); |
88 | } |
89 | |
90 | void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) { |
91 | auto input_type = ctx.getInputType(inputIndex); |
92 | if (nullptr == input_type) { |
93 | fail_type_inference("Input " , inputIndex, " expected to have type but instead is null" ); |
94 | } |
95 | const auto input_value_case = input_type->value_case(); |
96 | if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) { |
97 | propagateElemTypeFromTensorInputToOutput(ctx, inputIndex, outputIndex); |
98 | } else if (input_value_case == TypeProto::kSequenceType) { |
99 | propagateElemTypeFromSequenceInputToOutput(ctx, inputIndex, outputIndex); |
100 | } else if (input_value_case == TypeProto::kOptionalType) { |
101 | propagateElemTypeFromOptionalInputToOutput(ctx, inputIndex, outputIndex); |
102 | } else if (input_value_case == TypeProto::kMapType) { |
103 | propagateElemTypeFromMapInputToOutput(ctx, inputIndex, outputIndex); |
104 | } |
105 | } |
106 | |
107 | /* |
108 | Merge shape information from a source shape into a target shape. |
109 | * merges each TensorShapeProto_Dimension separately. |
110 | * prefer values over params. |
111 | * If both have values, values must match. |
112 | * prefer target param over source param if mismatched. |
113 | * Fail if there are mismatches in number of dimensions or dimension values. |
114 | */ |
115 | void mergeInShapeInfo(const TensorShapeProto& source, TensorShapeProto& target) { |
116 | auto num_source_dims = source.dim_size(); |
117 | auto num_target_dims = target.dim_size(); |
118 | if (num_source_dims != num_target_dims) { |
119 | fail_shape_inference( |
120 | "Mismatch between number of source and target dimensions. Source=" , |
121 | num_source_dims, |
122 | " Target=" , |
123 | num_target_dims); |
124 | } |
125 | |
126 | auto& source_dims = source.dim(); |
127 | auto* target_dims = target.mutable_dim(); |
128 | |
129 | for (int i = 0, end = source_dims.size(); i < end; ++i) { |
130 | auto& source_dim = source_dims.Get(i); |
131 | auto& target_dim = *target_dims->Mutable(i); |
132 | mergeInDimensionInfo(source_dim, target_dim, i); |
133 | } |
134 | } |
135 | |
136 | void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) { |
137 | if (target_type.has_shape()) { |
138 | // merge with existing info. |
139 | mergeInShapeInfo(source_shape, *target_type.mutable_shape()); |
140 | } else { |
141 | // copy to target |
142 | (*target_type.mutable_shape()) = source_shape; |
143 | } |
144 | } |
145 | |
146 | void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) { |
147 | if (target_type.has_shape()) { |
148 | // merge with existing info. |
149 | mergeInShapeInfo(source_shape, *target_type.mutable_shape()); |
150 | } else { |
151 | // copy to target |
152 | (*target_type.mutable_shape()) = source_shape; |
153 | } |
154 | } |
155 | |
156 | /* |
157 | Merge the shape information from two TypeProto_Tensor instances. |
158 | Values are merged into target from source. |
159 | If target has no shape information, copy from source. |
160 | If source has no shape information, ignore source. |
161 | If both have shape information: |
162 | - merge each TensorShapeProto_Dimension separately. |
163 | - Prefer values over params. If both have values, values must match. |
164 | - Prefer target param over source param if mismatched. |
165 | Fail if there are mismatches in number of dimensions or dimension values. |
166 | */ |
167 | void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target) { |
168 | if (source.has_shape()) |
169 | mergeInShapeInfo(source.shape(), target); |
170 | } |
171 | |
172 | void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target) { |
173 | if (source.has_shape()) |
174 | mergeInShapeInfo(source.shape(), target); |
175 | } |
176 | |
177 | /// <summary> |
178 | /// Utility function for UnionShapeInfoForTensor. |
179 | /// Both shapes must be of the same rank |
180 | /// </summary> |
181 | /// <param name="source_shape"></param> |
182 | /// <param name="target_shape">destination shape</param> |
183 | void UnionShapeInfo(const TensorShapeProto& source_shape, TensorShapeProto& target_shape) { |
184 | auto source_rank = source_shape.dim_size(); |
185 | for (int i = 0; i < source_rank; ++i) { |
186 | const auto source_dim = source_shape.dim(i); |
187 | const auto target_dim = target_shape.dim(i); |
188 | bool is_dims_conflict = [&]() { |
189 | if (source_dim.has_dim_value()) { |
190 | if (target_dim.has_dim_value() && target_dim.dim_value() == source_dim.dim_value()) { |
191 | return false; |
192 | } |
193 | return true; |
194 | } |
195 | |
196 | if (source_dim.has_dim_param()) { |
197 | if (target_dim.has_dim_param() && target_dim.dim_param() == source_dim.dim_param()) { |
198 | return false; |
199 | } |
200 | return true; |
201 | } |
202 | |
203 | return (target_dim.has_dim_value() || target_dim.has_dim_param()); |
204 | }(); |
205 | if (is_dims_conflict && (target_dim.has_dim_value() || target_dim.has_dim_param())) { |
206 | auto dim = target_shape.mutable_dim(i); |
207 | dim->clear_dim_value(); |
208 | dim->clear_dim_param(); |
209 | } |
210 | } |
211 | } |
212 | |
213 | template <typename TENSOR_TYPE> |
214 | void UnionShapeInfoForTensor(const TensorShapeProto& source_shape, TENSOR_TYPE& target_type) { |
215 | if (target_type.has_shape()) { |
216 | TensorShapeProto* target_shape = target_type.mutable_shape(); |
217 | |
218 | auto source_rank = source_shape.dim_size(); |
219 | auto target_rank = target_shape->dim_size(); |
220 | if (source_rank != target_rank) { |
221 | target_type.clear_shape(); |
222 | return; |
223 | } |
224 | |
225 | UnionShapeInfo(source_shape, *target_shape); |
226 | } |
227 | } |
228 | |
229 | void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) { |
230 | UnionShapeInfoForTensor(source_shape, target_type); |
231 | } |
232 | |
233 | void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) { |
234 | UnionShapeInfoForTensor(source_shape, target_type); |
235 | } |
236 | |
237 | void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) { |
238 | if (source_type.value_case() != target_type.value_case()) { |
239 | fail_type_inference("Mismatched type:" , " source=" , source_type.value_case(), " target=" , target_type.value_case()); |
240 | } |
241 | |
242 | const auto target_case = target_type.value_case(); |
243 | if (target_case == TypeProto::ValueCase::kTensorType) { |
244 | auto source_elem_type = source_type.tensor_type().elem_type(); |
245 | auto target_elem_type = target_type.tensor_type().elem_type(); |
246 | |
247 | if (source_elem_type != target_elem_type) { |
248 | fail_type_inference( |
249 | "Mismatched tensor element type:" , " source=" , source_elem_type, " target=" , target_elem_type); |
250 | } |
251 | |
252 | UnionShapeInfoForTensor(source_type.tensor_type().shape(), *target_type.mutable_tensor_type()); |
253 | } else if (target_case == TypeProto::ValueCase::kSparseTensorType) { |
254 | auto source_elem_type = source_type.sparse_tensor_type().elem_type(); |
255 | auto target_elem_type = target_type.sparse_tensor_type().elem_type(); |
256 | if (source_elem_type != target_elem_type) { |
257 | fail_type_inference( |
258 | "Mismatched sparse tensor element type:" , " source=" , source_elem_type, " target=" , target_elem_type); |
259 | } |
260 | |
261 | UnionShapeInfoForTensor(source_type.sparse_tensor_type().shape(), *target_type.mutable_sparse_tensor_type()); |
262 | } else if (target_case == TypeProto::ValueCase::kSequenceType) { |
263 | if (!source_type.sequence_type().has_elem_type()) { |
264 | fail_type_inference("source sequence type missing element type." ); |
265 | } |
266 | if (!target_type.sequence_type().has_elem_type()) { |
267 | fail_type_inference("target sequence type missing element type." ); |
268 | } |
269 | UnionTypeInfo(source_type.sequence_type().elem_type(), *target_type.mutable_sequence_type()->mutable_elem_type()); |
270 | } else if (target_case == TypeProto::ValueCase::kOptionalType) { |
271 | if (!source_type.optional_type().has_elem_type()) { |
272 | fail_type_inference("source optional type missing element type." ); |
273 | } |
274 | if (!target_type.optional_type().has_elem_type()) { |
275 | fail_type_inference("target optional type missing element type." ); |
276 | } |
277 | UnionTypeInfo(source_type.optional_type().elem_type(), *target_type.mutable_optional_type()->mutable_elem_type()); |
278 | } else if (target_case == TypeProto::ValueCase::kMapType) { |
279 | if (!source_type.map_type().has_key_type()) { |
280 | fail_type_inference("source map type missing key type." ); |
281 | } |
282 | if (!target_type.map_type().has_key_type()) { |
283 | fail_type_inference("target map type missing key type." ); |
284 | } |
285 | auto source_key_type = source_type.map_type().key_type(); |
286 | auto target_key_type = target_type.map_type().key_type(); |
287 | if (source_key_type != target_key_type) { |
288 | fail_type_inference( |
289 | "Mismatched map tensor key type:" , |
290 | " source=" , |
291 | Utils::DataTypeUtils::ToDataTypeString(source_key_type), |
292 | " target=" , |
293 | Utils::DataTypeUtils::ToDataTypeString(target_key_type)); |
294 | } |
295 | |
296 | if (!source_type.map_type().has_value_type()) { |
297 | fail_type_inference("source map type missing value type." ); |
298 | } |
299 | if (!target_type.map_type().has_value_type()) { |
300 | fail_type_inference("target map type missing value type." ); |
301 | } |
302 | UnionTypeInfo(source_type.map_type().value_type(), *target_type.mutable_map_type()->mutable_value_type()); |
303 | } |
304 | } |
305 | |
306 | // Supports both Tensor and SparseTensor |
307 | // This does not fail if input_type is Tensor and output type is SparseTensor |
308 | // or the other way around. This is to support mixed cases when an op receives |
309 | // sparse input and outputs dense or vice-versa. |
310 | // If the output value_case is not set, then |
311 | // the input value_case is propagated. |
312 | void propagateTensorElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) { |
313 | if (nullptr == input_type) { |
314 | fail_type_inference("Input type was null" ); |
315 | } |
316 | |
317 | int32_t input_elem_type = TensorProto::UNDEFINED; |
318 | const auto input_value_case = input_type->value_case(); |
319 | if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) { |
320 | input_elem_type = getTensorElementType(*input_type); |
321 | if (input_elem_type == TensorProto::UNDEFINED) { |
322 | fail_type_inference("Element type of tensor or sparse tensor input was unknown" ); |
323 | } |
324 | } else { |
325 | fail_type_inference("Input was expected to have tensor or sparse tensor type. Got " , input_value_case); |
326 | } |
327 | |
328 | const auto output_value_case = output_type->value_case(); |
329 | if (output_value_case == TypeProto::VALUE_NOT_SET) { |
330 | setTensorElementType(input_elem_type, input_value_case, *output_type); |
331 | } else if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) { |
332 | const auto output_elem_type = getTensorElementType(*output_type); |
333 | if (output_elem_type != TensorProto::UNDEFINED) { |
334 | if (input_elem_type != output_elem_type) { |
335 | fail_type_inference( |
336 | "Input element type of " , input_elem_type, " does not match existing output type of " , output_elem_type); |
337 | } |
338 | } else { |
339 | setTensorElementType(input_elem_type, output_value_case, *output_type); |
340 | } |
341 | } else { |
342 | // This is not expected to happen |
343 | fail_type_inference("Output was expected to have tensor type. Got " , output_value_case); |
344 | } |
345 | } |
346 | |
347 | void propagateSequenceElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) { |
348 | if (nullptr == input_type) { |
349 | fail_type_inference("Input type was null" ); |
350 | } |
351 | |
352 | if (input_type->value_case() != TypeProto::kSequenceType) { |
353 | fail_type_inference("Input was expected to have sequence type. Got " , input_type->value_case()); |
354 | } |
355 | |
356 | auto input_seq_type = input_type->sequence_type(); |
357 | |
358 | if (input_seq_type.has_elem_type()) { |
359 | propagateElemTypeWithValidation( |
360 | &input_seq_type.elem_type(), output_type->mutable_sequence_type()->mutable_elem_type()); |
361 | } else { |
362 | fail_type_inference("Element type of sequence input was unknown" ); |
363 | } |
364 | } |
365 | |
366 | void propagateOptionalElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) { |
367 | if (nullptr == input_type) { |
368 | fail_type_inference("Input type was null" ); |
369 | } |
370 | |
371 | if (input_type->value_case() != TypeProto::kOptionalType) { |
372 | fail_type_inference("Input was expected to have optional type. Got " , input_type->value_case()); |
373 | } |
374 | |
375 | auto input_opt_type = input_type->optional_type(); |
376 | |
377 | if (input_opt_type.has_elem_type()) { |
378 | propagateElemTypeWithValidation( |
379 | &input_opt_type.elem_type(), output_type->mutable_optional_type()->mutable_elem_type()); |
380 | } else { |
381 | fail_type_inference("Element type of optional input was unknown" ); |
382 | } |
383 | } |
384 | |
385 | void propagateMapElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) { |
386 | if (nullptr == input_type) { |
387 | fail_type_inference("Input type was null" ); |
388 | } |
389 | |
390 | if (input_type->value_case() != TypeProto::kMapType) { |
391 | fail_type_inference("Input was expected to have map type. Got " , input_type->value_case()); |
392 | } |
393 | |
394 | auto input_map_type = input_type->map_type(); |
395 | |
396 | if (!input_map_type.has_key_type()) { |
397 | fail_type_inference("Key type of map input was unknown" ); |
398 | } |
399 | if (!input_map_type.has_value_type()) { |
400 | fail_type_inference("Value type of map input was unknown" ); |
401 | } |
402 | output_type->mutable_map_type()->set_key_type(input_map_type.key_type()); |
403 | propagateElemTypeWithValidation(&input_map_type.value_type(), output_type->mutable_map_type()->mutable_value_type()); |
404 | } |
405 | |
406 | // propagate the element type from an input type to an output type. |
407 | // if an existing output element type exists, validate it matches. |
408 | void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) { |
409 | if (nullptr == input_type) { |
410 | fail_type_inference("Input type was null" ); |
411 | } |
412 | |
413 | const auto input_value_case = input_type->value_case(); |
414 | if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) { |
415 | propagateTensorElemTypeWithValidation(input_type, output_type); |
416 | } else if (input_value_case == TypeProto::kSequenceType) { |
417 | propagateSequenceElemTypeWithValidation(input_type, output_type); |
418 | } else if (input_value_case == TypeProto::kOptionalType) { |
419 | propagateOptionalElemTypeWithValidation(input_type, output_type); |
420 | } else if (input_value_case == TypeProto::kMapType) { |
421 | propagateMapElemTypeWithValidation(input_type, output_type); |
422 | } else { |
423 | fail_type_inference( |
424 | "Input was expected to have either tensor, sequence, optional or map type. Got " , input_value_case); |
425 | } |
426 | } |
427 | |
428 | TensorShapeProto getShapeInput(InferenceContext& ctx, size_t input_index, bool& found) { |
429 | TensorShapeProto shape_input; |
430 | |
431 | // First, check initializer. |
432 | const TensorProto* shape_initializer = ctx.getInputData(input_index); |
433 | if (shape_initializer) { |
434 | const std::vector<int64_t>& shape_data = ParseData<int64_t>(shape_initializer); |
435 | for (const int64_t& e : shape_data) { |
436 | shape_input.add_dim()->set_dim_value(e); |
437 | } |
438 | found = true; |
439 | return shape_input; |
440 | } |
441 | |
442 | // Then, check symbolic input. |
443 | const TensorShapeProto* symbolic_input = ctx.getSymbolicInput(input_index); |
444 | if (symbolic_input) { |
445 | shape_input.CopyFrom(*symbolic_input); |
446 | found = true; |
447 | return shape_input; |
448 | } |
449 | |
450 | // Try rank inference. |
451 | if (hasInputShape(ctx, input_index)) { |
452 | const TensorShapeProto& shape_input_shape = getInputShape(ctx, input_index); |
453 | if (shape_input_shape.dim_size() != 1) { |
454 | fail_shape_inference("shape input must be 1D tensor" ); |
455 | } |
456 | if (shape_input_shape.dim(0).has_dim_value()) { |
457 | // Attempt rank inference using shape of shape input |
458 | int64_t dim_value = shape_input_shape.dim(0).dim_value(); |
459 | for (int64_t i = 0; i < dim_value; ++i) { |
460 | shape_input.add_dim(); |
461 | } |
462 | found = true; |
463 | return shape_input; |
464 | } |
465 | } |
466 | |
467 | // Shape input was not found. |
468 | found = false; |
469 | return shape_input; |
470 | } |
471 | |
472 | } // namespace ONNX_NAMESPACE |