1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#pragma once
6
7#include "onnx/defs/function.h"
8#include "onnx/defs/schema.h"
9#include "onnx/proto_utils.h"
10#include "onnx/string_utils.h"
11
12namespace ONNX_NAMESPACE {
13namespace shape_inference {
14
15using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>;
16
17class SymbolTableImpl : public SymbolTable {
18 public:
19 SymbolTableImpl() : index_(0) {}
20
21 void addFromGraph(const GraphProto& g) {
22 AddExistingSymbolicDims(g.input());
23 AddExistingSymbolicDims(g.output());
24 AddExistingSymbolicDims(g.value_info());
25 }
26 // Creates a new unique symbol with the given prefix and adds it to the SymbolTable
27 // Returns the newly created symbol
28 std::string createNew(const std::string& symbol_prefix = "unk__") {
29 std::string newSymbol;
30 do {
31 newSymbol = symbol_prefix + std::to_string(index_++);
32 } while (existing_symbols.count(newSymbol) > 0);
33 existing_symbols.insert(newSymbol);
34 return newSymbol;
35 }
36
37 private:
38 unsigned int index_;
39 std::unordered_set<std::string> existing_symbols;
40
41 // TypeProto_Tensor or TypeProto_SparseTensor
42 template <typename TensorTypeProto>
43 void AddExistingSymbolicDims(const TensorTypeProto& tensorType) {
44 if (tensorType.has_shape()) {
45 for (int i = 0; i < tensorType.shape().dim_size(); ++i) {
46 if (tensorType.shape().dim(i).has_dim_param()) {
47 existing_symbols.insert(tensorType.shape().dim(i).dim_param());
48 }
49 }
50 }
51 }
52
53 void AddExistingSymbolicDims(const TypeProto& typeProto) {
54 const auto val_case = typeProto.value_case();
55 switch (val_case) {
56 case TypeProto::kTensorType:
57 AddExistingSymbolicDims(typeProto.tensor_type());
58 break;
59 case TypeProto::kSparseTensorType:
60 AddExistingSymbolicDims(typeProto.sparse_tensor_type());
61 break;
62 case TypeProto::kSequenceType:
63 AddExistingSymbolicDims(typeProto.sequence_type().elem_type());
64 break;
65 case TypeProto::kOptionalType:
66 AddExistingSymbolicDims(typeProto.optional_type().elem_type());
67 break;
68 case TypeProto::kMapType:
69 AddExistingSymbolicDims(typeProto.map_type().value_type());
70 break;
71 default:
72 break;
73 }
74 }
75
76 void AddExistingSymbolicDims(const google::protobuf::RepeatedPtrField<ValueInfoProto>& protos) {
77 for (const auto& proto : protos) {
78 AddExistingSymbolicDims(proto.type());
79 }
80 }
81};
82
83struct GraphInferenceContext {
84 GraphInferenceContext(
85 const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in,
86 const std::unordered_map<std::string, int> opset_imports_in,
87 SymbolTable* symbol_table_in = nullptr,
88 const ModelLocalFunctionsMap& model_local_functions_in = {},
89 const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
90 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name_in = nullptr,
91 const int ir_version_in = IR_VERSION)
92 : outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in},
93 opset_imports{opset_imports_in},
94 symbol_table{symbol_table_in},
95 model_local_functions{model_local_functions_in},
96 schema_registry{schema_registry_in},
97 generated_shape_data_by_name{generated_shape_data_by_name_in},
98 ir_version{ir_version_in} {}
99
100 const std::unordered_map<std::string, TypeProto*>* outer_scope_value_types_by_name;
101 const std::unordered_map<std::string, int> opset_imports;
102 SymbolTable* symbol_table;
103 const ModelLocalFunctionsMap& model_local_functions;
104 const ISchemaRegistry* schema_registry;
105 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name;
106 const int ir_version;
107};
108
109class GraphInferencerImpl : public GraphInferencer {
110 public:
111 GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context) : g_{&g}, context_{&context} {}
112
113 std::vector<const TypeProto*> doInferencing(
114 const std::vector<const TypeProto*>& inputTypes,
115 const std::vector<const TensorProto*>& inputData) override;
116
117 private:
118 GraphProto* g_;
119 GraphInferenceContext* context_;
120};
121
122struct InferenceContextImpl : public InferenceContext {
123 InferenceContextImpl(
124 NodeProto& n,
125 const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
126 const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
127 const std::unordered_map<std::string, const SparseTensorProto*>& inputSparseDataByName,
128 std::unordered_map<std::string, TensorShapeProto>* generatedShapeData = nullptr,
129 GraphInferenceContext* graphInferenceContext = nullptr)
130 : graphInferenceContext_{graphInferenceContext} {
131 for (auto& attr : *n.mutable_attribute()) {
132 attributesByName_[attr.name()] = &attr;
133 if (attr.has_g()) {
134 // need a mutable GraphProto to run inferencing on this attribute
135 graphProtoAttributesByName_[attr.name()] = attr.mutable_g();
136 }
137 }
138
139 for (const auto& input : n.input()) {
140 auto valueTypesIter = valueTypesByName.find(input);
141 if (valueTypesIter != valueTypesByName.end()) {
142 allInputTypes_.push_back(valueTypesIter->second);
143 } else {
144 allInputTypes_.push_back(nullptr);
145 }
146
147 // input data can be in 1 of the 3 containers
148 // inputDataByName - this is when input is TensorProto
149 // inputSparseDataByName - this is when input is SparseTensorProto
150 // generatedShapeData - this is when input was generated as part of partial data propagation
151 const auto inputDataIter = inputDataByName.find(input);
152 if (inputDataIter != inputDataByName.cend()) {
153 allInputData_.push_back(inputDataIter->second);
154 allInputSparseData_.push_back(nullptr);
155 allShapeInputData_.push_back(nullptr);
156 } else {
157 allInputData_.push_back(nullptr);
158 const auto inputSparseDataIter = inputSparseDataByName.find(input);
159 if (inputSparseDataIter != inputSparseDataByName.cend()) {
160 allInputSparseData_.push_back(inputSparseDataIter->second);
161 allShapeInputData_.push_back(nullptr);
162 } else {
163 allInputSparseData_.push_back(nullptr);
164 if (generatedShapeData != nullptr) {
165 const auto inputShapeDataIter = generatedShapeData->find(input);
166 if (inputShapeDataIter != generatedShapeData->cend()) {
167 allShapeInputData_.push_back(&inputShapeDataIter->second);
168 } else {
169 allShapeInputData_.push_back(nullptr);
170 }
171 } else {
172 allShapeInputData_.push_back(nullptr);
173 }
174 }
175 }
176 }
177
178 allOutputTypes_.resize(n.output_size());
179 }
180
181 const AttributeProto* getAttribute(const std::string& name) const override {
182 auto iter = attributesByName_.find(name);
183 if (iter == attributesByName_.end()) {
184 return nullptr;
185 } else {
186 return iter->second;
187 }
188 }
189
190 size_t getNumInputs() const override {
191 return allInputTypes_.size();
192 }
193
194 const TypeProto* getInputType(size_t index) const override {
195 if (index >= allInputTypes_.size()) {
196 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
197 }
198 return allInputTypes_[index];
199 }
200
201 const TensorProto* getInputData(size_t index) const override {
202 if (index >= allInputData_.size()) {
203 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
204 }
205 return allInputData_[index];
206 }
207
208 const TensorShapeProto* getSymbolicInput(size_t index) const override {
209 if (index >= allShapeInputData_.size()) {
210 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
211 }
212
213 return allShapeInputData_[index];
214 }
215
216 const SparseTensorProto* getInputSparseData(size_t index) const override {
217 if (index >= allInputSparseData_.size()) {
218 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
219 }
220 return allInputSparseData_[index];
221 }
222
223 size_t getNumOutputs() const override {
224 return allOutputTypes_.size();
225 }
226
227 TypeProto* getOutputType(size_t index) override {
228 if (index >= allOutputTypes_.size()) {
229 ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
230 }
231 return &allOutputTypes_[index];
232 }
233
234 GraphInferencer* getGraphAttributeInferencer(const std::string& attr_name) override {
235 if (!graphInferenceContext_) {
236 fail_type_inference("GraphProto attribute inferencing is not enabled in this InferenceContextImpl instance.");
237 }
238
239 GraphInferencer* inferencer = nullptr;
240
241 auto entry = graphAttributeInferencers_.find(attr_name);
242 if (entry == graphAttributeInferencers_.cend()) {
243 // create GraphInferencer instance
244 auto attrNameToGraphProto = graphProtoAttributesByName_.find(attr_name);
245 if (attrNameToGraphProto == graphProtoAttributesByName_.cend()) {
246 fail_type_inference("Attribute ", attr_name, " does not contain a graph.");
247 }
248
249 std::unique_ptr<GraphInferencer> new_inferencer{
250 new GraphInferencerImpl(*attrNameToGraphProto->second, *graphInferenceContext_)};
251
252 inferencer = new_inferencer.get();
253 graphAttributeInferencers_.emplace(attr_name, std::move(new_inferencer));
254 } else {
255 inferencer = entry->second.get();
256 }
257
258 return inferencer;
259 }
260
261 std::vector<const TensorProto*> allInputData_;
262 std::vector<const SparseTensorProto*> allInputSparseData_;
263 std::vector<const TensorShapeProto*> allShapeInputData_;
264 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
265 std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_;
266 std::vector<const TypeProto*> allInputTypes_;
267 std::vector<TypeProto> allOutputTypes_;
268 GraphInferenceContext* graphInferenceContext_;
269
270 // mutable as internal cache of GraphInferencer instances
271 mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_;
272};
273
274struct DataPropagationContextImpl : public DataPropagationContext {
275 DataPropagationContextImpl(
276 NodeProto& n,
277 const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
278 const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
279 std::unordered_map<std::string, TensorShapeProto>& generatedShapeData)
280 : generatedShapeData_(generatedShapeData) {
281 size_t input_idx = 0;
282
283 for (auto& attr : *n.mutable_attribute()) {
284 attributesByName_[attr.name()] = &attr;
285 }
286
287 for (const auto& input : n.input()) {
288 inputIndexToNameMap_.insert({input_idx++, input});
289
290 auto valueTypesIter = valueTypesByName.find(input);
291 if (valueTypesIter != valueTypesByName.end()) {
292 allInputTypes_.push_back(valueTypesIter->second);
293 } else {
294 allInputTypes_.push_back(nullptr);
295 }
296
297 const auto inputDataIter = inputDataByName.find(input);
298 if (inputDataIter != inputDataByName.cend()) {
299 allInputData_.push_back(inputDataIter->second);
300 } else {
301 allInputData_.push_back(nullptr);
302 }
303 }
304
305 size_t output_idx = 0;
306 for (const auto& output : n.output()) {
307 outputIndexToNameMap_.insert({output_idx++, output});
308 }
309
310 allOutputTypes_.resize(n.output_size());
311 }
312
313 const AttributeProto* getAttribute(const std::string& name) const override {
314 auto iter = attributesByName_.find(name);
315 if (iter == attributesByName_.end()) {
316 return nullptr;
317 } else {
318 return iter->second;
319 }
320 }
321
322 size_t getNumInputs() const override {
323 return allInputTypes_.size();
324 }
325
326 const TypeProto* getInputType(size_t index) const override {
327 if (index >= allInputTypes_.size()) {
328 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
329 }
330 return allInputTypes_[index];
331 }
332
333 size_t getNumOutputs() const override {
334 return allOutputTypes_.size();
335 }
336
337 const TypeProto* getOutputType(size_t index) const override {
338 if (index >= allOutputTypes_.size()) {
339 ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
340 }
341 return &allOutputTypes_[index];
342 }
343
344 // Convert integer vector into TensorShapeProto
345 template <typename INTEGER>
346 void vectorToTensorShapeProto(const std::vector<INTEGER>& input_vals, TensorShapeProto& converted_tsp) const {
347 for (unsigned int i = 0; i < input_vals.size(); ++i) {
348 converted_tsp.mutable_dim()->Add()->set_dim_value(input_vals[i]);
349 }
350 }
351
352 const TensorShapeProto* getInputData(size_t index) override {
353 if (index >= allInputData_.size()) {
354 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
355 }
356 const std::string input_name = inputIndexToNameMap_.at(index);
357 // Gets it from previous data propagation
358 auto iter = generatedShapeData_.find(input_name);
359 if (iter != generatedShapeData_.end()) {
360 return &iter->second;
361 }
362 // Otherwise, gets it from initializer if it exists
363 const auto* input_data = allInputData_[index];
364 // Only scalar (0D tensor) or 1D tensor can be converted for now
365 // TODO: It should support tensors with more dimension on demand
366 if (input_data != nullptr && (input_data->dims_size() == 0 || input_data->dims_size() == 1)) {
367 TensorShapeProto tsp;
368
369 if (input_data->data_type() == TensorProto_DataType_INT64) {
370 vectorToTensorShapeProto(ParseData<int64_t>(input_data), tsp);
371 } else if (input_data->data_type() == TensorProto_DataType_INT32) {
372 vectorToTensorShapeProto(ParseData<int32_t>(input_data), tsp);
373 } else {
374 // Only supports integer type to form a shape
375 return nullptr;
376 }
377
378 // Adds this TensorShapeProto from initializer into generatedShapeData
379 // for future use
380 auto result = generatedShapeData_.insert({input_name, std::move(tsp)});
381 if (result.second) {
382 return &(result.first->second);
383 }
384 }
385 return nullptr;
386 }
387
388 void addOutputData(size_t index, TensorShapeProto&& tsp) override {
389 if (index >= outputIndexToNameMap_.size()) {
390 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
391 }
392 auto result = generatedShapeData_.insert({outputIndexToNameMap_.at(index), std::move(tsp)});
393 if (!result.second) {
394 fail_shape_inference("Data for input " + ONNX_NAMESPACE::to_string(index) + " already exists.");
395 }
396 }
397
398 std::vector<const TensorProto*> allInputData_;
399 std::unordered_map<size_t, std::string> inputIndexToNameMap_;
400 std::unordered_map<size_t, std::string> outputIndexToNameMap_;
401 std::vector<const TypeProto*> allInputTypes_;
402 std::vector<TypeProto> allOutputTypes_;
403 std::unordered_map<std::string, TensorShapeProto>& generatedShapeData_;
404 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
405};
406
407void checkShapesAndTypes(const TypeProto_Sequence& inferredType, const TypeProto_Sequence& existingType);
408
409void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existingType);
410
411template <typename TensorTypeProto>
412void GenerateSymbolicShape(TensorTypeProto* inferredType, SymbolTable& symbolTable);
413
414void MaterializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable);
415
416void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType);
417
418void mergeShapesAndTypes(const TypeProto_SparseTensor& inferredType, TypeProto_SparseTensor* existingType);
419
420void mergeShapesAndTypes(const TypeProto_Sequence& inferredType, TypeProto_Tensor* existingType);
421
422void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType);
423
424///
425/// ModelLocalFunctionsMap is a map of function id -> model local function proto
426/// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
427///
428void InferShapes(
429 GraphProto* g,
430 const std::unordered_map<std::string, int>& opset_imports,
431 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
432 const ShapeInferenceOptions& options = {},
433 const ModelLocalFunctionsMap& in_model_functions = {});
434
435void InferShapes(
436 ModelProto& m,
437 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
438 const ShapeInferenceOptions& options = {},
439 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr);
440
441void InferShapes(
442 const std::string& model_path,
443 const std::string& save_path = "",
444 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
445 const ShapeInferenceOptions& options = {},
446 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr);
447
448///
449/// ModelLocalFunctionsMap is a map of function id -> model local function proto
450/// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
451///
452void InferShapeForFunctionNode(
453 const FunctionProto& func,
454 const ISchemaRegistry* schema_registry,
455 InferenceContext& ctx,
456 const ShapeInferenceOptions& options = {},
457 const ModelLocalFunctionsMap& model_local_functions_map = {},
458 SymbolTable* symbolTable = nullptr,
459 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr);
460
461///
462/// ModelLocalFunctionsMap is a map of function id -> model local function proto
463/// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
464///
465void InferShapeForFunctionNode(
466 const FunctionProto& func_proto,
467 const std::unordered_map<std::string, int>& func_opset_imports,
468 const ISchemaRegistry* schema_registry,
469 InferenceContext& ctx,
470 const ShapeInferenceOptions& options = {},
471 const ModelLocalFunctionsMap& model_local_functions_map = {},
472 SymbolTable* symbolTable = nullptr,
473 std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name = nullptr);
474
475std::string GetErrorWithNodeInfo(const NodeProto& n, std::runtime_error err);
476
477void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbolTable);
478
479} // namespace shape_inference
480} // namespace ONNX_NAMESPACE
481