1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #pragma once |
6 | |
7 | #include "onnx/defs/function.h" |
8 | #include "onnx/defs/schema.h" |
9 | #include "onnx/proto_utils.h" |
10 | #include "onnx/string_utils.h" |
11 | |
12 | namespace ONNX_NAMESPACE { |
13 | namespace shape_inference { |
14 | |
15 | using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>; |
16 | |
17 | class SymbolTableImpl : public SymbolTable { |
18 | public: |
19 | SymbolTableImpl() : index_(0) {} |
20 | |
21 | void addFromGraph(const GraphProto& g) { |
22 | AddExistingSymbolicDims(g.input()); |
23 | AddExistingSymbolicDims(g.output()); |
24 | AddExistingSymbolicDims(g.value_info()); |
25 | } |
26 | // Creates a new unique symbol with the given prefix and adds it to the SymbolTable |
27 | // Returns the newly created symbol |
28 | std::string createNew(const std::string& symbol_prefix = "unk__" ) { |
29 | std::string newSymbol; |
30 | do { |
31 | newSymbol = symbol_prefix + std::to_string(index_++); |
32 | } while (existing_symbols.count(newSymbol) > 0); |
33 | existing_symbols.insert(newSymbol); |
34 | return newSymbol; |
35 | } |
36 | |
37 | private: |
38 | unsigned int index_; |
39 | std::unordered_set<std::string> existing_symbols; |
40 | |
41 | // TypeProto_Tensor or TypeProto_SparseTensor |
42 | template <typename TensorTypeProto> |
43 | void AddExistingSymbolicDims(const TensorTypeProto& tensorType) { |
44 | if (tensorType.has_shape()) { |
45 | for (int i = 0; i < tensorType.shape().dim_size(); ++i) { |
46 | if (tensorType.shape().dim(i).has_dim_param()) { |
47 | existing_symbols.insert(tensorType.shape().dim(i).dim_param()); |
48 | } |
49 | } |
50 | } |
51 | } |
52 | |
53 | void AddExistingSymbolicDims(const TypeProto& typeProto) { |
54 | const auto val_case = typeProto.value_case(); |
55 | switch (val_case) { |
56 | case TypeProto::kTensorType: |
57 | AddExistingSymbolicDims(typeProto.tensor_type()); |
58 | break; |
59 | case TypeProto::kSparseTensorType: |
60 | AddExistingSymbolicDims(typeProto.sparse_tensor_type()); |
61 | break; |
62 | case TypeProto::kSequenceType: |
63 | AddExistingSymbolicDims(typeProto.sequence_type().elem_type()); |
64 | break; |
65 | case TypeProto::kOptionalType: |
66 | AddExistingSymbolicDims(typeProto.optional_type().elem_type()); |
67 | break; |
68 | case TypeProto::kMapType: |
69 | AddExistingSymbolicDims(typeProto.map_type().value_type()); |
70 | break; |
71 | default: |
72 | break; |
73 | } |
74 | } |
75 | |
76 | void AddExistingSymbolicDims(const google::protobuf::RepeatedPtrField<ValueInfoProto>& protos) { |
77 | for (const auto& proto : protos) { |
78 | AddExistingSymbolicDims(proto.type()); |
79 | } |
80 | } |
81 | }; |
82 | |
83 | struct GraphInferenceContext { |
84 | GraphInferenceContext( |
85 | const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in, |
86 | const std::unordered_map<std::string, int> opset_imports_in, |
87 | SymbolTable* symbol_table_in = nullptr, |
88 | const ModelLocalFunctionsMap& model_local_functions_in = {}, |
89 | const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(), |
90 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name_in = nullptr, |
91 | const int ir_version_in = IR_VERSION) |
92 | : outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in}, |
93 | opset_imports{opset_imports_in}, |
94 | symbol_table{symbol_table_in}, |
95 | model_local_functions{model_local_functions_in}, |
96 | schema_registry{schema_registry_in}, |
97 | generated_shape_data_by_name{generated_shape_data_by_name_in}, |
98 | ir_version{ir_version_in} {} |
99 | |
100 | const std::unordered_map<std::string, TypeProto*>* outer_scope_value_types_by_name; |
101 | const std::unordered_map<std::string, int> opset_imports; |
102 | SymbolTable* symbol_table; |
103 | const ModelLocalFunctionsMap& model_local_functions; |
104 | const ISchemaRegistry* schema_registry; |
105 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name; |
106 | const int ir_version; |
107 | }; |
108 | |
109 | class GraphInferencerImpl : public GraphInferencer { |
110 | public: |
111 | GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context) : g_{&g}, context_{&context} {} |
112 | |
113 | std::vector<const TypeProto*> doInferencing( |
114 | const std::vector<const TypeProto*>& inputTypes, |
115 | const std::vector<const TensorProto*>& inputData) override; |
116 | |
117 | private: |
118 | GraphProto* g_; |
119 | GraphInferenceContext* context_; |
120 | }; |
121 | |
122 | struct InferenceContextImpl : public InferenceContext { |
123 | InferenceContextImpl( |
124 | NodeProto& n, |
125 | const std::unordered_map<std::string, TypeProto*>& valueTypesByName, |
126 | const std::unordered_map<std::string, const TensorProto*>& inputDataByName, |
127 | const std::unordered_map<std::string, const SparseTensorProto*>& inputSparseDataByName, |
128 | std::unordered_map<std::string, TensorShapeProto>* generatedShapeData = nullptr, |
129 | GraphInferenceContext* graphInferenceContext = nullptr) |
130 | : graphInferenceContext_{graphInferenceContext} { |
131 | for (auto& attr : *n.mutable_attribute()) { |
132 | attributesByName_[attr.name()] = &attr; |
133 | if (attr.has_g()) { |
134 | // need a mutable GraphProto to run inferencing on this attribute |
135 | graphProtoAttributesByName_[attr.name()] = attr.mutable_g(); |
136 | } |
137 | } |
138 | |
139 | for (const auto& input : n.input()) { |
140 | auto valueTypesIter = valueTypesByName.find(input); |
141 | if (valueTypesIter != valueTypesByName.end()) { |
142 | allInputTypes_.push_back(valueTypesIter->second); |
143 | } else { |
144 | allInputTypes_.push_back(nullptr); |
145 | } |
146 | |
147 | // input data can be in 1 of the 3 containers |
148 | // inputDataByName - this is when input is TensorProto |
149 | // inputSparseDataByName - this is when input is SparseTensorProto |
150 | // generatedShapeData - this is when input was generated as part of partial data propagation |
151 | const auto inputDataIter = inputDataByName.find(input); |
152 | if (inputDataIter != inputDataByName.cend()) { |
153 | allInputData_.push_back(inputDataIter->second); |
154 | allInputSparseData_.push_back(nullptr); |
155 | allShapeInputData_.push_back(nullptr); |
156 | } else { |
157 | allInputData_.push_back(nullptr); |
158 | const auto inputSparseDataIter = inputSparseDataByName.find(input); |
159 | if (inputSparseDataIter != inputSparseDataByName.cend()) { |
160 | allInputSparseData_.push_back(inputSparseDataIter->second); |
161 | allShapeInputData_.push_back(nullptr); |
162 | } else { |
163 | allInputSparseData_.push_back(nullptr); |
164 | if (generatedShapeData != nullptr) { |
165 | const auto inputShapeDataIter = generatedShapeData->find(input); |
166 | if (inputShapeDataIter != generatedShapeData->cend()) { |
167 | allShapeInputData_.push_back(&inputShapeDataIter->second); |
168 | } else { |
169 | allShapeInputData_.push_back(nullptr); |
170 | } |
171 | } else { |
172 | allShapeInputData_.push_back(nullptr); |
173 | } |
174 | } |
175 | } |
176 | } |
177 | |
178 | allOutputTypes_.resize(n.output_size()); |
179 | } |
180 | |
181 | const AttributeProto* getAttribute(const std::string& name) const override { |
182 | auto iter = attributesByName_.find(name); |
183 | if (iter == attributesByName_.end()) { |
184 | return nullptr; |
185 | } else { |
186 | return iter->second; |
187 | } |
188 | } |
189 | |
190 | size_t getNumInputs() const override { |
191 | return allInputTypes_.size(); |
192 | } |
193 | |
194 | const TypeProto* getInputType(size_t index) const override { |
195 | if (index >= allInputTypes_.size()) { |
196 | ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
197 | } |
198 | return allInputTypes_[index]; |
199 | } |
200 | |
201 | const TensorProto* getInputData(size_t index) const override { |
202 | if (index >= allInputData_.size()) { |
203 | ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
204 | } |
205 | return allInputData_[index]; |
206 | } |
207 | |
208 | const TensorShapeProto* getSymbolicInput(size_t index) const override { |
209 | if (index >= allShapeInputData_.size()) { |
210 | ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
211 | } |
212 | |
213 | return allShapeInputData_[index]; |
214 | } |
215 | |
216 | const SparseTensorProto* getInputSparseData(size_t index) const override { |
217 | if (index >= allInputSparseData_.size()) { |
218 | ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
219 | } |
220 | return allInputSparseData_[index]; |
221 | } |
222 | |
223 | size_t getNumOutputs() const override { |
224 | return allOutputTypes_.size(); |
225 | } |
226 | |
227 | TypeProto* getOutputType(size_t index) override { |
228 | if (index >= allOutputTypes_.size()) { |
229 | ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
230 | } |
231 | return &allOutputTypes_[index]; |
232 | } |
233 | |
234 | GraphInferencer* getGraphAttributeInferencer(const std::string& attr_name) override { |
235 | if (!graphInferenceContext_) { |
236 | fail_type_inference("GraphProto attribute inferencing is not enabled in this InferenceContextImpl instance." ); |
237 | } |
238 | |
239 | GraphInferencer* inferencer = nullptr; |
240 | |
241 | auto entry = graphAttributeInferencers_.find(attr_name); |
242 | if (entry == graphAttributeInferencers_.cend()) { |
243 | // create GraphInferencer instance |
244 | auto attrNameToGraphProto = graphProtoAttributesByName_.find(attr_name); |
245 | if (attrNameToGraphProto == graphProtoAttributesByName_.cend()) { |
246 | fail_type_inference("Attribute " , attr_name, " does not contain a graph." ); |
247 | } |
248 | |
249 | std::unique_ptr<GraphInferencer> new_inferencer{ |
250 | new GraphInferencerImpl(*attrNameToGraphProto->second, *graphInferenceContext_)}; |
251 | |
252 | inferencer = new_inferencer.get(); |
253 | graphAttributeInferencers_.emplace(attr_name, std::move(new_inferencer)); |
254 | } else { |
255 | inferencer = entry->second.get(); |
256 | } |
257 | |
258 | return inferencer; |
259 | } |
260 | |
261 | std::vector<const TensorProto*> allInputData_; |
262 | std::vector<const SparseTensorProto*> allInputSparseData_; |
263 | std::vector<const TensorShapeProto*> allShapeInputData_; |
264 | std::unordered_map<std::string, const AttributeProto*> attributesByName_; |
265 | std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_; |
266 | std::vector<const TypeProto*> allInputTypes_; |
267 | std::vector<TypeProto> allOutputTypes_; |
268 | GraphInferenceContext* graphInferenceContext_; |
269 | |
270 | // mutable as internal cache of GraphInferencer instances |
271 | mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_; |
272 | }; |
273 | |
274 | struct DataPropagationContextImpl : public DataPropagationContext { |
275 | DataPropagationContextImpl( |
276 | NodeProto& n, |
277 | const std::unordered_map<std::string, TypeProto*>& valueTypesByName, |
278 | const std::unordered_map<std::string, const TensorProto*>& inputDataByName, |
279 | std::unordered_map<std::string, TensorShapeProto>& generatedShapeData) |
280 | : generatedShapeData_(generatedShapeData) { |
281 | size_t input_idx = 0; |
282 | |
283 | for (auto& attr : *n.mutable_attribute()) { |
284 | attributesByName_[attr.name()] = &attr; |
285 | } |
286 | |
287 | for (const auto& input : n.input()) { |
288 | inputIndexToNameMap_.insert({input_idx++, input}); |
289 | |
290 | auto valueTypesIter = valueTypesByName.find(input); |
291 | if (valueTypesIter != valueTypesByName.end()) { |
292 | allInputTypes_.push_back(valueTypesIter->second); |
293 | } else { |
294 | allInputTypes_.push_back(nullptr); |
295 | } |
296 | |
297 | const auto inputDataIter = inputDataByName.find(input); |
298 | if (inputDataIter != inputDataByName.cend()) { |
299 | allInputData_.push_back(inputDataIter->second); |
300 | } else { |
301 | allInputData_.push_back(nullptr); |
302 | } |
303 | } |
304 | |
305 | size_t output_idx = 0; |
306 | for (const auto& output : n.output()) { |
307 | outputIndexToNameMap_.insert({output_idx++, output}); |
308 | } |
309 | |
310 | allOutputTypes_.resize(n.output_size()); |
311 | } |
312 | |
313 | const AttributeProto* getAttribute(const std::string& name) const override { |
314 | auto iter = attributesByName_.find(name); |
315 | if (iter == attributesByName_.end()) { |
316 | return nullptr; |
317 | } else { |
318 | return iter->second; |
319 | } |
320 | } |
321 | |
322 | size_t getNumInputs() const override { |
323 | return allInputTypes_.size(); |
324 | } |
325 | |
326 | const TypeProto* getInputType(size_t index) const override { |
327 | if (index >= allInputTypes_.size()) { |
328 | ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
329 | } |
330 | return allInputTypes_[index]; |
331 | } |
332 | |
333 | size_t getNumOutputs() const override { |
334 | return allOutputTypes_.size(); |
335 | } |
336 | |
337 | const TypeProto* getOutputType(size_t index) const override { |
338 | if (index >= allOutputTypes_.size()) { |
339 | ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
340 | } |
341 | return &allOutputTypes_[index]; |
342 | } |
343 | |
344 | // Convert integer vector into TensorShapeProto |
345 | template <typename INTEGER> |
346 | void vectorToTensorShapeProto(const std::vector<INTEGER>& input_vals, TensorShapeProto& converted_tsp) const { |
347 | for (unsigned int i = 0; i < input_vals.size(); ++i) { |
348 | converted_tsp.mutable_dim()->Add()->set_dim_value(input_vals[i]); |
349 | } |
350 | } |
351 | |
352 | const TensorShapeProto* getInputData(size_t index) override { |
353 | if (index >= allInputData_.size()) { |
354 | ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
355 | } |
356 | const std::string input_name = inputIndexToNameMap_.at(index); |
357 | // Gets it from previous data propagation |
358 | auto iter = generatedShapeData_.find(input_name); |
359 | if (iter != generatedShapeData_.end()) { |
360 | return &iter->second; |
361 | } |
362 | // Otherwise, gets it from initializer if it exists |
363 | const auto* input_data = allInputData_[index]; |
364 | // Only scalar (0D tensor) or 1D tensor can be converted for now |
365 | // TODO: It should support tensors with more dimension on demand |
366 | if (input_data != nullptr && (input_data->dims_size() == 0 || input_data->dims_size() == 1)) { |
367 | TensorShapeProto tsp; |
368 | |
369 | if (input_data->data_type() == TensorProto_DataType_INT64) { |
370 | vectorToTensorShapeProto(ParseData<int64_t>(input_data), tsp); |
371 | } else if (input_data->data_type() == TensorProto_DataType_INT32) { |
372 | vectorToTensorShapeProto(ParseData<int32_t>(input_data), tsp); |
373 | } else { |
374 | // Only supports integer type to form a shape |
375 | return nullptr; |
376 | } |
377 | |
378 | // Adds this TensorShapeProto from initializer into generatedShapeData |
379 | // for future use |
380 | auto result = generatedShapeData_.insert({input_name, std::move(tsp)}); |
381 | if (result.second) { |
382 | return &(result.first->second); |
383 | } |
384 | } |
385 | return nullptr; |
386 | } |
387 | |
388 | void addOutputData(size_t index, TensorShapeProto&& tsp) override { |
389 | if (index >= outputIndexToNameMap_.size()) { |
390 | ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds." ); |
391 | } |
392 | auto result = generatedShapeData_.insert({outputIndexToNameMap_.at(index), std::move(tsp)}); |
393 | if (!result.second) { |
394 | fail_shape_inference("Data for input " + ONNX_NAMESPACE::to_string(index) + " already exists." ); |
395 | } |
396 | } |
397 | |
398 | std::vector<const TensorProto*> allInputData_; |
399 | std::unordered_map<size_t, std::string> inputIndexToNameMap_; |
400 | std::unordered_map<size_t, std::string> outputIndexToNameMap_; |
401 | std::vector<const TypeProto*> allInputTypes_; |
402 | std::vector<TypeProto> allOutputTypes_; |
403 | std::unordered_map<std::string, TensorShapeProto>& generatedShapeData_; |
404 | std::unordered_map<std::string, const AttributeProto*> attributesByName_; |
405 | }; |
406 | |
407 | void checkShapesAndTypes(const TypeProto_Sequence& inferredType, const TypeProto_Sequence& existingType); |
408 | |
409 | void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existingType); |
410 | |
411 | template <typename TensorTypeProto> |
412 | void GenerateSymbolicShape(TensorTypeProto* inferredType, SymbolTable& symbolTable); |
413 | |
414 | void MaterializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable); |
415 | |
416 | void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType); |
417 | |
418 | void mergeShapesAndTypes(const TypeProto_SparseTensor& inferredType, TypeProto_SparseTensor* existingType); |
419 | |
420 | void mergeShapesAndTypes(const TypeProto_Sequence& inferredType, TypeProto_Tensor* existingType); |
421 | |
422 | void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType); |
423 | |
424 | /// |
425 | /// ModelLocalFunctionsMap is a map of function id -> model local function proto |
426 | /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name> |
427 | /// |
428 | void InferShapes( |
429 | GraphProto* g, |
430 | const std::unordered_map<std::string, int>& opset_imports, |
431 | const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(), |
432 | const ShapeInferenceOptions& options = {}, |
433 | const ModelLocalFunctionsMap& in_model_functions = {}); |
434 | |
435 | void InferShapes( |
436 | ModelProto& m, |
437 | const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(), |
438 | const ShapeInferenceOptions& options = {}, |
439 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr); |
440 | |
441 | void InferShapes( |
442 | const std::string& model_path, |
443 | const std::string& save_path = "" , |
444 | const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(), |
445 | const ShapeInferenceOptions& options = {}, |
446 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr); |
447 | |
448 | /// |
449 | /// ModelLocalFunctionsMap is a map of function id -> model local function proto |
450 | /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name> |
451 | /// |
452 | void InferShapeForFunctionNode( |
453 | const FunctionProto& func, |
454 | const ISchemaRegistry* schema_registry, |
455 | InferenceContext& ctx, |
456 | const ShapeInferenceOptions& options = {}, |
457 | const ModelLocalFunctionsMap& model_local_functions_map = {}, |
458 | SymbolTable* symbolTable = nullptr, |
459 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr); |
460 | |
461 | /// |
462 | /// ModelLocalFunctionsMap is a map of function id -> model local function proto |
463 | /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name> |
464 | /// |
465 | void InferShapeForFunctionNode( |
466 | const FunctionProto& func_proto, |
467 | const std::unordered_map<std::string, int>& func_opset_imports, |
468 | const ISchemaRegistry* schema_registry, |
469 | InferenceContext& ctx, |
470 | const ShapeInferenceOptions& options = {}, |
471 | const ModelLocalFunctionsMap& model_local_functions_map = {}, |
472 | SymbolTable* symbolTable = nullptr, |
473 | std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr); |
474 | |
475 | std::string GetErrorWithNodeInfo(const NodeProto& n, std::runtime_error err); |
476 | |
477 | void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbolTable); |
478 | |
479 | } // namespace shape_inference |
480 | } // namespace ONNX_NAMESPACE |
481 | |