1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "onnx/shape_inference/implementation.h"
6#include <fstream>
7#include <list>
8#include "onnx/checker.h"
9#include "onnx/common/file_utils.h"
10#include "onnx/defs/data_type_utils.h"
11#include "onnx/proto_utils.h"
12#include "onnx/string_utils.h"
13
14namespace ONNX_NAMESPACE {
15namespace shape_inference {
16namespace {
17
18std::string GetValueCaseString(const TypeProto& type) {
19 switch (type.value_case()) {
20 case TypeProto::ValueCase::kTensorType:
21 return "tensor_type";
22 case TypeProto::ValueCase::kSequenceType:
23 return "sequence_type";
24 case TypeProto::ValueCase::kMapType:
25 return "map_type";
26 case TypeProto::ValueCase::kOptionalType:
27 return "optional_type";
28#ifdef ONNX_ML
29 case TypeProto::ValueCase::kOpaqueType:
30 return "opaque_type";
31#endif
32 case TypeProto::ValueCase::kSparseTensorType:
33 return "sparse_tensor_type";
34 case TypeProto::ValueCase::VALUE_NOT_SET:
35 return "NOT_SET";
36 default:
37 return ONNX_NAMESPACE::to_string(type.value_case());
38 }
39}
40
41std::string GetElemTypeString(const TypeProto_Tensor& type) {
42#ifndef ONNX_USE_LITE_PROTO
43 const std::string type_str = TensorProto::DataType_Name(static_cast<TensorProto_DataType>(type.elem_type()));
44 if (!type_str.empty()) {
45 return type_str;
46 }
47#endif
48 return ONNX_NAMESPACE::to_string(type.elem_type());
49}
50
51std::string GetElemTypeString(const TypeProto_SparseTensor& type) {
52#ifndef ONNX_USE_LITE_PROTO
53 const std::string type_str = TensorProto::DataType_Name(static_cast<TensorProto_DataType>(type.elem_type()));
54 if (!type_str.empty()) {
55 return type_str;
56 }
57#endif
58 return ONNX_NAMESPACE::to_string(type.elem_type());
59}
60
61} // namespace
62
63template <class T>
64void CheckTensorShapesAndTypes(const T& inferred_type, const T& existing_type) {
65 if (inferred_type.elem_type() != TensorProto::UNDEFINED && existing_type.elem_type() != TensorProto::UNDEFINED &&
66 existing_type.elem_type() != inferred_type.elem_type()) {
67 std::stringstream ss;
68 ss << "Inferred elem type differs from existing elem type: (" << GetElemTypeString(inferred_type) << ") vs ("
69 << GetElemTypeString(existing_type) << ")";
70 fail_type_inference(ss.str());
71 }
72
73 if (!inferred_type.has_shape() || !existing_type.has_shape()) {
74 return;
75 }
76
77 if (inferred_type.shape().dim_size() != existing_type.shape().dim_size()) {
78 std::stringstream ss;
79 ss << "Inferred shape and existing shape differ in rank: (" << inferred_type.shape().dim_size() << ") vs ("
80 << existing_type.shape().dim_size() << ")";
81 fail_shape_inference(ss.str());
82 }
83
84 for (int i = 0; i < inferred_type.shape().dim_size(); ++i) {
85 const auto& inferred_dim = inferred_type.shape().dim(i);
86 const auto& existing_dim = existing_type.shape().dim(i);
87 if (inferred_dim.has_dim_value() && existing_dim.has_dim_value() &&
88 inferred_dim.dim_value() != existing_dim.dim_value()) {
89 std::stringstream ss;
90 ss << "Inferred shape and existing shape differ in dimension " << i << ": (" << inferred_dim.dim_value()
91 << ") vs (" << existing_dim.dim_value() << ")";
92 fail_shape_inference(ss.str());
93 }
94 }
95}
96
97void checkShapesAndTypes(const TypeProto& inferred_type, const TypeProto& existing_type) {
98 const auto inferred_value_case = inferred_type.value_case();
99 const auto existing_value_case = existing_type.value_case();
100 if (inferred_value_case == TypeProto::ValueCase::VALUE_NOT_SET ||
101 existing_value_case == TypeProto::ValueCase::VALUE_NOT_SET) {
102 // nothing to check; will assign inferredType to undefined existingType
103 return;
104 }
105 if (inferred_value_case != existing_value_case) {
106 fail_type_inference(
107 "type case mismatch. existing=",
108 GetValueCaseString(existing_type),
109 " inferred=",
110 GetValueCaseString(inferred_type));
111 }
112
113 if (inferred_value_case == TypeProto::kTensorType && existing_value_case == TypeProto::kTensorType) {
114 CheckTensorShapesAndTypes(inferred_type.tensor_type(), existing_type.tensor_type());
115 } else if (
116 inferred_value_case == TypeProto::kSparseTensorType && existing_value_case == TypeProto::kSparseTensorType) {
117 CheckTensorShapesAndTypes(inferred_type.sparse_tensor_type(), existing_type.sparse_tensor_type());
118 } else if (inferred_value_case == TypeProto::kSequenceType && existing_value_case == TypeProto::kSequenceType) {
119 checkShapesAndTypes(inferred_type.sequence_type().elem_type(), existing_type.sequence_type().elem_type());
120 } else if (inferred_value_case == TypeProto::kOptionalType && existing_value_case == TypeProto::kOptionalType) {
121 checkShapesAndTypes(inferred_type.optional_type().elem_type(), existing_type.optional_type().elem_type());
122 } else if (
123 inferred_value_case == TypeProto::TypeProto::kMapType && existing_value_case == TypeProto::TypeProto::kMapType) {
124 if (inferred_type.map_type().key_type() != existing_type.map_type().key_type()) {
125 fail_type_inference(
126 "key type mismatch from MapProto. existing=",
127 Utils::DataTypeUtils::ToDataTypeString(existing_type.map_type().key_type()),
128 " inferred=",
129 Utils::DataTypeUtils::ToDataTypeString(inferred_type.map_type().key_type()));
130 }
131 checkShapesAndTypes(inferred_type.map_type().value_type(), existing_type.map_type().value_type());
132 } else {
133 fail_type_inference("type case unsupported. existing=", existing_value_case, " inferred=", inferred_value_case);
134 }
135}
136
137void mergeShapesAndTypes(const TypeProto_Tensor& inferred_type, TypeProto_Tensor* existing_type) {
138 if (existing_type->elem_type() == TensorProto::UNDEFINED) {
139 existing_type->set_elem_type(inferred_type.elem_type());
140 }
141
142 if (!inferred_type.has_shape()) {
143 return;
144 }
145
146 if (!existing_type->has_shape()) {
147 *existing_type->mutable_shape() = inferred_type.shape();
148 return;
149 }
150
151 for (int i = 0; i < inferred_type.shape().dim_size(); ++i) {
152 const auto& inferred_dim = inferred_type.shape().dim(i);
153 auto* existing_dim = existing_type->mutable_shape()->mutable_dim(i);
154 if ((!existing_dim->has_dim_value() && !existing_dim->has_dim_param()) || inferred_dim.has_dim_value()) {
155 *existing_dim = inferred_dim;
156 }
157 }
158}
159
160void mergeShapesAndTypes(const TypeProto_SparseTensor& inferred_type, TypeProto_SparseTensor* existing_type) {
161 if (existing_type->elem_type() == TensorProto::UNDEFINED) {
162 existing_type->set_elem_type(inferred_type.elem_type());
163 }
164
165 if (!inferred_type.has_shape()) {
166 return;
167 }
168
169 if (!existing_type->has_shape()) {
170 *existing_type->mutable_shape() = inferred_type.shape();
171 return;
172 }
173
174 for (int i = 0; i < inferred_type.shape().dim_size(); ++i) {
175 const auto& inferred_dim = inferred_type.shape().dim(i);
176 auto* existing_dim = existing_type->mutable_shape()->mutable_dim(i);
177 if ((!existing_dim->has_dim_value() && !existing_dim->has_dim_param()) || inferred_dim.has_dim_value()) {
178 *existing_dim = inferred_dim;
179 }
180 }
181}
182
183void mergeShapesAndTypes(const TypeProto& inferred_type, TypeProto* existing_type) {
184 // Check before merge
185 checkShapesAndTypes(inferred_type, *existing_type);
186 const auto inferred_val_case = inferred_type.value_case();
187 if (inferred_val_case == TypeProto::kTensorType) {
188 mergeShapesAndTypes(inferred_type.tensor_type(), existing_type->mutable_tensor_type());
189 } else if (inferred_val_case == TypeProto::kSparseTensorType) {
190 mergeShapesAndTypes(inferred_type.sparse_tensor_type(), existing_type->mutable_sparse_tensor_type());
191 } else if (inferred_val_case == TypeProto::kSequenceType) {
192 mergeShapesAndTypes(
193 inferred_type.sequence_type().elem_type(), existing_type->mutable_sequence_type()->mutable_elem_type());
194 } else if (inferred_val_case == TypeProto::kOptionalType) {
195 mergeShapesAndTypes(
196 inferred_type.optional_type().elem_type(), existing_type->mutable_optional_type()->mutable_elem_type());
197 } else if (inferred_val_case == TypeProto::kMapType) {
198 mergeShapesAndTypes(inferred_type.map_type().value_type(), existing_type->mutable_map_type()->mutable_value_type());
199 }
200}
201
202// TypeProto_Tensor or TypeProto_SparseTensor
203template <typename TensorTypeProto>
204void GenerateSymbolicShape(TensorTypeProto* inferred_type, SymbolTable& symbol_table) {
205 if (!inferred_type->has_shape()) {
206 return;
207 }
208 for (int i = 0; i < inferred_type->shape().dim_size(); ++i) {
209 // set a symbol if it doesn't have dim_value and dim_param
210 auto* dim = inferred_type->mutable_shape()->mutable_dim(i);
211 if (!dim->has_dim_value() && !dim->has_dim_param()) {
212 dim->set_dim_param(symbol_table.createNew("unk__"));
213 }
214 }
215}
216
217void MaterializeSymbolicShape(TypeProto* inferred_type, SymbolTable& symbol_table) {
218 const auto inferred_val_case = inferred_type->value_case();
219 if (inferred_val_case == TypeProto::ValueCase::VALUE_NOT_SET) {
220 return;
221 }
222
223 if (inferred_val_case == TypeProto::kTensorType) {
224 GenerateSymbolicShape(inferred_type->mutable_tensor_type(), symbol_table);
225 } else if (inferred_val_case == TypeProto::kSparseTensorType) {
226 GenerateSymbolicShape(inferred_type->mutable_sparse_tensor_type(), symbol_table);
227 } else if (inferred_val_case == TypeProto::kSequenceType) {
228 MaterializeSymbolicShape(inferred_type->mutable_sequence_type()->mutable_elem_type(), symbol_table);
229 } else if (inferred_val_case == TypeProto::kOptionalType) {
230 MaterializeSymbolicShape(inferred_type->mutable_optional_type()->mutable_elem_type(), symbol_table);
231 } else {
232 fail_shape_inference("type case unsupported for symbolic shape inference. inferred=", inferred_val_case);
233 }
234}
235
236std::string GetModelLocalFunctionsMapIdentifier(const std::string& domain, const std::string& func_name) {
237 return domain + ":" + func_name;
238}
239
240class ShapeInferenceImplBase {
241 public:
242 void updateType(const std::string& name, TypeProto* inferred_type) {
243 if (inferred_type->value_case() == TypeProto::ValueCase::VALUE_NOT_SET) {
244 return;
245 }
246
247 if (symbol_table) {
248 MaterializeSymbolicShape(inferred_type, *symbol_table);
249 }
250
251 // Find any pre-existing type and shape info. If there is such,
252 // then check for compatibility with the inferred
253 // information. Otherwise, initialize it in an empty state.
254 auto iter = value_types_by_name.find(name);
255 TypeProto* existing_type = nullptr;
256 if (iter != value_types_by_name.end()) {
257 existing_type = iter->second;
258 } else {
259 // Create a new value_info if defined type does not exist
260 auto vi = g.add_value_info(); // TODO: clean this up
261 vi->set_name(name);
262 existing_type = vi->mutable_type();
263 // For undefined output type, update both value_info and output for now
264 // Update existing output with undefined type: assign inferred type to it
265 iter = undefined_value_types_by_name.find(name);
266 if (iter != undefined_value_types_by_name.end()) {
267 *iter->second = *inferred_type;
268 }
269 }
270
271 // TODO: cleanup this by merging with previous if-else
272 // Now we can merge pre-existing and inferred info
273 mergeShapesAndTypes(*inferred_type, existing_type);
274
275 // Make merged info available to further inference.
276 value_types_by_name[name] = existing_type;
277 }
278
279 void updateType(ValueInfoProto& valueInfo) {
280 if (valueInfo.has_type()) {
281 value_types_by_name[valueInfo.name()] = valueInfo.mutable_type();
282 } else {
283 undefined_value_types_by_name[valueInfo.name()] = valueInfo.mutable_type();
284 }
285 }
286
287 void preprocess(const NodeProto& n) {
288 if (checker::check_is_experimental_op(n)) {
289 has_experimental_op = true;
290 } else if (n.op_type() == "Constant" && n.output().size() == 1) {
291 for (const auto& attr : n.attribute()) {
292 if (attr.name() == "value") {
293 if (attr.type() == AttributeProto::TENSOR && attr.has_t()) {
294 input_data_by_name[n.output(0)] = &attr.t();
295 } else if (attr.type() == AttributeProto::SPARSE_TENSOR && attr.has_sparse_tensor()) {
296 input_sparse_data_by_name[n.output(0)] = &attr.sparse_tensor();
297 }
298 } else if (attr.type() == AttributeProto::INTS && attr.name() == "value_ints") {
299 std::vector<int64_t> ints{attr.ints().begin(), attr.ints().end()};
300 input_data_by_name_holder[n.output(0)] = ToTensor(ints);
301 input_data_by_name[n.output(0)] = &input_data_by_name_holder[n.output(0)];
302 }
303 }
304 }
305 }
306
307 void process(NodeProto& n) {
308 // Resolve domain for node
309 auto dit = opset_imports.find(n.domain());
310 if (dit == opset_imports.end()) {
311 // Both "" and "ai.onnx" refer to the default ONNX domain
312 if (n.domain() == "") {
313 dit = opset_imports.find("ai.onnx");
314 }
315 if (dit == opset_imports.end()) {
316 fail_type_inference(
317 "Cannot infer type and shape for node name ",
318 n.name(),
319 ". No opset import for domain",
320 n.domain(),
321 " optype ",
322 n.op_type());
323 }
324 }
325 auto domain_version = dit->second;
326 const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
327 InferenceContextImpl ctx(
328 n,
329 value_types_by_name,
330 input_data_by_name,
331 input_sparse_data_by_name,
332 generated_shape_data_by_name,
333 &graph_inference_context);
334
335 ONNX_TRY {
336 if (schema) {
337 if (schema->has_type_and_shape_inference_function()) {
338 schema->GetTypeAndShapeInferenceFunction()(ctx);
339 } else if (schema->HasFunction()) {
340 InferShapeForFunctionNode(
341 *(schema->GetFunction()),
342 schema_registry,
343 ctx,
344 options,
345 model_local_functions_map,
346 symbol_table,
347 generated_shape_data_by_name);
348 } else {
349 // Continue with inference for remaining nodes
350 return;
351 }
352 } else if (model_local_functions_map.size() > 0) {
353 auto iter = model_local_functions_map.find(GetModelLocalFunctionsMapIdentifier(n.domain(), n.op_type()));
354 if (iter != model_local_functions_map.end()) {
355 InferShapeForFunctionNode(
356 *(iter->second),
357 schema_registry,
358 ctx,
359 options,
360 model_local_functions_map,
361 symbol_table,
362 generated_shape_data_by_name);
363 } else {
364 has_unsupported_op = true;
365 return;
366 }
367 } else {
368 has_unsupported_op = true;
369 return;
370 }
371 }
372 ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
373 ONNX_HANDLE_EXCEPTION([&]() {
374 // onnx does not support unsupported/experimental operators
375 // so it won't consider it as an error
376 if (!has_unsupported_op && !has_experimental_op) {
377 inference_errors.push_back(GetErrorWithNodeInfo(n, ex));
378 }
379 });
380 // Continue with inference for remaining nodes
381 return;
382 }
383
384 ONNX_TRY {
385 // check the type-equality for input and output
386 if (options.check_type && schema) {
387 schema->CheckInputOutputType(ctx);
388 }
389
390 for (int i = 0; i < n.output_size(); ++i) {
391 // skip type and shape propagation for missing optional outputs.
392 if (!n.output(i).empty())
393 updateType(n.output(i), ctx.getOutputType(i));
394 }
395
396 preprocess(n);
397
398 // If data propagation is enabled, propagate shape data if it exists.
399 if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) {
400 if (generated_shape_data_by_name == nullptr) {
401 fail_shape_inference(
402 "Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
403 }
404 DataPropagationContextImpl data_propagation_ctx(
405 n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name);
406 schema->GetDataPropagationFunction()(data_propagation_ctx);
407 }
408 }
409 ONNX_CATCH(const std::runtime_error& err) {
410 ONNX_HANDLE_EXCEPTION([&]() { fail_shape_inference(GetErrorWithNodeInfo(n, err)); });
411 }
412 }
413
414 // TypeProto_Tensor or TypeProto_SparseTensor
415 template <typename T>
416 void processInitializer(
417 const std::string& name,
418 const T& tensorValue,
419 TypeProto& initializer_type,
420 std::unordered_map<std::string, const T*>& map) {
421 map[name] = &tensorValue;
422 auto iter = value_types_by_name.find(name);
423 // If it already exists in input, check input and initializer is sync
424 // use shape info from input (input has priority over initializer)
425 if (iter != value_types_by_name.end()) {
426 checkShapesAndTypes(initializer_type, *iter->second);
427 // CheckTensorShapesAndTypes(*initializer_tensor_type, *iter->second->mutable_tensor_type());
428 }
429 // Support IR>=4: some tensors can only exist in initializer and not in input
430 // So shape_inference should make use of initializer shapes
431 // Store initializer shape info in value_info as well
432 else if (ir_version >= 4) {
433 initializer_type_list.push_back(std::move(initializer_type));
434 value_types_by_name[name] = &initializer_type_list.back();
435 }
436 }
437
438 void process(GraphProto& graph) {
439 if (symbol_table) {
440 TraverseGraphsToAddExistingSymbols(graph, *symbol_table);
441 }
442 for (auto& vi : *graph.mutable_value_info()) {
443 updateType(vi);
444 }
445 for (auto& vi : *graph.mutable_input()) {
446 updateType(vi);
447 }
448 for (auto& vi : *graph.mutable_output()) {
449 updateType(vi);
450 }
451 for (const auto& tp : graph.initializer()) {
452 TypeProto initializer_type;
453 TypeProto_Tensor* initializer_tensor_type = initializer_type.mutable_tensor_type();
454 initializer_tensor_type->set_elem_type(tp.data_type());
455 // set the shape according to the initializer shape info
456 auto* shape = initializer_tensor_type->mutable_shape();
457 for (int i = 0; i < tp.dims_size(); ++i) {
458 shape->add_dim()->set_dim_value(tp.dims(i));
459 }
460 processInitializer(tp.name(), tp, initializer_type, input_data_by_name);
461 }
462 for (const auto& tp : graph.sparse_initializer()) {
463 TypeProto initializer_type;
464 auto* initializer_sparse_tensor_type = initializer_type.mutable_sparse_tensor_type();
465 initializer_sparse_tensor_type->set_elem_type(tp.values().data_type());
466 // set the shape according to the initializer shape info
467 auto* shape = initializer_sparse_tensor_type->mutable_shape();
468 for (int i = 0; i < tp.dims_size(); ++i) {
469 shape->add_dim()->set_dim_value(tp.dims(i));
470 }
471 processInitializer(tp.values().name(), tp, initializer_type, input_sparse_data_by_name);
472 }
473 for (auto& n : *graph.mutable_node()) {
474 process(n);
475 }
476 // Throw shape inference error if any. Error mode right now only supports 0 and 1.
477 // When set to 0, any node level shape inference errors are not thrown. This is to support backward compatiblity
478 // with 1.7 and earlier releases. When set to 1 it will throw all exceptions.
479 // TODO: Add a more granular way for exception handling.
480 if (options.error_mode > 0 && !inference_errors.empty()) {
481 std::string full_errors = "Shape inference error(s): ";
482 for (const std::string& error : inference_errors) {
483 full_errors += error + "\n";
484 }
485 fail_shape_inference(full_errors);
486 }
487 }
488
489 void process(const NodeProto& n, std::unordered_map<std::string, const AttributeProto*> attr_map) {
490 NodeProto copy_n(n);
491 // Add attribute information into the temporary node
492 copy_n.clear_attribute();
493 for (const auto& attr : n.attribute()) {
494 if (attr.has_ref_attr_name()) {
495 if (attr_map.count(attr.ref_attr_name())) {
496 auto copy_attr = *attr_map[attr.ref_attr_name()];
497 copy_attr.set_name(attr.name());
498 copy_n.add_attribute()->CopyFrom(copy_attr);
499 }
500 } else {
501 copy_n.add_attribute()->CopyFrom(attr);
502 }
503 }
504 process(copy_n);
505 }
506
507 void process(const FunctionProto& func_proto, InferenceContext& ctx) {
508 // Get a temporary tensor-shape map
509 const auto num_func_inputs = func_proto.input_size();
510 std::vector<TypeProto> types_cache(num_func_inputs);
511 for (int i = 0; i < num_func_inputs; ++i) {
512 if (ctx.getInputType(i) == nullptr) {
513 fail_type_inference("Input ", i, " type is missing.");
514 }
515 types_cache[i] = *ctx.getInputType(i); // TODO: investigate whether we can remove cache
516 value_types_by_name[func_proto.input().Get(i)] = &types_cache[i];
517 }
518
519 // Create a temporary initializer value map
520 for (int i = 0; i < static_cast<int>(ctx.getNumInputs()) && i < num_func_inputs; ++i) {
521 const TypeProto* type = ctx.getInputType(i);
522 if (type->value_case() == TypeProto::kTensorType && ctx.getInputData(i) != nullptr) {
523 input_data_by_name[func_proto.input().Get(i)] = ctx.getInputData(i);
524 } else if (type->value_case() == TypeProto::kSparseTensorType && ctx.getInputSparseData(i) != nullptr) {
525 input_sparse_data_by_name[func_proto.input().Get(i)] = ctx.getInputSparseData(i);
526 }
527 }
528
529 std::unordered_map<std::string, const AttributeProto*> attr_map;
530 for (auto& attr : func_proto.attribute()) {
531 if (ctx.getAttribute(attr) != nullptr) {
532 attr_map[attr] = ctx.getAttribute(attr);
533 }
534 }
535
536 for (auto& n : func_proto.node()) {
537 process(n, attr_map);
538 }
539
540 for (int i = 0; i < func_proto.output_size(); ++i) {
541 const std::string& output_name = func_proto.output().Get(i);
542 // Skip if no type inferred for the tensor
543 auto iter = value_types_by_name.find(output_name);
544 if (iter != value_types_by_name.cend()) {
545 // Copy the type info to ctx
546 // to pass back to main graph
547 auto type_proto = ctx.getOutputType(i);
548 type_proto->CopyFrom(*(iter->second));
549 }
550 }
551 }
552
553 public:
554 ShapeInferenceImplBase(
555 GraphProto* g_in,
556 const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in,
557 const std::unordered_map<std::string, int>& opset_imports_in,
558 const ShapeInferenceOptions& options_in,
559 SymbolTable* symbol_table_in,
560 const ModelLocalFunctionsMap& model_local_functions_map_in,
561 const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
562 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name_in = nullptr,
563 const int ir_version_in = IR_VERSION // default the latest one
564 )
565 : g(*g_in),
566 value_types_by_name(outer_scope_value_types_by_name_in),
567 opset_imports(opset_imports_in),
568 options(options_in),
569 symbol_table(symbol_table_in),
570 model_local_functions_map(model_local_functions_map_in),
571 schema_registry(schema_registry_in),
572 generated_shape_data_by_name(generated_shape_data_by_name_in),
573 ir_version(ir_version_in),
574 graph_inference_context{
575 value_types_by_name,
576 opset_imports,
577 symbol_table,
578 model_local_functions_map,
579 schema_registry,
580 generated_shape_data_by_name,
581 ir_version} {
582 if (options.enable_data_propagation && generated_shape_data_by_name == nullptr) {
583 fail_shape_inference(
584 "Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
585 }
586 }
587
588 private:
589 GraphProto& g;
590 std::unordered_map<std::string, TypeProto*> value_types_by_name;
591 const std::unordered_map<std::string, int>& opset_imports;
592
593 const ShapeInferenceOptions& options;
594 SymbolTable* symbol_table;
595 const ModelLocalFunctionsMap& model_local_functions_map;
596 const ISchemaRegistry* schema_registry;
597 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name;
598 int ir_version;
599 GraphInferenceContext graph_inference_context;
600
601 std::unordered_map<std::string, TypeProto*> undefined_value_types_by_name;
602 std::unordered_map<std::string, const TensorProto*> input_data_by_name;
603 std::unordered_map<std::string, TensorProto> input_data_by_name_holder;
604 std::unordered_map<std::string, const SparseTensorProto*> input_sparse_data_by_name;
605
606 bool has_experimental_op = false;
607 bool has_unsupported_op = false;
608
609 std::vector<std::string> inference_errors;
610
611 std::list<TypeProto> initializer_type_list;
612};
613
614static void InferShapesImpl(
615 GraphProto* g,
616 const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name,
617 const std::unordered_map<std::string, int>& opset_imports,
618 const ShapeInferenceOptions& options,
619 SymbolTable* symbol_table,
620 const ModelLocalFunctionsMap& model_local_functions_map,
621 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
622 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr,
623 const int ir_version = IR_VERSION // default the latest one
624) {
625 std::unordered_map<std::string, TensorShapeProto> empty;
626 if (generated_shape_data_by_name == nullptr) {
627 generated_shape_data_by_name = &empty;
628 }
629 ShapeInferenceImplBase base(
630 g,
631 outer_scope_value_types_by_name,
632 opset_imports,
633 options,
634 symbol_table,
635 model_local_functions_map,
636 schema_registry,
637 generated_shape_data_by_name,
638 ir_version);
639 base.process(*g);
640}
641
642// Either ModelProto or FunctionProto
643template <class T>
644std::unordered_map<std::string, int> GetOpsetImportsFromProto(const T& proto) {
645 std::unordered_map<std::string, int> opset_imports;
646 for (const auto& opset_import : proto.opset_import()) {
647 opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version());
648 }
649 return opset_imports;
650}
651
652void InferShapes(
653 GraphProto* g,
654 const std::unordered_map<std::string, int>& opset_imports,
655 const ISchemaRegistry* schema_registry,
656 const ShapeInferenceOptions& options,
657 const std::unordered_map<std::string, const FunctionProto*>& model_local_functions) {
658 SymbolTableImpl symbol_table;
659 InferShapesImpl(
660 g,
661 std::unordered_map<std::string, TypeProto*>(0),
662 opset_imports,
663 options,
664 &symbol_table,
665 model_local_functions,
666 schema_registry);
667}
668
669void InferShapes(
670 ModelProto& m,
671 const ISchemaRegistry* schema_registry,
672 const ShapeInferenceOptions& options,
673 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) {
674 auto opset_imports = GetOpsetImportsFromProto(m);
675 SymbolTableImpl symbol_table;
676 ModelLocalFunctionsMap model_local_functions_by_id;
677 for (const auto& function_proto : m.functions()) {
678 model_local_functions_by_id.insert(
679 {GetModelLocalFunctionsMapIdentifier(function_proto.domain(), function_proto.name()), &function_proto});
680 }
681 InferShapesImpl(
682 m.mutable_graph(),
683 std::unordered_map<std::string, TypeProto*>(0),
684 opset_imports,
685 options,
686 &symbol_table,
687 model_local_functions_by_id,
688 schema_registry,
689 generated_shape_data_by_name,
690 m.ir_version());
691}
692
693void InferShapes(
694 const std::string& model_path,
695 const std::string& save_path,
696 const ISchemaRegistry* schema_registry,
697 const ShapeInferenceOptions& options,
698 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) {
699 ModelProto model;
700 LoadProtoFromPath(model_path, model);
701 InferShapes(model, schema_registry, options, generated_shape_data_by_name);
702 // Save the inferred model to the original model path
703 // Use SerializeToString instead of SerializeToOstream due to LITE_PROTO
704 std::fstream output(save_path, std::ios::out | std::ios::trunc | std::ios::binary);
705 std::string model_string;
706 ONNX_TRY {
707 model.SerializeToString(&model_string);
708 output << model_string;
709 }
710 ONNX_CATCH(...) {
711 fail_check("Unable to save inferred model to the target path:", save_path);
712 }
713}
714
715// Infer shape for functions
716void InferShapeForFunctionNode(
717 const FunctionProto& func_proto,
718 const std::unordered_map<std::string, int>& func_opset_imports,
719 const ISchemaRegistry* schema_registry,
720 InferenceContext& ctx,
721 const ShapeInferenceOptions& options,
722 const std::unordered_map<std::string, const FunctionProto*>& model_local_functions_map,
723 SymbolTable* symbol_table,
724 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) {
725 GraphProto g;
726 ShapeInferenceImplBase base(
727 &g,
728 {}, // outer_scope_value_types_by_name
729 func_opset_imports,
730 options,
731 symbol_table,
732 model_local_functions_map,
733 schema_registry,
734 generated_shape_data_by_name);
735 base.process(func_proto, ctx);
736}
737
738void InferShapeForFunctionNode(
739 const FunctionProto& function_proto,
740 const ISchemaRegistry* schema_registry,
741 InferenceContext& ctx,
742 const ShapeInferenceOptions& options,
743 const std::unordered_map<std::string, const FunctionProto*>& model_local_functions_map,
744 SymbolTable* symbol_table,
745 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) {
746 auto opset_imports = GetOpsetImportsFromProto(function_proto);
747 InferShapeForFunctionNode(
748 function_proto,
749 opset_imports,
750 schema_registry,
751 ctx,
752 options,
753 model_local_functions_map,
754 symbol_table,
755 generated_shape_data_by_name);
756}
757
758std::vector<const TypeProto*> GraphInferencerImpl::doInferencing(
759 const std::vector<const TypeProto*>& input_types,
760 const std::vector<const TensorProto*>& input_data) {
761 SymbolTable* symbol_table = context_->symbol_table;
762 int num_inputs = int(input_types.size());
763 std::unordered_set<std::string> initializer_name_set;
764 for (const auto& tp : g_->initializer()) {
765 initializer_name_set.insert(tp.name());
766 }
767
768 if (context_->ir_version >= 4) {
769 if (g_->input_size() != num_inputs) {
770 fail_shape_inference("Graph has ", g_->input_size(), " inputs but ", num_inputs, " were provided");
771 }
772 for (int i = 0; i < g_->input_size(); ++i) {
773 if (initializer_name_set.count(g_->input(i).name()) > 0) {
774 fail_shape_inference(
775 "Cannot use the same name as both a subgraph initializer and subgraph input: ", g_->input(i).name());
776 }
777 }
778 } else {
779 // IR < 4 requires all initializers to be optional inputs
780 // So the number of graph input can be larger than the number of node input
781 if (num_inputs > g_->input_size()) {
782 fail_shape_inference(
783 "Graph has ",
784 g_->input_size(),
785 " inputs but ",
786 num_inputs,
787 " were provided.",
788 "The number of graph input cannot be smaller than the number of node input");
789 } else if (num_inputs < g_->input_size()) {
790 for (int i = 0; i < g_->input_size(); ++i) {
791 if (i < num_inputs && initializer_name_set.count(g_->input(i).name()) > 0) {
792 fail_shape_inference("Graph initializer names must appear after the actual inputs: ", g_->input(i).name());
793 } else if (i >= num_inputs && initializer_name_set.count(g_->input(i).name()) == 0) {
794 // Further check whether the additional input is in initializers
795 fail_shape_inference("Cannot find missing input: ", g_->input(i).name(), "in initializers. ");
796 }
797 }
798 }
799 }
800
801 for (int i = 0, end = num_inputs; i < end; ++i) {
802 const TypeProto* inferred_input = input_types[i];
803
804 if (!inferred_input)
805 continue;
806
807 TypeProto* graph_input = g_->mutable_input(i)->mutable_type();
808 // Even if graphInput doesn't have defined type, it will assign inferredType to it
809 mergeShapesAndTypes(*inferred_input, graph_input);
810
811 if (symbol_table) {
812 MaterializeSymbolicShape(graph_input, *symbol_table);
813 }
814 }
815
816 // future: pass inputData into InferShapes either directly, or indirectly by
817 // updating initializers that match subgraph inputs.
818 (void)input_data;
819 ShapeInferenceOptions options{};
820 InferShapesImpl(
821 g_,
822 *context_->outer_scope_value_types_by_name, // never null
823 context_->opset_imports,
824 options,
825 symbol_table,
826 context_->model_local_functions,
827 context_->schema_registry,
828 context_->generated_shape_data_by_name);
829
830 std::vector<const TypeProto*> graph_output_types;
831 graph_output_types.reserve(g_->output().size());
832 for (const ValueInfoProto& output : g_->output()) {
833 graph_output_types.push_back(&output.type());
834 }
835
836 return graph_output_types;
837}
838
839std::string GetErrorWithNodeInfo(const NodeProto& n, std::runtime_error err) {
840 std::string op_name = n.has_name() ? (", node name: " + n.name()) : "";
841 return "(op_type:" + n.op_type() + op_name + "): " + err.what();
842}
843
844void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbol_table) {
845 symbol_table.addFromGraph(g);
846 for (const auto& n : g.node()) {
847 for (auto& attr : n.attribute()) {
848 if (attr.has_g()) {
849 TraverseGraphsToAddExistingSymbols(attr.g(), symbol_table);
850 }
851 }
852 }
853}
854
855} // namespace shape_inference
856} // namespace ONNX_NAMESPACE
857