1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "onnx/defs/schema.h" |
6 | #include <stdexcept> |
7 | #include <unordered_set> |
8 | #include "onnx/checker.h" |
9 | #include "onnx/defs/operator_sets.h" |
10 | #include "onnx/defs/operator_sets_preview.h" |
11 | #include "onnx/defs/operator_sets_training.h" |
12 | |
13 | #ifdef ONNX_ML |
14 | #include "onnx/defs/operator_sets_ml.h" |
15 | #endif |
16 | |
17 | #include "onnx/common/assertions.h" |
18 | #include "onnx/common/stl_backports.h" |
19 | #include "onnx/defs/parser.h" |
20 | |
21 | namespace ONNX_NAMESPACE { |
22 | // -1 means ONNX schema hasn't been loaded yet |
23 | // 0 means all versions of ONNX schema have been loaded |
24 | // Other positive integer means the ONNX schemas for the specified version have been loaded |
25 | int OpSchemaRegistry::loaded_schema_version = -1; |
26 | |
27 | constexpr int OpSchema::kUninitializedSinceVersion; |
28 | |
29 | // By default if opset_version_to_load=0, it registers all opset schema for all opset versions |
30 | // Otherwise, it only registers the latest schema according to opset_version_to_load |
31 | void RegisterSchema(OpSchema schema, int opset_version_to_load) { |
32 | OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(schema, opset_version_to_load); |
33 | } |
34 | |
35 | #ifndef NDEBUG |
36 | DbgOperatorSetTracker& DbgOperatorSetTracker::Instance() { |
37 | static DbgOperatorSetTracker instance; |
38 | return instance; |
39 | } |
40 | #endif |
41 | |
42 | const std::string& OpSchema::FormalParameter::GetName() const { |
43 | return name_; |
44 | } |
45 | |
46 | const DataTypeSet& OpSchema::FormalParameter::GetTypes() const { |
47 | return type_set_; |
48 | } |
49 | |
50 | DataTypeSet& OpSchema::FormalParameter::MutableTypes() { |
51 | return type_set_; |
52 | } |
53 | |
54 | const std::string& OpSchema::FormalParameter::GetTypeStr() const { |
55 | return type_str_; |
56 | } |
57 | |
58 | const std::string& OpSchema::FormalParameter::GetDescription() const { |
59 | return description_; |
60 | } |
61 | |
62 | OpSchema::FormalParameterOption OpSchema::FormalParameter::GetOption() const { |
63 | return param_option_; |
64 | } |
65 | |
66 | bool OpSchema::FormalParameter::GetIsHomogeneous() const { |
67 | return is_homogeneous_; |
68 | } |
69 | |
70 | int OpSchema::FormalParameter::GetMinArity() const { |
71 | return min_arity_; |
72 | } |
73 | |
74 | OpSchema::DifferentiationCategory OpSchema::FormalParameter::GetDifferentiationCategory() const { |
75 | return differentiation_category_; |
76 | } |
77 | |
78 | OpSchemaRegistry* OpSchemaRegistry::Instance() { |
79 | static OpSchemaRegistry instance; |
80 | return &instance; |
81 | } |
82 | |
83 | void OpSchema::CheckInputOutputType(struct InferenceContext& ctx) const { |
84 | std::unordered_map<std::string, std::string> type_constraints; |
85 | // check all input types |
86 | for (size_t in_idx = 0; in_idx < ctx.getNumInputs(); ++in_idx) { |
87 | // If the last input is Variadic by definition, checker still needs to check the rest of actual input's type |
88 | const auto& param = (in_idx < inputs_.size()) ? inputs_[in_idx] : inputs_.back(); |
89 | const auto& type_str = param.GetTypeStr(); |
90 | const auto& param_type = ctx.getInputType(in_idx); |
91 | const auto& all_types = param.GetTypes(); |
92 | if (nullptr == param_type || param_type->value_case() == TypeProto::VALUE_NOT_SET) { |
93 | continue; |
94 | } else if (!all_types.empty() && all_types.find(Utils::DataTypeUtils::ToType(*param_type)) == all_types.end()) { |
95 | fail_check( |
96 | param.GetName(), |
97 | " typestr: " , |
98 | type_str, |
99 | ", has unsupported type: " , |
100 | *Utils::DataTypeUtils::ToType(*param_type)); |
101 | } |
102 | if (param.GetIsHomogeneous()) { |
103 | const auto& type_proto = Utils::DataTypeUtils::ToType(*param_type); |
104 | auto p = type_constraints.emplace(type_str, *type_proto); |
105 | if (!p.second) { |
106 | // failed to insert a new element due to a duplication, now check consistency |
107 | if (p.first->second != *type_proto) { |
108 | fail_check(param.GetName(), " has inconsistent type " , *Utils::DataTypeUtils::ToType(*param_type)); |
109 | } |
110 | } |
111 | } |
112 | } // for inputs |
113 | // check all output types |
114 | for (size_t out_idx = 0; out_idx < ctx.getNumOutputs(); ++out_idx) { |
115 | // If the last output is Variadic by definition, checker still needs to check the rest of actual output's type |
116 | const auto& param = (out_idx < outputs_.size()) ? outputs_[out_idx] : outputs_.back(); |
117 | const auto& type_str = param.GetTypeStr(); |
118 | const auto& param_type = ctx.getOutputType(out_idx); |
119 | const auto& all_types = param.GetTypes(); |
120 | bool output_type_found = true; |
121 | // infer type if necessary |
122 | if (param_type->value_case() == TypeProto::VALUE_NOT_SET) { |
123 | if (all_types.size() == 1) { |
124 | *param_type = Utils::DataTypeUtils::ToTypeProto(*all_types.begin()); |
125 | } else if (type_constraints.find(type_str) != type_constraints.end()) { |
126 | auto data_type = Utils::DataTypeUtils::ToType(type_constraints[type_str]); |
127 | *param_type = Utils::DataTypeUtils::ToTypeProto(data_type); |
128 | } else { |
129 | output_type_found = false; |
130 | } |
131 | } |
132 | if (!output_type_found) { |
133 | continue; |
134 | } |
135 | if (!all_types.empty() && all_types.find(Utils::DataTypeUtils::ToType(*param_type)) == all_types.end()) { |
136 | fail_check(param.GetName(), " has unsupported type " , *Utils::DataTypeUtils::ToType(*param_type)); |
137 | } |
138 | if (param.GetIsHomogeneous()) { |
139 | const auto& type_proto = Utils::DataTypeUtils::ToType(*param_type); |
140 | if (type_constraints.find(type_str) == type_constraints.end()) { |
141 | type_constraints[type_str] = *type_proto; |
142 | } else if (type_constraints[type_str] != *type_proto) { |
143 | fail_check(param.GetName(), " has inconsistent type " , *Utils::DataTypeUtils::ToType(*param_type)); |
144 | } |
145 | } // else |
146 | } // for outputs |
147 | } |
148 | |
149 | void OpSchema::Verify(const NodeProto& node) const { |
150 | if (deprecated_) { |
151 | fail_check("Operator '" , name_, "' has been deprecated since version " , since_version_); |
152 | } |
153 | |
154 | // Check the number of inputs. |
155 | if (node.input_size() < min_input_ || node.input_size() > max_input_) { |
156 | fail_check( |
157 | "Node (" , |
158 | node.name(), |
159 | ") has input size " , |
160 | node.input_size(), |
161 | " not in range [min=" , |
162 | min_input_, |
163 | ", max=" , |
164 | max_input_, |
165 | "]." ); |
166 | } |
167 | |
168 | if (!num_inputs_allowed_(node.input_size())) { |
169 | fail_check("Node (" , node.name(), ") has input size " , node.input_size(), " not in allowed input sizes." ); |
170 | } |
171 | |
172 | // Check the number of outputs. |
173 | if (node.output_size() < min_output_ || node.output_size() > max_output_) { |
174 | fail_check( |
175 | "Node (" , |
176 | node.name(), |
177 | ") has output size " , |
178 | node.output_size(), |
179 | " not in range [min=" , |
180 | min_output_, |
181 | ", max=" , |
182 | max_output_, |
183 | "]." ); |
184 | } |
185 | |
186 | if (!num_outputs_allowed_(node.output_size())) { |
187 | fail_check("Node (" , node.name(), "has output size " , node.output_size(), " not in allowed output sizes." ); |
188 | } |
189 | |
190 | // Check the values of inputs / outputs |
191 | for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) { |
192 | if (in_idx >= static_cast<int>(inputs_.size())) { |
193 | if (!inputs_.empty() && Variadic == inputs_.back().GetOption()) { |
194 | // The last input formal parameter should be variadic. |
195 | break; |
196 | } else { |
197 | fail_check( |
198 | "Node (" , |
199 | node.name(), |
200 | ") has more inputs (" , |
201 | node.input_size(), |
202 | ") than declared (" , |
203 | inputs_.size(), |
204 | ") in op definition." ); |
205 | } |
206 | } |
207 | if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) { |
208 | fail_check("Node (" , node.name(), ")'s input " , in_idx, " is marked single but has an empty string in the graph" ); |
209 | } |
210 | } |
211 | |
212 | for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) { |
213 | if (out_idx >= static_cast<int>(outputs_.size())) { |
214 | if (!outputs_.empty() && Variadic == outputs_.back().GetOption()) { |
215 | // The last output formal parameter should be variadic. |
216 | break; |
217 | } else { |
218 | fail_check( |
219 | "Node (" , |
220 | node.name(), |
221 | ") has more outputs (" , |
222 | node.output_size(), |
223 | ") than declared (" , |
224 | outputs_.size(), |
225 | ") in op definition." ); |
226 | } |
227 | } |
228 | |
229 | if (node.output(out_idx).empty() && (Single == outputs_[out_idx].GetOption())) { |
230 | fail_check( |
231 | "Node (" , node.name(), ")'s output " , out_idx, " is marked single but has an empty string in the graph" ); |
232 | } |
233 | } |
234 | |
235 | // An internal symbol is defined as starting with two underscores. Attributes |
236 | // with names meeting this condition are considered implementation details |
237 | // and should be ignored for the purpose of schema checking. |
238 | auto isInternalSymbol = [](const std::string& sym) -> bool { |
239 | return sym.length() >= 2 && sym[0] == '_' && sym[1] == '_'; |
240 | }; |
241 | |
242 | // Check attributes |
243 | std::unordered_set<std::string> seen_attr_names{}; |
244 | for (const auto& attr_proto : node.attribute()) { |
245 | const auto& name = attr_proto.name(); |
246 | |
247 | if (!seen_attr_names.insert(name).second) { |
248 | fail_check("Attribute '" , name, "' appeared multiple times." ); |
249 | }; |
250 | |
251 | const auto& search = attributes_.find(name); |
252 | AttributeProto::AttributeType expected_type; |
253 | if (search != attributes_.end()) { |
254 | expected_type = search->second.type; |
255 | } else if (allows_unchecked_attributes_ || isInternalSymbol(name)) { |
256 | continue; |
257 | } else { |
258 | fail_check("Unrecognized attribute: " , name, " for operator " , node.op_type()); |
259 | } |
260 | |
261 | // Type would be UNDEFINED if not set |
262 | if (attr_proto.type() != expected_type) { |
263 | fail_check("Mismatched attribute type in '" , node.name() + " : " + name, "'" ); |
264 | } |
265 | |
266 | // ref_attr_name is only valid when non-empty |
267 | // we simply read default value if not present |
268 | if (!attr_proto.ref_attr_name().empty()) { |
269 | continue; |
270 | } |
271 | |
272 | switch (expected_type) { |
273 | // if attr_proto().type() != UNDEFINED |
274 | // we consider primitive types to be set even |
275 | // if proto3 did not output default values into the stream |
276 | // in which case we will read the default |
277 | case AttributeProto::FLOAT: |
278 | case AttributeProto::INT: |
279 | case AttributeProto::STRING: |
280 | break; |
281 | case AttributeProto::TENSOR: |
282 | if (!attr_proto.has_t()) { |
283 | fail_check("Attribute '" , name, "' is expected to have field 't'" ); |
284 | } |
285 | break; |
286 | case AttributeProto::SPARSE_TENSOR: |
287 | if (!attr_proto.has_sparse_tensor()) { |
288 | fail_check("Attribute '" , name, "' is expected to have field 'sparse_tensor'" ); |
289 | } |
290 | break; |
291 | case AttributeProto::GRAPH: |
292 | if (!attr_proto.has_g()) { |
293 | fail_check("Attribute '" , name, "' is expected to have field 'g'" ); |
294 | } |
295 | break; |
296 | case AttributeProto::TYPE_PROTO: |
297 | if (!attr_proto.has_tp()) { |
298 | fail_check("Attribute '" , name, "' is expected to have field 'type_proto'" ); |
299 | } |
300 | break; |
301 | case AttributeProto::FLOATS: |
302 | if (!attr_proto.floats_size()) { |
303 | fail_check("Attribute '" , name, "' is expected to have field 'floats'" ); |
304 | } |
305 | break; |
306 | case AttributeProto::INTS: |
307 | if (!attr_proto.ints_size()) { |
308 | fail_check("Attribute '" , name, "' is expected to have field 'ints'" ); |
309 | } |
310 | break; |
311 | case AttributeProto::STRINGS: |
312 | if (!attr_proto.strings_size()) { |
313 | fail_check("Attribute '" , name, "' is expected to have field 'strings'" ); |
314 | } |
315 | break; |
316 | case AttributeProto::TENSORS: |
317 | if (!attr_proto.tensors_size()) { |
318 | fail_check("Attribute '" , name, "' is expected to have field 'tensors'" ); |
319 | } |
320 | break; |
321 | case AttributeProto::SPARSE_TENSORS: |
322 | // Not adding check ... we should likely delete the check in all other |
323 | // cases, which will not allow us to have an empty list as a valid value |
324 | // for an attribute and this seems undesirable. |
325 | break; |
326 | case AttributeProto::GRAPHS: |
327 | if (!attr_proto.graphs_size()) { |
328 | fail_check("Attribute '" , name, "' is expected to have field 'graphs'" ); |
329 | } |
330 | break; |
331 | case AttributeProto::TYPE_PROTOS: |
332 | if (!attr_proto.type_protos_size()) { |
333 | fail_check("Attribute '" , name, "' is expected to have field 'type_protos'" ); |
334 | } |
335 | break; |
336 | default: |
337 | fail_check("Attribute '" , name, " has unknown expected type" ); |
338 | } |
339 | } |
340 | for (const auto& pair : attributes_) { |
341 | const auto& attr = pair.second; |
342 | if (!attr.required) { |
343 | continue; |
344 | } |
345 | if (!seen_attr_names.count(attr.name)) { |
346 | fail_check("Required attribute '" , attr.name, "' is missing." ); |
347 | } |
348 | } |
349 | |
350 | // Phew. All verifications passed. |
351 | } |
352 | |
353 | OpSchema& OpSchema::SinceVersion(OperatorSetVersion v) { |
354 | since_version_ = v; |
355 | |
356 | // SinceVersion is called after FunctionBody and SetContextDependentFunctionBodyBuilder are called |
357 | // when defining a op. |
358 | // FunctionBody() and SetContextDependentFunctionBodyBuilder() use -1 as the default opset_version |
359 | // default opset_version is for a FunctionProto of the same opset_version as the op's since_version_. |
360 | // It is indexed with -1 so we need to reindex it with since_version_. |
361 | // |
362 | // FunctionProtos of non-default opset_versions are for models whose opset version is higher than the op's |
363 | // opset version such that ops used in the default function_proto are no longer valid. For example: |
364 | // A model of opset version 18 contains a LayerNormalization op. |
365 | // LayerNormalization is function op whese function body uses ReduceMean op. |
366 | // LayerNormalization's since_version is 17 thus it is good for the model of opset 18. |
367 | // however, if a runtime needs to inline LayerNormalization, the inlined model has a ReduceMean op. |
368 | // ReduceMean in opset 18 is different from opset 17. |
369 | // This requires us to define more than one function body |
370 | std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it = |
371 | opset_version_to_function_builder_.find(OpSchema::kUninitializedSinceVersion); |
372 | |
373 | if (it != opset_version_to_function_builder_.cend()) { |
374 | opset_version_to_function_builder_[since_version_] = it->second; |
375 | opset_version_to_function_builder_.erase(it); |
376 | } |
377 | |
378 | std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it_function_body = |
379 | opset_version_to_function_body_.find(OpSchema::kUninitializedSinceVersion); |
380 | if (it_function_body != opset_version_to_function_body_.cend()) { |
381 | opset_version_to_function_body_[since_version_] = it_function_body->second; |
382 | UpdateFunctionProtoOpsetImportVersion(*opset_version_to_function_body_[since_version_], since_version_); |
383 | opset_version_to_function_body_.erase(it_function_body); |
384 | } |
385 | |
386 | return *this; |
387 | } |
388 | |
389 | OpSchema& OpSchema::Deprecate() { |
390 | deprecated_ = true; |
391 | return *this; |
392 | } |
393 | |
394 | OpSchema& OpSchema::NumInputs(std::set<int> allowed_input_nums) { |
395 | num_inputs_allowed_ = [MOVE_CAPTURE_IF_CPP14(allowed_input_nums)](int n) -> bool { |
396 | return allowed_input_nums.count(n); |
397 | }; |
398 | return *this; |
399 | } |
400 | |
401 | OpSchema& OpSchema::NumOutputs(std::set<int> allowed_output_nums) { |
402 | num_outputs_allowed_ = [MOVE_CAPTURE_IF_CPP14(allowed_output_nums)](int n) -> bool { |
403 | return allowed_output_nums.count(n) > 0; |
404 | }; |
405 | return *this; |
406 | } |
407 | |
408 | OpSchema& OpSchema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction) { |
409 | tensor_inference_function_ = std::move(inferenceFunction); |
410 | return *this; |
411 | } |
412 | |
413 | OpSchema& OpSchema::PartialDataPropagationFunction(DataPropagationFunction dataPropagationFunction) { |
414 | data_propagation_function_ = std::move(dataPropagationFunction); |
415 | return *this; |
416 | } |
417 | |
418 | OpSchema& OpSchema::SetSupportLevel(SupportType support) { |
419 | support_ = support; |
420 | return *this; |
421 | } |
422 | |
423 | // Functions to specify name for the operator schema. |
424 | OpSchema& OpSchema::SetName(std::string name) { |
425 | name_ = std::move(name); |
426 | return *this; |
427 | } |
428 | |
429 | OpSchema& OpSchema::SetName(const char* name) { |
430 | return SetName(std::string(name)); |
431 | } |
432 | |
433 | // Functions to specify code location for the operator schema. |
434 | OpSchema& OpSchema::SetLocation(std::string file, int line) { |
435 | file_ = std::move(file); |
436 | line_ = line; |
437 | return *this; |
438 | } |
439 | |
440 | OpSchema& OpSchema::SetLocation(const char* file, int line) { |
441 | return SetLocation(std::string(file), line); |
442 | } |
443 | |
444 | OpSchema& OpSchema::SetDomain(std::string domain) { |
445 | domain_ = std::move(domain); |
446 | return *this; |
447 | } |
448 | |
449 | OpSchema& OpSchema::SetDomain(const char* domain) { |
450 | return SetDomain(std::string(domain)); |
451 | } |
452 | |
453 | OpSchema& OpSchema::Attr(Attribute attr) { |
454 | auto name = attr.name; // copy name so we can move attr in the next line |
455 | attributes_.insert(std::make_pair(std::move(name), std::move(attr))); |
456 | return *this; |
457 | } |
458 | |
459 | OpSchema& OpSchema::Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required) { |
460 | Attr(Attribute{std::move(name), std::move(description), type, required}); |
461 | return *this; |
462 | } |
463 | |
464 | OpSchema& OpSchema::Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required) { |
465 | return Attr(std::string(name), std::string(description), type, required); |
466 | } |
467 | |
468 | #define ATTR_SETTER_WITH_SINGLE_VALUE(type, field, attrtype) \ |
469 | OpSchema& OpSchema::Attr( \ |
470 | std::string name, std::string description, AttributeProto::AttributeType attr_type, const type& default_value) { \ |
471 | if (attrtype != attr_type) { \ |
472 | fail_schema("Attribute specification type mismatch."); \ |
473 | } \ |
474 | AttributeProto a; \ |
475 | a.set_name(name); \ |
476 | a.set_##field(default_value); \ |
477 | a.set_type(attr_type); \ |
478 | Attr(Attribute(std::move(name), std::move(description), std::move(a))); \ |
479 | return *this; \ |
480 | } \ |
481 | OpSchema& OpSchema::Attr( \ |
482 | const char* name, const char* description, AttributeProto::AttributeType attr_type, const type& default_value) { \ |
483 | return Attr(std::string(name), std::string(description), attr_type, default_value); \ |
484 | } |
485 | |
486 | #define ATTR_SETTER_WITH_LIST_VALUE(type, field, attrtype) \ |
487 | OpSchema& OpSchema::Attr( \ |
488 | std::string name, \ |
489 | std::string description, \ |
490 | AttributeProto::AttributeType attr_type, \ |
491 | const std::vector<type>& default_value) { \ |
492 | if (attrtype != attr_type) { \ |
493 | fail_schema("Attribute specification type mismatch."); \ |
494 | } \ |
495 | AttributeProto a; \ |
496 | a.set_name(name); \ |
497 | a.set_type(attr_type); \ |
498 | for (const auto& v : default_value) { \ |
499 | a.add_##field(v); \ |
500 | } \ |
501 | Attr(Attribute(std::move(name), std::move(description), std::move(a))); \ |
502 | return *this; \ |
503 | } |
504 | |
505 | #define ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(type, field, attrtype) \ |
506 | OpSchema& OpSchema::Attr( \ |
507 | std::string name, std::string description, AttributeProto::AttributeType attr_type, const type& default_value) { \ |
508 | if (attrtype != attr_type) { \ |
509 | fail_schema("Attribute specification type mismatch."); \ |
510 | } \ |
511 | AttributeProto a; \ |
512 | a.set_name(name); \ |
513 | *(a.mutable_##field()) = default_value; \ |
514 | a.set_type(attr_type); \ |
515 | Attr(Attribute(std::move(name), std::move(description), a)); \ |
516 | return *this; \ |
517 | } |
518 | |
519 | #define ATTR_SETTER_WITH_LIST_COMPLEXVALUE(type, field, attrtype) \ |
520 | OpSchema& OpSchema::Attr( \ |
521 | std::string name, \ |
522 | std::string description, \ |
523 | AttributeProto::AttributeType attr_type, \ |
524 | const std::vector<type>& default_value) { \ |
525 | if (attrtype != attr_type) { \ |
526 | fail_schema("Attribute specification type mismatch."); \ |
527 | } \ |
528 | AttributeProto a; \ |
529 | a.set_name(name); \ |
530 | a.set_type(attr_type); \ |
531 | for (const auto& v : default_value) { \ |
532 | *(a.add_##field()) = v; \ |
533 | } \ |
534 | Attr(Attribute(std::move(name), std::move(description), std::move(a))); \ |
535 | return *this; \ |
536 | } |
537 | |
538 | ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i, AttributeProto::INT) |
539 | ATTR_SETTER_WITH_SINGLE_VALUE(float, f, AttributeProto::FLOAT) |
540 | ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s, AttributeProto::STRING) |
541 | ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(TensorProto, t, AttributeProto::TENSOR) |
542 | ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(GraphProto, g, AttributeProto::GRAPH) |
543 | ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(TypeProto, tp, AttributeProto::TYPE_PROTO) |
544 | ATTR_SETTER_WITH_LIST_VALUE(int64_t, ints, AttributeProto::INTS) |
545 | ATTR_SETTER_WITH_LIST_VALUE(float, floats, AttributeProto::FLOATS) |
546 | ATTR_SETTER_WITH_LIST_COMPLEXVALUE(std::string, strings, AttributeProto::STRINGS) |
547 | ATTR_SETTER_WITH_LIST_COMPLEXVALUE(TensorProto, tensors, AttributeProto::TENSORS) |
548 | ATTR_SETTER_WITH_LIST_COMPLEXVALUE(GraphProto, graphs, AttributeProto::GRAPHS) |
549 | ATTR_SETTER_WITH_LIST_COMPLEXVALUE(TypeProto, type_protos, AttributeProto::TYPE_PROTOS) |
550 | |
551 | OpSchema& OpSchema::AllowUncheckedAttributes() { |
552 | allows_unchecked_attributes_ = true; |
553 | return *this; |
554 | } |
555 | |
556 | OpSchema& OpSchema::Input( |
557 | int n, |
558 | std::string name, |
559 | const std::string& description, |
560 | std::string type_str, |
561 | OpSchema::FormalParameterOption param_option, |
562 | bool is_homogeneous, |
563 | int min_arity, |
564 | DifferentiationCategory differentiation_category) { |
565 | if (int(inputs_.size()) <= n) { |
566 | inputs_.resize(n + 1); |
567 | } |
568 | inputs_[n] = FormalParameter( |
569 | std::move(name), |
570 | #ifndef __ONNX_NO_DOC_STRINGS |
571 | description, |
572 | #else |
573 | std::string(), |
574 | #endif |
575 | std::move(type_str), |
576 | param_option, |
577 | is_homogeneous, |
578 | min_arity, |
579 | differentiation_category); |
580 | return *this; |
581 | } |
582 | |
583 | OpSchema& OpSchema::Input( |
584 | int n, |
585 | const char* name, |
586 | const char* description, |
587 | const char* type_str, |
588 | FormalParameterOption param_option, |
589 | bool is_homogeneous, |
590 | int min_arity, |
591 | DifferentiationCategory differentiation_category) { |
592 | return Input( |
593 | n, |
594 | std::string(name), |
595 | #ifndef __ONNX_NO_DOC_STRINGS |
596 | std::string(description), |
597 | #else |
598 | std::string(), |
599 | #endif |
600 | std::string(type_str), |
601 | param_option, |
602 | is_homogeneous, |
603 | min_arity, |
604 | differentiation_category); |
605 | } |
606 | |
607 | OpSchema& OpSchema::Output( |
608 | int n, |
609 | std::string name, |
610 | const std::string& description, |
611 | std::string type_str, |
612 | OpSchema::FormalParameterOption param_option, |
613 | bool is_homogeneous, |
614 | int min_arity, |
615 | DifferentiationCategory differentiation_category) { |
616 | if (int(outputs_.size()) <= n) { |
617 | outputs_.resize(n + 1); |
618 | } |
619 | outputs_[n] = FormalParameter( |
620 | std::move(name), |
621 | #ifndef __ONNX_NO_DOC_STRINGS |
622 | description, |
623 | #else |
624 | std::string(), |
625 | #endif |
626 | std::move(type_str), |
627 | param_option, |
628 | is_homogeneous, |
629 | min_arity, |
630 | differentiation_category); |
631 | return *this; |
632 | } |
633 | |
634 | OpSchema& OpSchema::Output( |
635 | int n, |
636 | const char* name, |
637 | const char* description, |
638 | const char* type_str, |
639 | FormalParameterOption param_option, |
640 | bool is_homogeneous, |
641 | int min_arity, |
642 | DifferentiationCategory differentiation_category) { |
643 | return Output( |
644 | n, |
645 | std::string(name), |
646 | #ifndef __ONNX_NO_DOC_STRINGS |
647 | std::string(description), |
648 | #else |
649 | std::string(), |
650 | #endif |
651 | std::string(type_str), |
652 | param_option, |
653 | is_homogeneous, |
654 | min_arity, |
655 | differentiation_category); |
656 | } |
657 | |
658 | OpSchema& |
659 | OpSchema::TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description) { |
660 | if (type_constraints_.end() != type_constraints_.find(type_str)) { |
661 | fail_schema("Duplicate type constraint name" ); |
662 | } |
663 | |
664 | DataTypeSet d; |
665 | for (const auto& t : constraints) { |
666 | d.insert(Utils::DataTypeUtils::ToType(t)); |
667 | } |
668 | type_constraints_.insert(std::make_pair(type_str, std::make_pair(d, description))); |
669 | type_constraint_params_.push_back( |
670 | TypeConstraintParam(std::move(type_str), std::move(constraints), std::move(description))); |
671 | return *this; |
672 | } |
673 | |
674 | OpSchema& OpSchema::TypeConstraint( |
675 | const char* type_str, |
676 | std::initializer_list<const char*> constraints, |
677 | const char* description) { |
678 | std::vector<std::string> constraints_vector; |
679 | constraints_vector.reserve(constraints.size()); |
680 | for (auto iter = constraints.begin(); iter != constraints.end(); ++iter) { |
681 | constraints_vector.push_back(*iter); |
682 | } |
683 | |
684 | return TypeConstraint(std::string(type_str), constraints_vector, std::string(description)); |
685 | } |
686 | |
687 | void OpSchema::ParseAndSetTypes( |
688 | /*out*/ std::vector<OpSchema::FormalParameter>* formal_parameters) { |
689 | for (auto& formal_parameter : *formal_parameters) { |
690 | auto& type = formal_parameter.GetTypeStr(); |
691 | DataTypeSet allowed_types; |
692 | auto it = type_constraints_.find(type); |
693 | if (it != type_constraints_.end()) { |
694 | allowed_types = it->second.first; |
695 | } else { |
696 | allowed_types.emplace(Utils::DataTypeUtils::ToType(type)); |
697 | } |
698 | |
699 | formal_parameter.MutableTypes() = allowed_types; |
700 | } |
701 | } |
702 | |
703 | OpSchema& OpSchema::SetContextDependentFunctionBodyBuilder( |
704 | ContextDependentFunctionBodyBuilder functionBuilder, |
705 | int opset_version) { |
706 | if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) { |
707 | opset_version_to_function_builder_[since_version_] = std::move(functionBuilder); |
708 | } else { |
709 | opset_version_to_function_builder_[opset_version] = std::move(functionBuilder); |
710 | } |
711 | return *this; |
712 | } |
713 | |
714 | bool OpSchema::BuildContextDependentFunction( |
715 | const FunctionBodyBuildContext& ctx, |
716 | FunctionProto& function_proto, |
717 | int requested_opset_version) const { |
718 | if (requested_opset_version == OpSchema::kUninitializedSinceVersion) |
719 | requested_opset_version = since_version_; |
720 | |
721 | std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it = |
722 | opset_version_to_function_builder_.upper_bound(requested_opset_version); |
723 | if (opset_version_to_function_builder_.empty() || it == opset_version_to_function_builder_.begin()) { |
724 | ONNX_THROW_EX(std::out_of_range( |
725 | std::string("Cannot find a function builder that satisfies the requested opset version: op_type = " ) + |
726 | this->name_ + ", opset_version = " + std::to_string(requested_opset_version) + "." )); |
727 | } else { |
728 | --it; |
729 | const ContextDependentFunctionBodyBuilder& body_builder = it->second; |
730 | if (!body_builder(ctx, *this, function_proto)) { |
731 | return false; |
732 | } |
733 | //// default opset import may have been added to function_proto by OpSchema::BuildFunction |
734 | //// we need to update its version with the specified opset_version |
735 | UpdateFunctionProtoOpsetImportVersion(function_proto, requested_opset_version); |
736 | ValidateReferencedOpsInFuncton(&function_proto, requested_opset_version, it->first); |
737 | return true; |
738 | } |
739 | } |
740 | |
741 | // A function of a schema (either stored in opset_version_to_function_body_ or built with one of function builder |
742 | // in opset_version_to_function_builder_) has predefined opset_imports. Before returning the function, we shall |
743 | // update the predefined opset_imports so that it is consistent with the requested version. |
744 | // Note that this call only update opset_import of the default domain. |
745 | // TODO: extend this call to work for no-default domains. |
746 | void OpSchema::UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int requested_opset_version) const { |
747 | bool opset_import_exist = false; |
748 | for (int i = 0; i < function_proto.opset_import_size(); i++) { |
749 | auto* schema_opset = function_proto.mutable_opset_import(i); |
750 | if (schema_opset->domain() == domain_) { |
751 | if (schema_opset->version() != requested_opset_version) { |
752 | schema_opset->set_version(requested_opset_version); |
753 | } |
754 | opset_import_exist = true; |
755 | } |
756 | } |
757 | |
758 | if (!opset_import_exist) { |
759 | auto* schema_opset = function_proto.mutable_opset_import()->Add(); |
760 | schema_opset->set_domain(domain_); |
761 | schema_opset->set_version(requested_opset_version); |
762 | } |
763 | } |
764 | |
765 | OpSchema& OpSchema::FunctionBody(const char* func_body, int opset_version) { |
766 | if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) { |
767 | opset_version = since_version_; |
768 | } |
769 | std::shared_ptr<FunctionProto> function_proto(new FunctionProto()); |
770 | OnnxParser parser(func_body); |
771 | auto status = parser.Parse(*function_proto->mutable_node()); |
772 | if (!status.IsOK()) |
773 | ONNX_THROW_EX(std::logic_error("Error parsing function body:" + status.ErrorMessage())); |
774 | if (!parser.EndOfInput()) |
775 | ONNX_THROW_EX(std::logic_error("Extra unparsed input unexpected." )); |
776 | |
777 | // opset import may have been set |
778 | // we may need to update its version with the specified opset_version |
779 | UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version); |
780 | |
781 | opset_version_to_function_body_.insert(std::make_pair(opset_version, function_proto)); |
782 | return *this; |
783 | } |
784 | |
785 | OpSchema& OpSchema::FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version) { |
786 | if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) { |
787 | opset_version = since_version_; |
788 | } |
789 | std::shared_ptr<FunctionProto> function_proto(new FunctionProto()); |
790 | for (const auto& node : func_nodes) { |
791 | auto new_node = function_proto->add_node(); |
792 | new_node->CopyFrom(node); |
793 | } |
794 | |
795 | // opset import may have been set |
796 | // we may need to update its version with the specified opset_version |
797 | UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version); |
798 | opset_version_to_function_body_.insert(std::make_pair(opset_version, function_proto)); |
799 | return *this; |
800 | } |
801 | |
802 | OpSchema& OpSchema::FunctionBody( |
803 | const std::vector<NodeProto>& func_nodes, |
804 | const std::vector<OperatorSetIdProto>& relied_opsets, |
805 | int opset_version) { |
806 | if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) { |
807 | opset_version = since_version_; |
808 | } |
809 | |
810 | std::shared_ptr<FunctionProto> function_proto(new FunctionProto()); |
811 | for (auto& relied_opset : relied_opsets) { |
812 | *(function_proto->mutable_opset_import()->Add()) = relied_opset; |
813 | } |
814 | |
815 | for (const auto& node : func_nodes) { |
816 | auto new_node = function_proto->add_node(); |
817 | new_node->CopyFrom(node); |
818 | } |
819 | // opset import may have been set |
820 | // we may need to update its version with the specified opset_version |
821 | UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version); |
822 | opset_version_to_function_body_.insert(std::make_pair(opset_version, function_proto)); |
823 | return *this; |
824 | } |
825 | |
826 | const FunctionProto* OpSchema::GetFunction(int requested_opset_version, bool validate) const { |
827 | if (requested_opset_version == OpSchema::kUninitializedSinceVersion) |
828 | requested_opset_version = since_version_; |
829 | std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it = |
830 | opset_version_to_function_body_.upper_bound(requested_opset_version); |
831 | if (!opset_version_to_function_body_.empty() && it != opset_version_to_function_body_.begin()) { |
832 | --it; |
833 | int function_since_version = it->first; |
834 | const FunctionProto* function = it->second.get(); |
835 | if (!validate || ValidateReferencedOpsInFuncton(function, requested_opset_version, function_since_version)) { |
836 | return function; |
837 | } |
838 | } |
839 | return nullptr; |
840 | } |
841 | |
842 | // when requesting a function at loading time, |
843 | // requested_opset_version does not have to be the same as function_since_version. |
844 | // When they are not the same, it is necessary to verify that ops used to define the function |
845 | // are not updated between function_since_version and requested_opset_version (include requested_opset_version). |
846 | // this call only validate ops in the default domain. |
847 | // TODO: validate ops in other domains. |
848 | bool OpSchema::ValidateReferencedOpsInFuncton( |
849 | const FunctionProto* function, |
850 | int requested_opset_version, |
851 | int function_since_version, |
852 | std::set<std::string>* updated_ops) const { |
853 | bool all_ops_are_invalid = true; |
854 | if (requested_opset_version == function_since_version) { |
855 | return all_ops_are_invalid; |
856 | } |
857 | for (auto& node : function->node()) { |
858 | if (node.domain() == "" || node.domain() == "ai.onnx" ) { |
859 | const OpSchema* op1 = |
860 | OpSchemaRegistry::Instance()->GetSchema(node.op_type(), requested_opset_version, node.domain()); |
861 | const OpSchema* op2 = |
862 | OpSchemaRegistry::Instance()->GetSchema(node.op_type(), function_since_version, node.domain()); |
863 | if (op1 != op2) { |
864 | if (updated_ops) { |
865 | updated_ops->insert(node.op_type()); |
866 | } |
867 | all_ops_are_invalid = false; |
868 | } |
869 | } |
870 | } |
871 | |
872 | return all_ops_are_invalid; |
873 | } |
874 | |
875 | OpSchema& OpSchema::FillUsing(const std::function<void(OpSchema&)>& populator) { |
876 | if (populator) { |
877 | populator(*this); |
878 | } |
879 | return *this; |
880 | } |
881 | |
882 | void OpSchema::BuildFunction(FunctionProto& function_body) const { |
883 | function_body.set_name(this->name_); |
884 | function_body.set_doc_string(this->doc_); |
885 | function_body.set_domain(this->domain_); |
886 | for (auto& i : inputs_) { |
887 | function_body.add_input(i.GetName()); |
888 | } |
889 | for (auto& o : outputs_) { |
890 | function_body.add_output(o.GetName()); |
891 | } |
892 | for (auto& a : attributes_) { |
893 | function_body.add_attribute(a.first); |
894 | } |
895 | |
896 | // In a typical onnx function where the function and all the |
897 | // ops in function body belong to the same domain we implicitly add |
898 | // {domain_, since_version_} to funciton opset imports if it is not already added. |
899 | // This is simply for convienince. If any of the function body ops do not belong to same |
900 | // domain as function itself, then the function author needs to explicitly add all the relevant |
901 | // opset imports. |
902 | if (function_body.opset_import().size() == 0) { |
903 | auto* schema_opset = function_body.mutable_opset_import()->Add(); |
904 | schema_opset->set_domain(domain_); |
905 | schema_opset->set_version(since_version_); |
906 | } |
907 | } |
908 | |
909 | void OpSchema::Finalize() { |
910 | #define ENFORCE(x) \ |
911 | do { \ |
912 | if (!(x)) \ |
913 | ONNX_THROW_EX(std::logic_error("ONNX Schema " + name_ + ": failed validating the check: " + #x)); \ |
914 | } while (0) |
915 | |
916 | // Calculate min/max number of inputs. |
917 | // <Min number of inputs> = <number of "single" inputs> + <number of |
918 | // "optional" but not trailing inputs>. <Max number of inputs> = <number of |
919 | // all inputs or std::numeric_limits<int>::max() (if the last input is |
920 | // variadic). |
921 | |
922 | // Flag indicates whether an optional input is trailing one (there's no single |
923 | // or variadic input behind). |
924 | for (size_t i = 0; i < inputs_.size(); ++i) { |
925 | switch (inputs_[i].GetOption()) { |
926 | case OpSchema::Single: |
927 | ++max_input_; |
928 | min_input_ = max_input_; |
929 | break; |
930 | case OpSchema::Optional: |
931 | ++max_input_; |
932 | break; |
933 | case OpSchema::Variadic: |
934 | // Only last input formal parameter could be variadic. |
935 | ENFORCE((inputs_.size() - 1) == i); |
936 | min_input_ = max_input_ + inputs_[i].GetMinArity(); |
937 | max_input_ = std::numeric_limits<int>::max(); |
938 | break; |
939 | } |
940 | } |
941 | |
942 | // Calculate min/max number of outputs. |
943 | for (size_t i = 0; i < outputs_.size(); ++i) { |
944 | switch (outputs_[i].GetOption()) { |
945 | case OpSchema::Single: |
946 | ++max_output_; |
947 | min_output_ = max_output_; |
948 | break; |
949 | case OpSchema::Optional: |
950 | ++max_output_; |
951 | break; |
952 | case OpSchema::Variadic: |
953 | // Only last output formal parameter could be variadic. |
954 | ENFORCE((outputs_.size() - 1) == i); |
955 | min_output_ = max_output_ + outputs_[i].GetMinArity(); |
956 | max_output_ = std::numeric_limits<int>::max(); |
957 | break; |
958 | } |
959 | } |
960 | |
961 | // all inputs and outputs have names |
962 | for (const auto& it : inputs_) { |
963 | ENFORCE(!(it.GetName().empty())); |
964 | } |
965 | for (const auto& it : outputs_) { |
966 | ENFORCE(!(it.GetName().empty())); |
967 | } |
968 | |
969 | ParseAndSetTypes(&inputs_); |
970 | ParseAndSetTypes(&outputs_); |
971 | |
972 | for (auto& func : opset_version_to_function_body_) { |
973 | BuildFunction(*func.second); |
974 | } |
975 | } |
976 | |
977 | std::ostream& operator<<(std::ostream& out, const OpSchema& schema) { |
978 | if (!schema.attributes_.empty()) { |
979 | out << "Attributes:" << std::endl; |
980 | for (const auto& pair : schema.attributes_) { |
981 | out << " " << pair.second.name << " : " << pair.second.description << std::endl; |
982 | } |
983 | } |
984 | if (schema.max_input_ > 0) { |
985 | out << "Inputs:" << std::endl; |
986 | if (!schema.inputs_.empty()) { |
987 | for (size_t i = 0; i < schema.inputs_.size(); ++i) { |
988 | const auto& p = schema.inputs_[i]; |
989 | const auto& name = p.GetName(); |
990 | const auto& description = p.GetDescription(); |
991 | const auto& type_str = p.GetTypeStr(); |
992 | out << " " << i << ", " << (!name.empty() ? name : "(unnamed)" ) << " : " |
993 | << (!description.empty() ? description : "(no doc)" ) << " : " |
994 | << (!type_str.empty() ? type_str : "(no type)" ) << std::endl; |
995 | } |
996 | } else { |
997 | out << " (no explicit description available)" << std::endl; |
998 | } |
999 | } |
1000 | if (schema.max_output_ > 0) { |
1001 | out << "Outputs:" << std::endl; |
1002 | if (!schema.outputs_.empty()) { |
1003 | for (size_t i = 0; i < schema.outputs_.size(); ++i) { |
1004 | const auto& p = schema.outputs_[i]; |
1005 | const auto& name = p.GetName(); |
1006 | const auto& description = p.GetDescription(); |
1007 | const auto& type_str = p.GetTypeStr(); |
1008 | out << " " << i << ", " << (!name.empty() ? name : "(unnamed)" ) << " : " |
1009 | << (!description.empty() ? description : "(no doc)" ) << " : " |
1010 | << (!type_str.empty() ? type_str : "(no type)" ) << std::endl; |
1011 | } |
1012 | } else { |
1013 | out << " (no explicit description available)" << std::endl; |
1014 | } |
1015 | } |
1016 | out << std::endl; |
1017 | if (schema.doc()) { |
1018 | out << schema.doc(); |
1019 | } else { |
1020 | out << "(no documentation yet)" << std::endl; |
1021 | } |
1022 | out << std::endl; |
1023 | if (schema.line_) { |
1024 | out << "Defined at " << schema.file_ << ":" << schema.line_ << std::endl; |
1025 | } |
1026 | return out; |
1027 | } |
1028 | |
1029 | OpSchemaRegistry::DomainToVersionRange& OpSchemaRegistry::DomainToVersionRange::Instance() { |
1030 | static DomainToVersionRange domain_to_version_range; |
1031 | return domain_to_version_range; |
1032 | }; |
1033 | |
1034 | // Private method used by OpSchemaRegisterOnce and OpSchemaRegistry::map() |
1035 | OpName_Domain_Version_Schema_Map& OpSchemaRegistry::GetMapWithoutEnsuringRegistration() { |
1036 | static OpName_Domain_Version_Schema_Map map; |
1037 | return map; |
1038 | } |
1039 | |
1040 | OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() { |
1041 | auto& map = GetMapWithoutEnsuringRegistration(); |
1042 | |
1043 | // The following class is used to register operators the |
1044 | // first time this method is called, in a thread-safe fashion. |
1045 | class SchemasRegisterer { |
1046 | public: |
1047 | SchemasRegisterer() { |
1048 | // In debug builds, the number of schema registered in this constructor |
1049 | // is compared against the number of calls to schema registration macros. |
1050 | #ifndef NDEBUG |
1051 | size_t dbg_initial_schema_count = GetRegisteredSchemaCount(); |
1052 | #endif |
1053 | |
1054 | RegisterOnnxOperatorSetSchema(); |
1055 | |
1056 | #ifdef ONNX_ML |
1057 | RegisterOnnxMLOperatorSetSchema(); |
1058 | #endif |
1059 | |
1060 | // Invoke register of training operators. |
1061 | RegisterOnnxTrainingOperatorSetSchema(); |
1062 | |
1063 | // Invoke register of experimental operators. |
1064 | RegisterOnnxPreviewOperatorSetSchema(); |
1065 | |
1066 | #ifndef NDEBUG |
1067 | size_t dbg_registered_schema_count = GetRegisteredSchemaCount() - dbg_initial_schema_count; |
1068 | // Check enabled only if schemas for all opset versions are loaded |
1069 | if (OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0) { |
1070 | ONNX_ASSERTM( |
1071 | dbg_registered_schema_count == ONNX_DBG_GET_COUNT_IN_OPSETS(), |
1072 | "%u schema were exposed from operator sets and automatically placed into the static registry. " |
1073 | "%u were expected based on calls to registration macros. Operator set functions may need to be updated." , |
1074 | dbg_registered_schema_count, |
1075 | ONNX_DBG_GET_COUNT_IN_OPSETS()); |
1076 | } |
1077 | #endif |
1078 | } |
1079 | |
1080 | private: |
1081 | static size_t GetRegisteredSchemaCount() { |
1082 | size_t count = 0; |
1083 | for (auto& x : GetMapWithoutEnsuringRegistration()) { |
1084 | for (auto& y : x.second) { |
1085 | count += y.second.size(); |
1086 | } |
1087 | } |
1088 | return count; |
1089 | } |
1090 | }; |
1091 | |
1092 | #ifndef __ONNX_DISABLE_STATIC_REGISTRATION |
1093 | static SchemasRegisterer schemasRegisterer; |
1094 | #endif |
1095 | |
1096 | return map; |
1097 | } |
1098 | |
1099 | size_t ReplaceAll(std::string& s, const char* from, const char* to) { |
1100 | size_t numReplaced = 0; |
1101 | std::string::size_type lenFrom = std::strlen(from); |
1102 | std::string::size_type lenTo = std::strlen(to); |
1103 | for (std::string::size_type pos = s.find(from); pos != std::string::npos; pos = s.find(from, pos + lenTo)) { |
1104 | s.replace(pos, lenFrom, to); |
1105 | numReplaced++; |
1106 | } |
1107 | return numReplaced; |
1108 | } |
1109 | |
1110 | } // namespace ONNX_NAMESPACE |
1111 | |