1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "onnx/shape_inference/implementation.h" |
6 | #include <fstream> |
7 | #include <list> |
8 | #include "onnx/checker.h" |
9 | #include "onnx/common/file_utils.h" |
10 | #include "onnx/defs/data_type_utils.h" |
11 | #include "onnx/proto_utils.h" |
12 | #include "onnx/string_utils.h" |
13 | |
14 | namespace ONNX_NAMESPACE { |
15 | namespace shape_inference { |
16 | namespace { |
17 | |
18 | std::string GetValueCaseString(const TypeProto& type) { |
19 | switch (type.value_case()) { |
20 | case TypeProto::ValueCase::kTensorType: |
21 | return "tensor_type" ; |
22 | case TypeProto::ValueCase::kSequenceType: |
23 | return "sequence_type" ; |
24 | case TypeProto::ValueCase::kMapType: |
25 | return "map_type" ; |
26 | case TypeProto::ValueCase::kOptionalType: |
27 | return "optional_type" ; |
28 | #ifdef ONNX_ML |
29 | case TypeProto::ValueCase::kOpaqueType: |
30 | return "opaque_type" ; |
31 | #endif |
32 | case TypeProto::ValueCase::kSparseTensorType: |
33 | return "sparse_tensor_type" ; |
34 | case TypeProto::ValueCase::VALUE_NOT_SET: |
35 | return "NOT_SET" ; |
36 | default: |
37 | return ONNX_NAMESPACE::to_string(type.value_case()); |
38 | } |
39 | } |
40 | |
41 | std::string GetElemTypeString(const TypeProto_Tensor& type) { |
42 | #ifndef ONNX_USE_LITE_PROTO |
43 | const std::string type_str = TensorProto::DataType_Name(static_cast<TensorProto_DataType>(type.elem_type())); |
44 | if (!type_str.empty()) { |
45 | return type_str; |
46 | } |
47 | #endif |
48 | return ONNX_NAMESPACE::to_string(type.elem_type()); |
49 | } |
50 | |
51 | std::string GetElemTypeString(const TypeProto_SparseTensor& type) { |
52 | #ifndef ONNX_USE_LITE_PROTO |
53 | const std::string type_str = TensorProto::DataType_Name(static_cast<TensorProto_DataType>(type.elem_type())); |
54 | if (!type_str.empty()) { |
55 | return type_str; |
56 | } |
57 | #endif |
58 | return ONNX_NAMESPACE::to_string(type.elem_type()); |
59 | } |
60 | |
61 | } // namespace |
62 | |
63 | template <class T> |
64 | void CheckTensorShapesAndTypes(const T& inferred_type, const T& existing_type) { |
65 | if (inferred_type.elem_type() != TensorProto::UNDEFINED && existing_type.elem_type() != TensorProto::UNDEFINED && |
66 | existing_type.elem_type() != inferred_type.elem_type()) { |
67 | std::stringstream ss; |
68 | ss << "Inferred elem type differs from existing elem type: (" << GetElemTypeString(inferred_type) << ") vs (" |
69 | << GetElemTypeString(existing_type) << ")" ; |
70 | fail_type_inference(ss.str()); |
71 | } |
72 | |
73 | if (!inferred_type.has_shape() || !existing_type.has_shape()) { |
74 | return; |
75 | } |
76 | |
77 | if (inferred_type.shape().dim_size() != existing_type.shape().dim_size()) { |
78 | std::stringstream ss; |
79 | ss << "Inferred shape and existing shape differ in rank: (" << inferred_type.shape().dim_size() << ") vs (" |
80 | << existing_type.shape().dim_size() << ")" ; |
81 | fail_shape_inference(ss.str()); |
82 | } |
83 | |
84 | for (int i = 0; i < inferred_type.shape().dim_size(); ++i) { |
85 | const auto& inferred_dim = inferred_type.shape().dim(i); |
86 | const auto& existing_dim = existing_type.shape().dim(i); |
87 | if (inferred_dim.has_dim_value() && existing_dim.has_dim_value() && |
88 | inferred_dim.dim_value() != existing_dim.dim_value()) { |
89 | std::stringstream ss; |
90 | ss << "Inferred shape and existing shape differ in dimension " << i << ": (" << inferred_dim.dim_value() |
91 | << ") vs (" << existing_dim.dim_value() << ")" ; |
92 | fail_shape_inference(ss.str()); |
93 | } |
94 | } |
95 | } |
96 | |
97 | void checkShapesAndTypes(const TypeProto& inferred_type, const TypeProto& existing_type) { |
98 | const auto inferred_value_case = inferred_type.value_case(); |
99 | const auto existing_value_case = existing_type.value_case(); |
100 | if (inferred_value_case == TypeProto::ValueCase::VALUE_NOT_SET || |
101 | existing_value_case == TypeProto::ValueCase::VALUE_NOT_SET) { |
102 | // nothing to check; will assign inferredType to undefined existingType |
103 | return; |
104 | } |
105 | if (inferred_value_case != existing_value_case) { |
106 | fail_type_inference( |
107 | "type case mismatch. existing=" , |
108 | GetValueCaseString(existing_type), |
109 | " inferred=" , |
110 | GetValueCaseString(inferred_type)); |
111 | } |
112 | |
113 | if (inferred_value_case == TypeProto::kTensorType && existing_value_case == TypeProto::kTensorType) { |
114 | CheckTensorShapesAndTypes(inferred_type.tensor_type(), existing_type.tensor_type()); |
115 | } else if ( |
116 | inferred_value_case == TypeProto::kSparseTensorType && existing_value_case == TypeProto::kSparseTensorType) { |
117 | CheckTensorShapesAndTypes(inferred_type.sparse_tensor_type(), existing_type.sparse_tensor_type()); |
118 | } else if (inferred_value_case == TypeProto::kSequenceType && existing_value_case == TypeProto::kSequenceType) { |
119 | checkShapesAndTypes(inferred_type.sequence_type().elem_type(), existing_type.sequence_type().elem_type()); |
120 | } else if (inferred_value_case == TypeProto::kOptionalType && existing_value_case == TypeProto::kOptionalType) { |
121 | checkShapesAndTypes(inferred_type.optional_type().elem_type(), existing_type.optional_type().elem_type()); |
122 | } else if ( |
123 | inferred_value_case == TypeProto::TypeProto::kMapType && existing_value_case == TypeProto::TypeProto::kMapType) { |
124 | if (inferred_type.map_type().key_type() != existing_type.map_type().key_type()) { |
125 | fail_type_inference( |
126 | "key type mismatch from MapProto. existing=" , |
127 | Utils::DataTypeUtils::ToDataTypeString(existing_type.map_type().key_type()), |
128 | " inferred=" , |
129 | Utils::DataTypeUtils::ToDataTypeString(inferred_type.map_type().key_type())); |
130 | } |
131 | checkShapesAndTypes(inferred_type.map_type().value_type(), existing_type.map_type().value_type()); |
132 | } else { |
133 | fail_type_inference("type case unsupported. existing=" , existing_value_case, " inferred=" , inferred_value_case); |
134 | } |
135 | } |
136 | |
137 | void mergeShapesAndTypes(const TypeProto_Tensor& inferred_type, TypeProto_Tensor* existing_type) { |
138 | if (existing_type->elem_type() == TensorProto::UNDEFINED) { |
139 | existing_type->set_elem_type(inferred_type.elem_type()); |
140 | } |
141 | |
142 | if (!inferred_type.has_shape()) { |
143 | return; |
144 | } |
145 | |
146 | if (!existing_type->has_shape()) { |
147 | *existing_type->mutable_shape() = inferred_type.shape(); |
148 | return; |
149 | } |
150 | |
151 | for (int i = 0; i < inferred_type.shape().dim_size(); ++i) { |
152 | const auto& inferred_dim = inferred_type.shape().dim(i); |
153 | auto* existing_dim = existing_type->mutable_shape()->mutable_dim(i); |
154 | if ((!existing_dim->has_dim_value() && !existing_dim->has_dim_param()) || inferred_dim.has_dim_value()) { |
155 | *existing_dim = inferred_dim; |
156 | } |
157 | } |
158 | } |
159 | |
160 | void mergeShapesAndTypes(const TypeProto_SparseTensor& inferred_type, TypeProto_SparseTensor* existing_type) { |
161 | if (existing_type->elem_type() == TensorProto::UNDEFINED) { |
162 | existing_type->set_elem_type(inferred_type.elem_type()); |
163 | } |
164 | |
165 | if (!inferred_type.has_shape()) { |
166 | return; |
167 | } |
168 | |
169 | if (!existing_type->has_shape()) { |
170 | *existing_type->mutable_shape() = inferred_type.shape(); |
171 | return; |
172 | } |
173 | |
174 | for (int i = 0; i < inferred_type.shape().dim_size(); ++i) { |
175 | const auto& inferred_dim = inferred_type.shape().dim(i); |
176 | auto* existing_dim = existing_type->mutable_shape()->mutable_dim(i); |
177 | if ((!existing_dim->has_dim_value() && !existing_dim->has_dim_param()) || inferred_dim.has_dim_value()) { |
178 | *existing_dim = inferred_dim; |
179 | } |
180 | } |
181 | } |
182 | |
183 | void mergeShapesAndTypes(const TypeProto& inferred_type, TypeProto* existing_type) { |
184 | // Check before merge |
185 | checkShapesAndTypes(inferred_type, *existing_type); |
186 | const auto inferred_val_case = inferred_type.value_case(); |
187 | if (inferred_val_case == TypeProto::kTensorType) { |
188 | mergeShapesAndTypes(inferred_type.tensor_type(), existing_type->mutable_tensor_type()); |
189 | } else if (inferred_val_case == TypeProto::kSparseTensorType) { |
190 | mergeShapesAndTypes(inferred_type.sparse_tensor_type(), existing_type->mutable_sparse_tensor_type()); |
191 | } else if (inferred_val_case == TypeProto::kSequenceType) { |
192 | mergeShapesAndTypes( |
193 | inferred_type.sequence_type().elem_type(), existing_type->mutable_sequence_type()->mutable_elem_type()); |
194 | } else if (inferred_val_case == TypeProto::kOptionalType) { |
195 | mergeShapesAndTypes( |
196 | inferred_type.optional_type().elem_type(), existing_type->mutable_optional_type()->mutable_elem_type()); |
197 | } else if (inferred_val_case == TypeProto::kMapType) { |
198 | mergeShapesAndTypes(inferred_type.map_type().value_type(), existing_type->mutable_map_type()->mutable_value_type()); |
199 | } |
200 | } |
201 | |
202 | // TypeProto_Tensor or TypeProto_SparseTensor |
203 | template <typename TensorTypeProto> |
204 | void GenerateSymbolicShape(TensorTypeProto* inferred_type, SymbolTable& symbol_table) { |
205 | if (!inferred_type->has_shape()) { |
206 | return; |
207 | } |
208 | for (int i = 0; i < inferred_type->shape().dim_size(); ++i) { |
209 | // set a symbol if it doesn't have dim_value and dim_param |
210 | auto* dim = inferred_type->mutable_shape()->mutable_dim(i); |
211 | if (!dim->has_dim_value() && !dim->has_dim_param()) { |
212 | dim->set_dim_param(symbol_table.createNew("unk__" )); |
213 | } |
214 | } |
215 | } |
216 | |
217 | void MaterializeSymbolicShape(TypeProto* inferred_type, SymbolTable& symbol_table) { |
218 | const auto inferred_val_case = inferred_type->value_case(); |
219 | if (inferred_val_case == TypeProto::ValueCase::VALUE_NOT_SET) { |
220 | return; |
221 | } |
222 | |
223 | if (inferred_val_case == TypeProto::kTensorType) { |
224 | GenerateSymbolicShape(inferred_type->mutable_tensor_type(), symbol_table); |
225 | } else if (inferred_val_case == TypeProto::kSparseTensorType) { |
226 | GenerateSymbolicShape(inferred_type->mutable_sparse_tensor_type(), symbol_table); |
227 | } else if (inferred_val_case == TypeProto::kSequenceType) { |
228 | MaterializeSymbolicShape(inferred_type->mutable_sequence_type()->mutable_elem_type(), symbol_table); |
229 | } else if (inferred_val_case == TypeProto::kOptionalType) { |
230 | MaterializeSymbolicShape(inferred_type->mutable_optional_type()->mutable_elem_type(), symbol_table); |
231 | } else { |
232 | fail_shape_inference("type case unsupported for symbolic shape inference. inferred=" , inferred_val_case); |
233 | } |
234 | } |
235 | |
236 | std::string GetModelLocalFunctionsMapIdentifier(const std::string& domain, const std::string& func_name) { |
237 | return domain + ":" + func_name; |
238 | } |
239 | |
240 | class ShapeInferenceImplBase { |
241 | public: |
242 | void updateType(const std::string& name, TypeProto* inferred_type) { |
243 | if (inferred_type->value_case() == TypeProto::ValueCase::VALUE_NOT_SET) { |
244 | return; |
245 | } |
246 | |
247 | if (symbol_table) { |
248 | MaterializeSymbolicShape(inferred_type, *symbol_table); |
249 | } |
250 | |
251 | // Find any pre-existing type and shape info. If there is such, |
252 | // then check for compatibility with the inferred |
253 | // information. Otherwise, initialize it in an empty state. |
254 | auto iter = value_types_by_name.find(name); |
255 | TypeProto* existing_type = nullptr; |
256 | if (iter != value_types_by_name.end()) { |
257 | existing_type = iter->second; |
258 | } else { |
259 | // Create a new value_info if defined type does not exist |
260 | auto vi = g.add_value_info(); // TODO: clean this up |
261 | vi->set_name(name); |
262 | existing_type = vi->mutable_type(); |
263 | // For undefined output type, update both value_info and output for now |
264 | // Update existing output with undefined type: assign inferred type to it |
265 | iter = undefined_value_types_by_name.find(name); |
266 | if (iter != undefined_value_types_by_name.end()) { |
267 | *iter->second = *inferred_type; |
268 | } |
269 | } |
270 | |
271 | // TODO: cleanup this by merging with previous if-else |
272 | // Now we can merge pre-existing and inferred info |
273 | mergeShapesAndTypes(*inferred_type, existing_type); |
274 | |
275 | // Make merged info available to further inference. |
276 | value_types_by_name[name] = existing_type; |
277 | } |
278 | |
279 | void updateType(ValueInfoProto& valueInfo) { |
280 | if (valueInfo.has_type()) { |
281 | value_types_by_name[valueInfo.name()] = valueInfo.mutable_type(); |
282 | } else { |
283 | undefined_value_types_by_name[valueInfo.name()] = valueInfo.mutable_type(); |
284 | } |
285 | } |
286 | |
287 | void preprocess(const NodeProto& n) { |
288 | if (checker::check_is_experimental_op(n)) { |
289 | has_experimental_op = true; |
290 | } else if (n.op_type() == "Constant" && n.output().size() == 1) { |
291 | for (const auto& attr : n.attribute()) { |
292 | if (attr.name() == "value" ) { |
293 | if (attr.type() == AttributeProto::TENSOR && attr.has_t()) { |
294 | input_data_by_name[n.output(0)] = &attr.t(); |
295 | } else if (attr.type() == AttributeProto::SPARSE_TENSOR && attr.has_sparse_tensor()) { |
296 | input_sparse_data_by_name[n.output(0)] = &attr.sparse_tensor(); |
297 | } |
298 | } else if (attr.type() == AttributeProto::INTS && attr.name() == "value_ints" ) { |
299 | std::vector<int64_t> ints{attr.ints().begin(), attr.ints().end()}; |
300 | input_data_by_name_holder[n.output(0)] = ToTensor(ints); |
301 | input_data_by_name[n.output(0)] = &input_data_by_name_holder[n.output(0)]; |
302 | } |
303 | } |
304 | } |
305 | } |
306 | |
307 | void process(NodeProto& n) { |
308 | // Resolve domain for node |
309 | auto dit = opset_imports.find(n.domain()); |
310 | if (dit == opset_imports.end()) { |
311 | // Both "" and "ai.onnx" refer to the default ONNX domain |
312 | if (n.domain() == "" ) { |
313 | dit = opset_imports.find("ai.onnx" ); |
314 | } |
315 | if (dit == opset_imports.end()) { |
316 | fail_type_inference( |
317 | "Cannot infer type and shape for node name " , |
318 | n.name(), |
319 | ". No opset import for domain" , |
320 | n.domain(), |
321 | " optype " , |
322 | n.op_type()); |
323 | } |
324 | } |
325 | auto domain_version = dit->second; |
326 | const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain()); |
327 | InferenceContextImpl ctx( |
328 | n, |
329 | value_types_by_name, |
330 | input_data_by_name, |
331 | input_sparse_data_by_name, |
332 | generated_shape_data_by_name, |
333 | &graph_inference_context); |
334 | |
335 | ONNX_TRY { |
336 | if (schema) { |
337 | if (schema->has_type_and_shape_inference_function()) { |
338 | schema->GetTypeAndShapeInferenceFunction()(ctx); |
339 | } else if (schema->HasFunction()) { |
340 | InferShapeForFunctionNode( |
341 | *(schema->GetFunction()), |
342 | schema_registry, |
343 | ctx, |
344 | options, |
345 | model_local_functions_map, |
346 | symbol_table, |
347 | generated_shape_data_by_name); |
348 | } else { |
349 | // Continue with inference for remaining nodes |
350 | return; |
351 | } |
352 | } else if (model_local_functions_map.size() > 0) { |
353 | auto iter = model_local_functions_map.find(GetModelLocalFunctionsMapIdentifier(n.domain(), n.op_type())); |
354 | if (iter != model_local_functions_map.end()) { |
355 | InferShapeForFunctionNode( |
356 | *(iter->second), |
357 | schema_registry, |
358 | ctx, |
359 | options, |
360 | model_local_functions_map, |
361 | symbol_table, |
362 | generated_shape_data_by_name); |
363 | } else { |
364 | has_unsupported_op = true; |
365 | return; |
366 | } |
367 | } else { |
368 | has_unsupported_op = true; |
369 | return; |
370 | } |
371 | } |
372 | ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) { |
373 | ONNX_HANDLE_EXCEPTION([&]() { |
374 | // onnx does not support unsupported/experimental operators |
375 | // so it won't consider it as an error |
376 | if (!has_unsupported_op && !has_experimental_op) { |
377 | inference_errors.push_back(GetErrorWithNodeInfo(n, ex)); |
378 | } |
379 | }); |
380 | // Continue with inference for remaining nodes |
381 | return; |
382 | } |
383 | |
384 | ONNX_TRY { |
385 | // check the type-equality for input and output |
386 | if (options.check_type && schema) { |
387 | schema->CheckInputOutputType(ctx); |
388 | } |
389 | |
390 | for (int i = 0; i < n.output_size(); ++i) { |
391 | // skip type and shape propagation for missing optional outputs. |
392 | if (!n.output(i).empty()) |
393 | updateType(n.output(i), ctx.getOutputType(i)); |
394 | } |
395 | |
396 | preprocess(n); |
397 | |
398 | // If data propagation is enabled, propagate shape data if it exists. |
399 | if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) { |
400 | if (generated_shape_data_by_name == nullptr) { |
401 | fail_shape_inference( |
402 | "Container for generated shape data cannot be nullptr when enable_data_propagation option is set." ); |
403 | } |
404 | DataPropagationContextImpl data_propagation_ctx( |
405 | n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name); |
406 | schema->GetDataPropagationFunction()(data_propagation_ctx); |
407 | } |
408 | } |
409 | ONNX_CATCH(const std::runtime_error& err) { |
410 | ONNX_HANDLE_EXCEPTION([&]() { fail_shape_inference(GetErrorWithNodeInfo(n, err)); }); |
411 | } |
412 | } |
413 | |
414 | // TypeProto_Tensor or TypeProto_SparseTensor |
415 | template <typename T> |
416 | void processInitializer( |
417 | const std::string& name, |
418 | const T& tensorValue, |
419 | TypeProto& initializer_type, |
420 | std::unordered_map<std::string, const T*>& map) { |
421 | map[name] = &tensorValue; |
422 | auto iter = value_types_by_name.find(name); |
423 | // If it already exists in input, check input and initializer is sync |
424 | // use shape info from input (input has priority over initializer) |
425 | if (iter != value_types_by_name.end()) { |
426 | checkShapesAndTypes(initializer_type, *iter->second); |
427 | // CheckTensorShapesAndTypes(*initializer_tensor_type, *iter->second->mutable_tensor_type()); |
428 | } |
429 | // Support IR>=4: some tensors can only exist in initializer and not in input |
430 | // So shape_inference should make use of initializer shapes |
431 | // Store initializer shape info in value_info as well |
432 | else if (ir_version >= 4) { |
433 | initializer_type_list.push_back(std::move(initializer_type)); |
434 | value_types_by_name[name] = &initializer_type_list.back(); |
435 | } |
436 | } |
437 | |
438 | void process(GraphProto& graph) { |
439 | if (symbol_table) { |
440 | TraverseGraphsToAddExistingSymbols(graph, *symbol_table); |
441 | } |
442 | for (auto& vi : *graph.mutable_value_info()) { |
443 | updateType(vi); |
444 | } |
445 | for (auto& vi : *graph.mutable_input()) { |
446 | updateType(vi); |
447 | } |
448 | for (auto& vi : *graph.mutable_output()) { |
449 | updateType(vi); |
450 | } |
451 | for (const auto& tp : graph.initializer()) { |
452 | TypeProto initializer_type; |
453 | TypeProto_Tensor* initializer_tensor_type = initializer_type.mutable_tensor_type(); |
454 | initializer_tensor_type->set_elem_type(tp.data_type()); |
455 | // set the shape according to the initializer shape info |
456 | auto* shape = initializer_tensor_type->mutable_shape(); |
457 | for (int i = 0; i < tp.dims_size(); ++i) { |
458 | shape->add_dim()->set_dim_value(tp.dims(i)); |
459 | } |
460 | processInitializer(tp.name(), tp, initializer_type, input_data_by_name); |
461 | } |
462 | for (const auto& tp : graph.sparse_initializer()) { |
463 | TypeProto initializer_type; |
464 | auto* initializer_sparse_tensor_type = initializer_type.mutable_sparse_tensor_type(); |
465 | initializer_sparse_tensor_type->set_elem_type(tp.values().data_type()); |
466 | // set the shape according to the initializer shape info |
467 | auto* shape = initializer_sparse_tensor_type->mutable_shape(); |
468 | for (int i = 0; i < tp.dims_size(); ++i) { |
469 | shape->add_dim()->set_dim_value(tp.dims(i)); |
470 | } |
471 | processInitializer(tp.values().name(), tp, initializer_type, input_sparse_data_by_name); |
472 | } |
473 | for (auto& n : *graph.mutable_node()) { |
474 | process(n); |
475 | } |
476 | // Throw shape inference error if any. Error mode right now only supports 0 and 1. |
477 | // When set to 0, any node level shape inference errors are not thrown. This is to support backward compatiblity |
478 | // with 1.7 and earlier releases. When set to 1 it will throw all exceptions. |
479 | // TODO: Add a more granular way for exception handling. |
480 | if (options.error_mode > 0 && !inference_errors.empty()) { |
481 | std::string full_errors = "Shape inference error(s): " ; |
482 | for (const std::string& error : inference_errors) { |
483 | full_errors += error + "\n" ; |
484 | } |
485 | fail_shape_inference(full_errors); |
486 | } |
487 | } |
488 | |
489 | void process(const NodeProto& n, std::unordered_map<std::string, const AttributeProto*> attr_map) { |
490 | NodeProto copy_n(n); |
491 | // Add attribute information into the temporary node |
492 | copy_n.clear_attribute(); |
493 | for (const auto& attr : n.attribute()) { |
494 | if (attr.has_ref_attr_name()) { |
495 | if (attr_map.count(attr.ref_attr_name())) { |
496 | auto copy_attr = *attr_map[attr.ref_attr_name()]; |
497 | copy_attr.set_name(attr.name()); |
498 | copy_n.add_attribute()->CopyFrom(copy_attr); |
499 | } |
500 | } else { |
501 | copy_n.add_attribute()->CopyFrom(attr); |
502 | } |
503 | } |
504 | process(copy_n); |
505 | } |
506 | |
507 | void process(const FunctionProto& func_proto, InferenceContext& ctx) { |
508 | // Get a temporary tensor-shape map |
509 | const auto num_func_inputs = func_proto.input_size(); |
510 | std::vector<TypeProto> types_cache(num_func_inputs); |
511 | for (int i = 0; i < num_func_inputs; ++i) { |
512 | if (ctx.getInputType(i) == nullptr) { |
513 | fail_type_inference("Input " , i, " type is missing." ); |
514 | } |
515 | types_cache[i] = *ctx.getInputType(i); // TODO: investigate whether we can remove cache |
516 | value_types_by_name[func_proto.input().Get(i)] = &types_cache[i]; |
517 | } |
518 | |
519 | // Create a temporary initializer value map |
520 | for (int i = 0; i < static_cast<int>(ctx.getNumInputs()) && i < num_func_inputs; ++i) { |
521 | const TypeProto* type = ctx.getInputType(i); |
522 | if (type->value_case() == TypeProto::kTensorType && ctx.getInputData(i) != nullptr) { |
523 | input_data_by_name[func_proto.input().Get(i)] = ctx.getInputData(i); |
524 | } else if (type->value_case() == TypeProto::kSparseTensorType && ctx.getInputSparseData(i) != nullptr) { |
525 | input_sparse_data_by_name[func_proto.input().Get(i)] = ctx.getInputSparseData(i); |
526 | } |
527 | } |
528 | |
529 | std::unordered_map<std::string, const AttributeProto*> attr_map; |
530 | for (auto& attr : func_proto.attribute()) { |
531 | if (ctx.getAttribute(attr) != nullptr) { |
532 | attr_map[attr] = ctx.getAttribute(attr); |
533 | } |
534 | } |
535 | |
536 | for (auto& n : func_proto.node()) { |
537 | process(n, attr_map); |
538 | } |
539 | |
540 | for (int i = 0; i < func_proto.output_size(); ++i) { |
541 | const std::string& output_name = func_proto.output().Get(i); |
542 | // Skip if no type inferred for the tensor |
543 | auto iter = value_types_by_name.find(output_name); |
544 | if (iter != value_types_by_name.cend()) { |
545 | // Copy the type info to ctx |
546 | // to pass back to main graph |
547 | auto type_proto = ctx.getOutputType(i); |
548 | type_proto->CopyFrom(*(iter->second)); |
549 | } |
550 | } |
551 | } |
552 | |
553 | public: |
554 | ShapeInferenceImplBase( |
555 | GraphProto* g_in, |
556 | const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in, |
557 | const std::unordered_map<std::string, int>& opset_imports_in, |
558 | const ShapeInferenceOptions& options_in, |
559 | SymbolTable* symbol_table_in, |
560 | const ModelLocalFunctionsMap& model_local_functions_map_in, |
561 | const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(), |
562 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name_in = nullptr, |
563 | const int ir_version_in = IR_VERSION // default the latest one |
564 | ) |
565 | : g(*g_in), |
566 | value_types_by_name(outer_scope_value_types_by_name_in), |
567 | opset_imports(opset_imports_in), |
568 | options(options_in), |
569 | symbol_table(symbol_table_in), |
570 | model_local_functions_map(model_local_functions_map_in), |
571 | schema_registry(schema_registry_in), |
572 | generated_shape_data_by_name(generated_shape_data_by_name_in), |
573 | ir_version(ir_version_in), |
574 | graph_inference_context{ |
575 | value_types_by_name, |
576 | opset_imports, |
577 | symbol_table, |
578 | model_local_functions_map, |
579 | schema_registry, |
580 | generated_shape_data_by_name, |
581 | ir_version} { |
582 | if (options.enable_data_propagation && generated_shape_data_by_name == nullptr) { |
583 | fail_shape_inference( |
584 | "Container for generated shape data cannot be nullptr when enable_data_propagation option is set." ); |
585 | } |
586 | } |
587 | |
588 | private: |
589 | GraphProto& g; |
590 | std::unordered_map<std::string, TypeProto*> value_types_by_name; |
591 | const std::unordered_map<std::string, int>& opset_imports; |
592 | |
593 | const ShapeInferenceOptions& options; |
594 | SymbolTable* symbol_table; |
595 | const ModelLocalFunctionsMap& model_local_functions_map; |
596 | const ISchemaRegistry* schema_registry; |
597 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name; |
598 | int ir_version; |
599 | GraphInferenceContext graph_inference_context; |
600 | |
601 | std::unordered_map<std::string, TypeProto*> undefined_value_types_by_name; |
602 | std::unordered_map<std::string, const TensorProto*> input_data_by_name; |
603 | std::unordered_map<std::string, TensorProto> input_data_by_name_holder; |
604 | std::unordered_map<std::string, const SparseTensorProto*> input_sparse_data_by_name; |
605 | |
606 | bool has_experimental_op = false; |
607 | bool has_unsupported_op = false; |
608 | |
609 | std::vector<std::string> inference_errors; |
610 | |
611 | std::list<TypeProto> initializer_type_list; |
612 | }; |
613 | |
614 | static void InferShapesImpl( |
615 | GraphProto* g, |
616 | const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name, |
617 | const std::unordered_map<std::string, int>& opset_imports, |
618 | const ShapeInferenceOptions& options, |
619 | SymbolTable* symbol_table, |
620 | const ModelLocalFunctionsMap& model_local_functions_map, |
621 | const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(), |
622 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr, |
623 | const int ir_version = IR_VERSION // default the latest one |
624 | ) { |
625 | std::unordered_map<std::string, TensorShapeProto> empty; |
626 | if (generated_shape_data_by_name == nullptr) { |
627 | generated_shape_data_by_name = ∅ |
628 | } |
629 | ShapeInferenceImplBase base( |
630 | g, |
631 | outer_scope_value_types_by_name, |
632 | opset_imports, |
633 | options, |
634 | symbol_table, |
635 | model_local_functions_map, |
636 | schema_registry, |
637 | generated_shape_data_by_name, |
638 | ir_version); |
639 | base.process(*g); |
640 | } |
641 | |
642 | // Either ModelProto or FunctionProto |
643 | template <class T> |
644 | std::unordered_map<std::string, int> GetOpsetImportsFromProto(const T& proto) { |
645 | std::unordered_map<std::string, int> opset_imports; |
646 | for (const auto& opset_import : proto.opset_import()) { |
647 | opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version()); |
648 | } |
649 | return opset_imports; |
650 | } |
651 | |
652 | void InferShapes( |
653 | GraphProto* g, |
654 | const std::unordered_map<std::string, int>& opset_imports, |
655 | const ISchemaRegistry* schema_registry, |
656 | const ShapeInferenceOptions& options, |
657 | const std::unordered_map<std::string, const FunctionProto*>& model_local_functions) { |
658 | SymbolTableImpl symbol_table; |
659 | InferShapesImpl( |
660 | g, |
661 | std::unordered_map<std::string, TypeProto*>(0), |
662 | opset_imports, |
663 | options, |
664 | &symbol_table, |
665 | model_local_functions, |
666 | schema_registry); |
667 | } |
668 | |
669 | void InferShapes( |
670 | ModelProto& m, |
671 | const ISchemaRegistry* schema_registry, |
672 | const ShapeInferenceOptions& options, |
673 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) { |
674 | auto opset_imports = GetOpsetImportsFromProto(m); |
675 | SymbolTableImpl symbol_table; |
676 | ModelLocalFunctionsMap model_local_functions_by_id; |
677 | for (const auto& function_proto : m.functions()) { |
678 | model_local_functions_by_id.insert( |
679 | {GetModelLocalFunctionsMapIdentifier(function_proto.domain(), function_proto.name()), &function_proto}); |
680 | } |
681 | InferShapesImpl( |
682 | m.mutable_graph(), |
683 | std::unordered_map<std::string, TypeProto*>(0), |
684 | opset_imports, |
685 | options, |
686 | &symbol_table, |
687 | model_local_functions_by_id, |
688 | schema_registry, |
689 | generated_shape_data_by_name, |
690 | m.ir_version()); |
691 | } |
692 | |
693 | void InferShapes( |
694 | const std::string& model_path, |
695 | const std::string& save_path, |
696 | const ISchemaRegistry* schema_registry, |
697 | const ShapeInferenceOptions& options, |
698 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) { |
699 | ModelProto model; |
700 | LoadProtoFromPath(model_path, model); |
701 | InferShapes(model, schema_registry, options, generated_shape_data_by_name); |
702 | // Save the inferred model to the original model path |
703 | // Use SerializeToString instead of SerializeToOstream due to LITE_PROTO |
704 | std::fstream output(save_path, std::ios::out | std::ios::trunc | std::ios::binary); |
705 | std::string model_string; |
706 | ONNX_TRY { |
707 | model.SerializeToString(&model_string); |
708 | output << model_string; |
709 | } |
710 | ONNX_CATCH(...) { |
711 | fail_check("Unable to save inferred model to the target path:" , save_path); |
712 | } |
713 | } |
714 | |
715 | // Infer shape for functions |
716 | void InferShapeForFunctionNode( |
717 | const FunctionProto& func_proto, |
718 | const std::unordered_map<std::string, int>& func_opset_imports, |
719 | const ISchemaRegistry* schema_registry, |
720 | InferenceContext& ctx, |
721 | const ShapeInferenceOptions& options, |
722 | const std::unordered_map<std::string, const FunctionProto*>& model_local_functions_map, |
723 | SymbolTable* symbol_table, |
724 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) { |
725 | GraphProto g; |
726 | ShapeInferenceImplBase base( |
727 | &g, |
728 | {}, // outer_scope_value_types_by_name |
729 | func_opset_imports, |
730 | options, |
731 | symbol_table, |
732 | model_local_functions_map, |
733 | schema_registry, |
734 | generated_shape_data_by_name); |
735 | base.process(func_proto, ctx); |
736 | } |
737 | |
738 | void InferShapeForFunctionNode( |
739 | const FunctionProto& function_proto, |
740 | const ISchemaRegistry* schema_registry, |
741 | InferenceContext& ctx, |
742 | const ShapeInferenceOptions& options, |
743 | const std::unordered_map<std::string, const FunctionProto*>& model_local_functions_map, |
744 | SymbolTable* symbol_table, |
745 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) { |
746 | auto opset_imports = GetOpsetImportsFromProto(function_proto); |
747 | InferShapeForFunctionNode( |
748 | function_proto, |
749 | opset_imports, |
750 | schema_registry, |
751 | ctx, |
752 | options, |
753 | model_local_functions_map, |
754 | symbol_table, |
755 | generated_shape_data_by_name); |
756 | } |
757 | |
758 | std::vector<const TypeProto*> GraphInferencerImpl::doInferencing( |
759 | const std::vector<const TypeProto*>& input_types, |
760 | const std::vector<const TensorProto*>& input_data) { |
761 | SymbolTable* symbol_table = context_->symbol_table; |
762 | int num_inputs = int(input_types.size()); |
763 | std::unordered_set<std::string> initializer_name_set; |
764 | for (const auto& tp : g_->initializer()) { |
765 | initializer_name_set.insert(tp.name()); |
766 | } |
767 | |
768 | if (context_->ir_version >= 4) { |
769 | if (g_->input_size() != num_inputs) { |
770 | fail_shape_inference("Graph has " , g_->input_size(), " inputs but " , num_inputs, " were provided" ); |
771 | } |
772 | for (int i = 0; i < g_->input_size(); ++i) { |
773 | if (initializer_name_set.count(g_->input(i).name()) > 0) { |
774 | fail_shape_inference( |
775 | "Cannot use the same name as both a subgraph initializer and subgraph input: " , g_->input(i).name()); |
776 | } |
777 | } |
778 | } else { |
779 | // IR < 4 requires all initializers to be optional inputs |
780 | // So the number of graph input can be larger than the number of node input |
781 | if (num_inputs > g_->input_size()) { |
782 | fail_shape_inference( |
783 | "Graph has " , |
784 | g_->input_size(), |
785 | " inputs but " , |
786 | num_inputs, |
787 | " were provided." , |
788 | "The number of graph input cannot be smaller than the number of node input" ); |
789 | } else if (num_inputs < g_->input_size()) { |
790 | for (int i = 0; i < g_->input_size(); ++i) { |
791 | if (i < num_inputs && initializer_name_set.count(g_->input(i).name()) > 0) { |
792 | fail_shape_inference("Graph initializer names must appear after the actual inputs: " , g_->input(i).name()); |
793 | } else if (i >= num_inputs && initializer_name_set.count(g_->input(i).name()) == 0) { |
794 | // Further check whether the additional input is in initializers |
795 | fail_shape_inference("Cannot find missing input: " , g_->input(i).name(), "in initializers. " ); |
796 | } |
797 | } |
798 | } |
799 | } |
800 | |
801 | for (int i = 0, end = num_inputs; i < end; ++i) { |
802 | const TypeProto* inferred_input = input_types[i]; |
803 | |
804 | if (!inferred_input) |
805 | continue; |
806 | |
807 | TypeProto* graph_input = g_->mutable_input(i)->mutable_type(); |
808 | // Even if graphInput doesn't have defined type, it will assign inferredType to it |
809 | mergeShapesAndTypes(*inferred_input, graph_input); |
810 | |
811 | if (symbol_table) { |
812 | MaterializeSymbolicShape(graph_input, *symbol_table); |
813 | } |
814 | } |
815 | |
816 | // future: pass inputData into InferShapes either directly, or indirectly by |
817 | // updating initializers that match subgraph inputs. |
818 | (void)input_data; |
819 | ShapeInferenceOptions options{}; |
820 | InferShapesImpl( |
821 | g_, |
822 | *context_->outer_scope_value_types_by_name, // never null |
823 | context_->opset_imports, |
824 | options, |
825 | symbol_table, |
826 | context_->model_local_functions, |
827 | context_->schema_registry, |
828 | context_->generated_shape_data_by_name); |
829 | |
830 | std::vector<const TypeProto*> graph_output_types; |
831 | graph_output_types.reserve(g_->output().size()); |
832 | for (const ValueInfoProto& output : g_->output()) { |
833 | graph_output_types.push_back(&output.type()); |
834 | } |
835 | |
836 | return graph_output_types; |
837 | } |
838 | |
839 | std::string GetErrorWithNodeInfo(const NodeProto& n, std::runtime_error err) { |
840 | std::string op_name = n.has_name() ? (", node name: " + n.name()) : "" ; |
841 | return "(op_type:" + n.op_type() + op_name + "): " + err.what(); |
842 | } |
843 | |
844 | void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbol_table) { |
845 | symbol_table.addFromGraph(g); |
846 | for (const auto& n : g.node()) { |
847 | for (auto& attr : n.attribute()) { |
848 | if (attr.has_g()) { |
849 | TraverseGraphsToAddExistingSymbols(attr.g(), symbol_table); |
850 | } |
851 | } |
852 | } |
853 | } |
854 | |
855 | } // namespace shape_inference |
856 | } // namespace ONNX_NAMESPACE |
857 | |