1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "onnx/defs/schema.h"
6#include <stdexcept>
7#include <unordered_set>
8#include "onnx/checker.h"
9#include "onnx/defs/operator_sets.h"
10#include "onnx/defs/operator_sets_preview.h"
11#include "onnx/defs/operator_sets_training.h"
12
13#ifdef ONNX_ML
14#include "onnx/defs/operator_sets_ml.h"
15#endif
16
17#include "onnx/common/assertions.h"
18#include "onnx/common/stl_backports.h"
19#include "onnx/defs/parser.h"
20
21namespace ONNX_NAMESPACE {
22// -1 means ONNX schema hasn't been loaded yet
23// 0 means all versions of ONNX schema have been loaded
24// Other positive integer means the ONNX schemas for the specified version have been loaded
25int OpSchemaRegistry::loaded_schema_version = -1;
26
27constexpr int OpSchema::kUninitializedSinceVersion;
28
29// By default if opset_version_to_load=0, it registers all opset schema for all opset versions
30// Otherwise, it only registers the latest schema according to opset_version_to_load
31void RegisterSchema(OpSchema schema, int opset_version_to_load) {
32 OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(schema, opset_version_to_load);
33}
34
35#ifndef NDEBUG
36DbgOperatorSetTracker& DbgOperatorSetTracker::Instance() {
37 static DbgOperatorSetTracker instance;
38 return instance;
39}
40#endif
41
42const std::string& OpSchema::FormalParameter::GetName() const {
43 return name_;
44}
45
46const DataTypeSet& OpSchema::FormalParameter::GetTypes() const {
47 return type_set_;
48}
49
50DataTypeSet& OpSchema::FormalParameter::MutableTypes() {
51 return type_set_;
52}
53
54const std::string& OpSchema::FormalParameter::GetTypeStr() const {
55 return type_str_;
56}
57
58const std::string& OpSchema::FormalParameter::GetDescription() const {
59 return description_;
60}
61
62OpSchema::FormalParameterOption OpSchema::FormalParameter::GetOption() const {
63 return param_option_;
64}
65
66bool OpSchema::FormalParameter::GetIsHomogeneous() const {
67 return is_homogeneous_;
68}
69
70int OpSchema::FormalParameter::GetMinArity() const {
71 return min_arity_;
72}
73
74OpSchema::DifferentiationCategory OpSchema::FormalParameter::GetDifferentiationCategory() const {
75 return differentiation_category_;
76}
77
78OpSchemaRegistry* OpSchemaRegistry::Instance() {
79 static OpSchemaRegistry instance;
80 return &instance;
81}
82
83void OpSchema::CheckInputOutputType(struct InferenceContext& ctx) const {
84 std::unordered_map<std::string, std::string> type_constraints;
85 // check all input types
86 for (size_t in_idx = 0; in_idx < ctx.getNumInputs(); ++in_idx) {
87 // If the last input is Variadic by definition, checker still needs to check the rest of actual input's type
88 const auto& param = (in_idx < inputs_.size()) ? inputs_[in_idx] : inputs_.back();
89 const auto& type_str = param.GetTypeStr();
90 const auto& param_type = ctx.getInputType(in_idx);
91 const auto& all_types = param.GetTypes();
92 if (nullptr == param_type || param_type->value_case() == TypeProto::VALUE_NOT_SET) {
93 continue;
94 } else if (!all_types.empty() && all_types.find(Utils::DataTypeUtils::ToType(*param_type)) == all_types.end()) {
95 fail_check(
96 param.GetName(),
97 " typestr: ",
98 type_str,
99 ", has unsupported type: ",
100 *Utils::DataTypeUtils::ToType(*param_type));
101 }
102 if (param.GetIsHomogeneous()) {
103 const auto& type_proto = Utils::DataTypeUtils::ToType(*param_type);
104 auto p = type_constraints.emplace(type_str, *type_proto);
105 if (!p.second) {
106 // failed to insert a new element due to a duplication, now check consistency
107 if (p.first->second != *type_proto) {
108 fail_check(param.GetName(), " has inconsistent type ", *Utils::DataTypeUtils::ToType(*param_type));
109 }
110 }
111 }
112 } // for inputs
113 // check all output types
114 for (size_t out_idx = 0; out_idx < ctx.getNumOutputs(); ++out_idx) {
115 // If the last output is Variadic by definition, checker still needs to check the rest of actual output's type
116 const auto& param = (out_idx < outputs_.size()) ? outputs_[out_idx] : outputs_.back();
117 const auto& type_str = param.GetTypeStr();
118 const auto& param_type = ctx.getOutputType(out_idx);
119 const auto& all_types = param.GetTypes();
120 bool output_type_found = true;
121 // infer type if necessary
122 if (param_type->value_case() == TypeProto::VALUE_NOT_SET) {
123 if (all_types.size() == 1) {
124 *param_type = Utils::DataTypeUtils::ToTypeProto(*all_types.begin());
125 } else if (type_constraints.find(type_str) != type_constraints.end()) {
126 auto data_type = Utils::DataTypeUtils::ToType(type_constraints[type_str]);
127 *param_type = Utils::DataTypeUtils::ToTypeProto(data_type);
128 } else {
129 output_type_found = false;
130 }
131 }
132 if (!output_type_found) {
133 continue;
134 }
135 if (!all_types.empty() && all_types.find(Utils::DataTypeUtils::ToType(*param_type)) == all_types.end()) {
136 fail_check(param.GetName(), " has unsupported type ", *Utils::DataTypeUtils::ToType(*param_type));
137 }
138 if (param.GetIsHomogeneous()) {
139 const auto& type_proto = Utils::DataTypeUtils::ToType(*param_type);
140 if (type_constraints.find(type_str) == type_constraints.end()) {
141 type_constraints[type_str] = *type_proto;
142 } else if (type_constraints[type_str] != *type_proto) {
143 fail_check(param.GetName(), " has inconsistent type ", *Utils::DataTypeUtils::ToType(*param_type));
144 }
145 } // else
146 } // for outputs
147}
148
149void OpSchema::Verify(const NodeProto& node) const {
150 if (deprecated_) {
151 fail_check("Operator '", name_, "' has been deprecated since version ", since_version_);
152 }
153
154 // Check the number of inputs.
155 if (node.input_size() < min_input_ || node.input_size() > max_input_) {
156 fail_check(
157 "Node (",
158 node.name(),
159 ") has input size ",
160 node.input_size(),
161 " not in range [min=",
162 min_input_,
163 ", max=",
164 max_input_,
165 "].");
166 }
167
168 if (!num_inputs_allowed_(node.input_size())) {
169 fail_check("Node (", node.name(), ") has input size ", node.input_size(), " not in allowed input sizes.");
170 }
171
172 // Check the number of outputs.
173 if (node.output_size() < min_output_ || node.output_size() > max_output_) {
174 fail_check(
175 "Node (",
176 node.name(),
177 ") has output size ",
178 node.output_size(),
179 " not in range [min=",
180 min_output_,
181 ", max=",
182 max_output_,
183 "].");
184 }
185
186 if (!num_outputs_allowed_(node.output_size())) {
187 fail_check("Node (", node.name(), "has output size ", node.output_size(), " not in allowed output sizes.");
188 }
189
190 // Check the values of inputs / outputs
191 for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) {
192 if (in_idx >= static_cast<int>(inputs_.size())) {
193 if (!inputs_.empty() && Variadic == inputs_.back().GetOption()) {
194 // The last input formal parameter should be variadic.
195 break;
196 } else {
197 fail_check(
198 "Node (",
199 node.name(),
200 ") has more inputs (",
201 node.input_size(),
202 ") than declared (",
203 inputs_.size(),
204 ") in op definition.");
205 }
206 }
207 if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) {
208 fail_check("Node (", node.name(), ")'s input ", in_idx, " is marked single but has an empty string in the graph");
209 }
210 }
211
212 for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) {
213 if (out_idx >= static_cast<int>(outputs_.size())) {
214 if (!outputs_.empty() && Variadic == outputs_.back().GetOption()) {
215 // The last output formal parameter should be variadic.
216 break;
217 } else {
218 fail_check(
219 "Node (",
220 node.name(),
221 ") has more outputs (",
222 node.output_size(),
223 ") than declared (",
224 outputs_.size(),
225 ") in op definition.");
226 }
227 }
228
229 if (node.output(out_idx).empty() && (Single == outputs_[out_idx].GetOption())) {
230 fail_check(
231 "Node (", node.name(), ")'s output ", out_idx, " is marked single but has an empty string in the graph");
232 }
233 }
234
235 // An internal symbol is defined as starting with two underscores. Attributes
236 // with names meeting this condition are considered implementation details
237 // and should be ignored for the purpose of schema checking.
238 auto isInternalSymbol = [](const std::string& sym) -> bool {
239 return sym.length() >= 2 && sym[0] == '_' && sym[1] == '_';
240 };
241
242 // Check attributes
243 std::unordered_set<std::string> seen_attr_names{};
244 for (const auto& attr_proto : node.attribute()) {
245 const auto& name = attr_proto.name();
246
247 if (!seen_attr_names.insert(name).second) {
248 fail_check("Attribute '", name, "' appeared multiple times.");
249 };
250
251 const auto& search = attributes_.find(name);
252 AttributeProto::AttributeType expected_type;
253 if (search != attributes_.end()) {
254 expected_type = search->second.type;
255 } else if (allows_unchecked_attributes_ || isInternalSymbol(name)) {
256 continue;
257 } else {
258 fail_check("Unrecognized attribute: ", name, " for operator ", node.op_type());
259 }
260
261 // Type would be UNDEFINED if not set
262 if (attr_proto.type() != expected_type) {
263 fail_check("Mismatched attribute type in '", node.name() + " : " + name, "'");
264 }
265
266 // ref_attr_name is only valid when non-empty
267 // we simply read default value if not present
268 if (!attr_proto.ref_attr_name().empty()) {
269 continue;
270 }
271
272 switch (expected_type) {
273 // if attr_proto().type() != UNDEFINED
274 // we consider primitive types to be set even
275 // if proto3 did not output default values into the stream
276 // in which case we will read the default
277 case AttributeProto::FLOAT:
278 case AttributeProto::INT:
279 case AttributeProto::STRING:
280 break;
281 case AttributeProto::TENSOR:
282 if (!attr_proto.has_t()) {
283 fail_check("Attribute '", name, "' is expected to have field 't'");
284 }
285 break;
286 case AttributeProto::SPARSE_TENSOR:
287 if (!attr_proto.has_sparse_tensor()) {
288 fail_check("Attribute '", name, "' is expected to have field 'sparse_tensor'");
289 }
290 break;
291 case AttributeProto::GRAPH:
292 if (!attr_proto.has_g()) {
293 fail_check("Attribute '", name, "' is expected to have field 'g'");
294 }
295 break;
296 case AttributeProto::TYPE_PROTO:
297 if (!attr_proto.has_tp()) {
298 fail_check("Attribute '", name, "' is expected to have field 'type_proto'");
299 }
300 break;
301 case AttributeProto::FLOATS:
302 if (!attr_proto.floats_size()) {
303 fail_check("Attribute '", name, "' is expected to have field 'floats'");
304 }
305 break;
306 case AttributeProto::INTS:
307 if (!attr_proto.ints_size()) {
308 fail_check("Attribute '", name, "' is expected to have field 'ints'");
309 }
310 break;
311 case AttributeProto::STRINGS:
312 if (!attr_proto.strings_size()) {
313 fail_check("Attribute '", name, "' is expected to have field 'strings'");
314 }
315 break;
316 case AttributeProto::TENSORS:
317 if (!attr_proto.tensors_size()) {
318 fail_check("Attribute '", name, "' is expected to have field 'tensors'");
319 }
320 break;
321 case AttributeProto::SPARSE_TENSORS:
322 // Not adding check ... we should likely delete the check in all other
323 // cases, which will not allow us to have an empty list as a valid value
324 // for an attribute and this seems undesirable.
325 break;
326 case AttributeProto::GRAPHS:
327 if (!attr_proto.graphs_size()) {
328 fail_check("Attribute '", name, "' is expected to have field 'graphs'");
329 }
330 break;
331 case AttributeProto::TYPE_PROTOS:
332 if (!attr_proto.type_protos_size()) {
333 fail_check("Attribute '", name, "' is expected to have field 'type_protos'");
334 }
335 break;
336 default:
337 fail_check("Attribute '", name, " has unknown expected type");
338 }
339 }
340 for (const auto& pair : attributes_) {
341 const auto& attr = pair.second;
342 if (!attr.required) {
343 continue;
344 }
345 if (!seen_attr_names.count(attr.name)) {
346 fail_check("Required attribute '", attr.name, "' is missing.");
347 }
348 }
349
350 // Phew. All verifications passed.
351}
352
353OpSchema& OpSchema::SinceVersion(OperatorSetVersion v) {
354 since_version_ = v;
355
356 // SinceVersion is called after FunctionBody and SetContextDependentFunctionBodyBuilder are called
357 // when defining a op.
358 // FunctionBody() and SetContextDependentFunctionBodyBuilder() use -1 as the default opset_version
359 // default opset_version is for a FunctionProto of the same opset_version as the op's since_version_.
360 // It is indexed with -1 so we need to reindex it with since_version_.
361 //
362 // FunctionProtos of non-default opset_versions are for models whose opset version is higher than the op's
363 // opset version such that ops used in the default function_proto are no longer valid. For example:
364 // A model of opset version 18 contains a LayerNormalization op.
365 // LayerNormalization is function op whese function body uses ReduceMean op.
366 // LayerNormalization's since_version is 17 thus it is good for the model of opset 18.
367 // however, if a runtime needs to inline LayerNormalization, the inlined model has a ReduceMean op.
368 // ReduceMean in opset 18 is different from opset 17.
369 // This requires us to define more than one function body
370 std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it =
371 opset_version_to_function_builder_.find(OpSchema::kUninitializedSinceVersion);
372
373 if (it != opset_version_to_function_builder_.cend()) {
374 opset_version_to_function_builder_[since_version_] = it->second;
375 opset_version_to_function_builder_.erase(it);
376 }
377
378 std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it_function_body =
379 opset_version_to_function_body_.find(OpSchema::kUninitializedSinceVersion);
380 if (it_function_body != opset_version_to_function_body_.cend()) {
381 opset_version_to_function_body_[since_version_] = it_function_body->second;
382 UpdateFunctionProtoOpsetImportVersion(*opset_version_to_function_body_[since_version_], since_version_);
383 opset_version_to_function_body_.erase(it_function_body);
384 }
385
386 return *this;
387}
388
389OpSchema& OpSchema::Deprecate() {
390 deprecated_ = true;
391 return *this;
392}
393
394OpSchema& OpSchema::NumInputs(std::set<int> allowed_input_nums) {
395 num_inputs_allowed_ = [MOVE_CAPTURE_IF_CPP14(allowed_input_nums)](int n) -> bool {
396 return allowed_input_nums.count(n);
397 };
398 return *this;
399}
400
401OpSchema& OpSchema::NumOutputs(std::set<int> allowed_output_nums) {
402 num_outputs_allowed_ = [MOVE_CAPTURE_IF_CPP14(allowed_output_nums)](int n) -> bool {
403 return allowed_output_nums.count(n) > 0;
404 };
405 return *this;
406}
407
408OpSchema& OpSchema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction) {
409 tensor_inference_function_ = std::move(inferenceFunction);
410 return *this;
411}
412
413OpSchema& OpSchema::PartialDataPropagationFunction(DataPropagationFunction dataPropagationFunction) {
414 data_propagation_function_ = std::move(dataPropagationFunction);
415 return *this;
416}
417
418OpSchema& OpSchema::SetSupportLevel(SupportType support) {
419 support_ = support;
420 return *this;
421}
422
423// Functions to specify name for the operator schema.
424OpSchema& OpSchema::SetName(std::string name) {
425 name_ = std::move(name);
426 return *this;
427}
428
429OpSchema& OpSchema::SetName(const char* name) {
430 return SetName(std::string(name));
431}
432
433// Functions to specify code location for the operator schema.
434OpSchema& OpSchema::SetLocation(std::string file, int line) {
435 file_ = std::move(file);
436 line_ = line;
437 return *this;
438}
439
440OpSchema& OpSchema::SetLocation(const char* file, int line) {
441 return SetLocation(std::string(file), line);
442}
443
444OpSchema& OpSchema::SetDomain(std::string domain) {
445 domain_ = std::move(domain);
446 return *this;
447}
448
449OpSchema& OpSchema::SetDomain(const char* domain) {
450 return SetDomain(std::string(domain));
451}
452
453OpSchema& OpSchema::Attr(Attribute attr) {
454 auto name = attr.name; // copy name so we can move attr in the next line
455 attributes_.insert(std::make_pair(std::move(name), std::move(attr)));
456 return *this;
457}
458
459OpSchema& OpSchema::Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required) {
460 Attr(Attribute{std::move(name), std::move(description), type, required});
461 return *this;
462}
463
464OpSchema& OpSchema::Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required) {
465 return Attr(std::string(name), std::string(description), type, required);
466}
467
468#define ATTR_SETTER_WITH_SINGLE_VALUE(type, field, attrtype) \
469 OpSchema& OpSchema::Attr( \
470 std::string name, std::string description, AttributeProto::AttributeType attr_type, const type& default_value) { \
471 if (attrtype != attr_type) { \
472 fail_schema("Attribute specification type mismatch."); \
473 } \
474 AttributeProto a; \
475 a.set_name(name); \
476 a.set_##field(default_value); \
477 a.set_type(attr_type); \
478 Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
479 return *this; \
480 } \
481 OpSchema& OpSchema::Attr( \
482 const char* name, const char* description, AttributeProto::AttributeType attr_type, const type& default_value) { \
483 return Attr(std::string(name), std::string(description), attr_type, default_value); \
484 }
485
486#define ATTR_SETTER_WITH_LIST_VALUE(type, field, attrtype) \
487 OpSchema& OpSchema::Attr( \
488 std::string name, \
489 std::string description, \
490 AttributeProto::AttributeType attr_type, \
491 const std::vector<type>& default_value) { \
492 if (attrtype != attr_type) { \
493 fail_schema("Attribute specification type mismatch."); \
494 } \
495 AttributeProto a; \
496 a.set_name(name); \
497 a.set_type(attr_type); \
498 for (const auto& v : default_value) { \
499 a.add_##field(v); \
500 } \
501 Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
502 return *this; \
503 }
504
505#define ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(type, field, attrtype) \
506 OpSchema& OpSchema::Attr( \
507 std::string name, std::string description, AttributeProto::AttributeType attr_type, const type& default_value) { \
508 if (attrtype != attr_type) { \
509 fail_schema("Attribute specification type mismatch."); \
510 } \
511 AttributeProto a; \
512 a.set_name(name); \
513 *(a.mutable_##field()) = default_value; \
514 a.set_type(attr_type); \
515 Attr(Attribute(std::move(name), std::move(description), a)); \
516 return *this; \
517 }
518
519#define ATTR_SETTER_WITH_LIST_COMPLEXVALUE(type, field, attrtype) \
520 OpSchema& OpSchema::Attr( \
521 std::string name, \
522 std::string description, \
523 AttributeProto::AttributeType attr_type, \
524 const std::vector<type>& default_value) { \
525 if (attrtype != attr_type) { \
526 fail_schema("Attribute specification type mismatch."); \
527 } \
528 AttributeProto a; \
529 a.set_name(name); \
530 a.set_type(attr_type); \
531 for (const auto& v : default_value) { \
532 *(a.add_##field()) = v; \
533 } \
534 Attr(Attribute(std::move(name), std::move(description), std::move(a))); \
535 return *this; \
536 }
537
538ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i, AttributeProto::INT)
539ATTR_SETTER_WITH_SINGLE_VALUE(float, f, AttributeProto::FLOAT)
540ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s, AttributeProto::STRING)
541ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(TensorProto, t, AttributeProto::TENSOR)
542ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(GraphProto, g, AttributeProto::GRAPH)
543ATTR_SETTER_WITH_SINGLE_COMPLEXVALUE(TypeProto, tp, AttributeProto::TYPE_PROTO)
544ATTR_SETTER_WITH_LIST_VALUE(int64_t, ints, AttributeProto::INTS)
545ATTR_SETTER_WITH_LIST_VALUE(float, floats, AttributeProto::FLOATS)
546ATTR_SETTER_WITH_LIST_COMPLEXVALUE(std::string, strings, AttributeProto::STRINGS)
547ATTR_SETTER_WITH_LIST_COMPLEXVALUE(TensorProto, tensors, AttributeProto::TENSORS)
548ATTR_SETTER_WITH_LIST_COMPLEXVALUE(GraphProto, graphs, AttributeProto::GRAPHS)
549ATTR_SETTER_WITH_LIST_COMPLEXVALUE(TypeProto, type_protos, AttributeProto::TYPE_PROTOS)
550
551OpSchema& OpSchema::AllowUncheckedAttributes() {
552 allows_unchecked_attributes_ = true;
553 return *this;
554}
555
556OpSchema& OpSchema::Input(
557 int n,
558 std::string name,
559 const std::string& description,
560 std::string type_str,
561 OpSchema::FormalParameterOption param_option,
562 bool is_homogeneous,
563 int min_arity,
564 DifferentiationCategory differentiation_category) {
565 if (int(inputs_.size()) <= n) {
566 inputs_.resize(n + 1);
567 }
568 inputs_[n] = FormalParameter(
569 std::move(name),
570#ifndef __ONNX_NO_DOC_STRINGS
571 description,
572#else
573 std::string(),
574#endif
575 std::move(type_str),
576 param_option,
577 is_homogeneous,
578 min_arity,
579 differentiation_category);
580 return *this;
581}
582
583OpSchema& OpSchema::Input(
584 int n,
585 const char* name,
586 const char* description,
587 const char* type_str,
588 FormalParameterOption param_option,
589 bool is_homogeneous,
590 int min_arity,
591 DifferentiationCategory differentiation_category) {
592 return Input(
593 n,
594 std::string(name),
595#ifndef __ONNX_NO_DOC_STRINGS
596 std::string(description),
597#else
598 std::string(),
599#endif
600 std::string(type_str),
601 param_option,
602 is_homogeneous,
603 min_arity,
604 differentiation_category);
605}
606
607OpSchema& OpSchema::Output(
608 int n,
609 std::string name,
610 const std::string& description,
611 std::string type_str,
612 OpSchema::FormalParameterOption param_option,
613 bool is_homogeneous,
614 int min_arity,
615 DifferentiationCategory differentiation_category) {
616 if (int(outputs_.size()) <= n) {
617 outputs_.resize(n + 1);
618 }
619 outputs_[n] = FormalParameter(
620 std::move(name),
621#ifndef __ONNX_NO_DOC_STRINGS
622 description,
623#else
624 std::string(),
625#endif
626 std::move(type_str),
627 param_option,
628 is_homogeneous,
629 min_arity,
630 differentiation_category);
631 return *this;
632}
633
634OpSchema& OpSchema::Output(
635 int n,
636 const char* name,
637 const char* description,
638 const char* type_str,
639 FormalParameterOption param_option,
640 bool is_homogeneous,
641 int min_arity,
642 DifferentiationCategory differentiation_category) {
643 return Output(
644 n,
645 std::string(name),
646#ifndef __ONNX_NO_DOC_STRINGS
647 std::string(description),
648#else
649 std::string(),
650#endif
651 std::string(type_str),
652 param_option,
653 is_homogeneous,
654 min_arity,
655 differentiation_category);
656}
657
658OpSchema&
659OpSchema::TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description) {
660 if (type_constraints_.end() != type_constraints_.find(type_str)) {
661 fail_schema("Duplicate type constraint name");
662 }
663
664 DataTypeSet d;
665 for (const auto& t : constraints) {
666 d.insert(Utils::DataTypeUtils::ToType(t));
667 }
668 type_constraints_.insert(std::make_pair(type_str, std::make_pair(d, description)));
669 type_constraint_params_.push_back(
670 TypeConstraintParam(std::move(type_str), std::move(constraints), std::move(description)));
671 return *this;
672}
673
674OpSchema& OpSchema::TypeConstraint(
675 const char* type_str,
676 std::initializer_list<const char*> constraints,
677 const char* description) {
678 std::vector<std::string> constraints_vector;
679 constraints_vector.reserve(constraints.size());
680 for (auto iter = constraints.begin(); iter != constraints.end(); ++iter) {
681 constraints_vector.push_back(*iter);
682 }
683
684 return TypeConstraint(std::string(type_str), constraints_vector, std::string(description));
685}
686
687void OpSchema::ParseAndSetTypes(
688 /*out*/ std::vector<OpSchema::FormalParameter>* formal_parameters) {
689 for (auto& formal_parameter : *formal_parameters) {
690 auto& type = formal_parameter.GetTypeStr();
691 DataTypeSet allowed_types;
692 auto it = type_constraints_.find(type);
693 if (it != type_constraints_.end()) {
694 allowed_types = it->second.first;
695 } else {
696 allowed_types.emplace(Utils::DataTypeUtils::ToType(type));
697 }
698
699 formal_parameter.MutableTypes() = allowed_types;
700 }
701}
702
703OpSchema& OpSchema::SetContextDependentFunctionBodyBuilder(
704 ContextDependentFunctionBodyBuilder functionBuilder,
705 int opset_version) {
706 if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) {
707 opset_version_to_function_builder_[since_version_] = std::move(functionBuilder);
708 } else {
709 opset_version_to_function_builder_[opset_version] = std::move(functionBuilder);
710 }
711 return *this;
712}
713
714bool OpSchema::BuildContextDependentFunction(
715 const FunctionBodyBuildContext& ctx,
716 FunctionProto& function_proto,
717 int requested_opset_version) const {
718 if (requested_opset_version == OpSchema::kUninitializedSinceVersion)
719 requested_opset_version = since_version_;
720
721 std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it =
722 opset_version_to_function_builder_.upper_bound(requested_opset_version);
723 if (opset_version_to_function_builder_.empty() || it == opset_version_to_function_builder_.begin()) {
724 ONNX_THROW_EX(std::out_of_range(
725 std::string("Cannot find a function builder that satisfies the requested opset version: op_type = ") +
726 this->name_ + ", opset_version = " + std::to_string(requested_opset_version) + "."));
727 } else {
728 --it;
729 const ContextDependentFunctionBodyBuilder& body_builder = it->second;
730 if (!body_builder(ctx, *this, function_proto)) {
731 return false;
732 }
733 //// default opset import may have been added to function_proto by OpSchema::BuildFunction
734 //// we need to update its version with the specified opset_version
735 UpdateFunctionProtoOpsetImportVersion(function_proto, requested_opset_version);
736 ValidateReferencedOpsInFuncton(&function_proto, requested_opset_version, it->first);
737 return true;
738 }
739}
740
741// A function of a schema (either stored in opset_version_to_function_body_ or built with one of function builder
742// in opset_version_to_function_builder_) has predefined opset_imports. Before returning the function, we shall
743// update the predefined opset_imports so that it is consistent with the requested version.
744// Note that this call only update opset_import of the default domain.
745// TODO: extend this call to work for no-default domains.
746void OpSchema::UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int requested_opset_version) const {
747 bool opset_import_exist = false;
748 for (int i = 0; i < function_proto.opset_import_size(); i++) {
749 auto* schema_opset = function_proto.mutable_opset_import(i);
750 if (schema_opset->domain() == domain_) {
751 if (schema_opset->version() != requested_opset_version) {
752 schema_opset->set_version(requested_opset_version);
753 }
754 opset_import_exist = true;
755 }
756 }
757
758 if (!opset_import_exist) {
759 auto* schema_opset = function_proto.mutable_opset_import()->Add();
760 schema_opset->set_domain(domain_);
761 schema_opset->set_version(requested_opset_version);
762 }
763}
764
765OpSchema& OpSchema::FunctionBody(const char* func_body, int opset_version) {
766 if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) {
767 opset_version = since_version_;
768 }
769 std::shared_ptr<FunctionProto> function_proto(new FunctionProto());
770 OnnxParser parser(func_body);
771 auto status = parser.Parse(*function_proto->mutable_node());
772 if (!status.IsOK())
773 ONNX_THROW_EX(std::logic_error("Error parsing function body:" + status.ErrorMessage()));
774 if (!parser.EndOfInput())
775 ONNX_THROW_EX(std::logic_error("Extra unparsed input unexpected."));
776
777 // opset import may have been set
778 // we may need to update its version with the specified opset_version
779 UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version);
780
781 opset_version_to_function_body_.insert(std::make_pair(opset_version, function_proto));
782 return *this;
783}
784
785OpSchema& OpSchema::FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version) {
786 if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) {
787 opset_version = since_version_;
788 }
789 std::shared_ptr<FunctionProto> function_proto(new FunctionProto());
790 for (const auto& node : func_nodes) {
791 auto new_node = function_proto->add_node();
792 new_node->CopyFrom(node);
793 }
794
795 // opset import may have been set
796 // we may need to update its version with the specified opset_version
797 UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version);
798 opset_version_to_function_body_.insert(std::make_pair(opset_version, function_proto));
799 return *this;
800}
801
802OpSchema& OpSchema::FunctionBody(
803 const std::vector<NodeProto>& func_nodes,
804 const std::vector<OperatorSetIdProto>& relied_opsets,
805 int opset_version) {
806 if (opset_version == OpSchema::kUninitializedSinceVersion && since_version_ != OpSchema::kUninitializedSinceVersion) {
807 opset_version = since_version_;
808 }
809
810 std::shared_ptr<FunctionProto> function_proto(new FunctionProto());
811 for (auto& relied_opset : relied_opsets) {
812 *(function_proto->mutable_opset_import()->Add()) = relied_opset;
813 }
814
815 for (const auto& node : func_nodes) {
816 auto new_node = function_proto->add_node();
817 new_node->CopyFrom(node);
818 }
819 // opset import may have been set
820 // we may need to update its version with the specified opset_version
821 UpdateFunctionProtoOpsetImportVersion(*function_proto, opset_version);
822 opset_version_to_function_body_.insert(std::make_pair(opset_version, function_proto));
823 return *this;
824}
825
826const FunctionProto* OpSchema::GetFunction(int requested_opset_version, bool validate) const {
827 if (requested_opset_version == OpSchema::kUninitializedSinceVersion)
828 requested_opset_version = since_version_;
829 std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it =
830 opset_version_to_function_body_.upper_bound(requested_opset_version);
831 if (!opset_version_to_function_body_.empty() && it != opset_version_to_function_body_.begin()) {
832 --it;
833 int function_since_version = it->first;
834 const FunctionProto* function = it->second.get();
835 if (!validate || ValidateReferencedOpsInFuncton(function, requested_opset_version, function_since_version)) {
836 return function;
837 }
838 }
839 return nullptr;
840}
841
842// when requesting a function at loading time,
843// requested_opset_version does not have to be the same as function_since_version.
844// When they are not the same, it is necessary to verify that ops used to define the function
845// are not updated between function_since_version and requested_opset_version (include requested_opset_version).
846// this call only validate ops in the default domain.
847// TODO: validate ops in other domains.
848bool OpSchema::ValidateReferencedOpsInFuncton(
849 const FunctionProto* function,
850 int requested_opset_version,
851 int function_since_version,
852 std::set<std::string>* updated_ops) const {
853 bool all_ops_are_invalid = true;
854 if (requested_opset_version == function_since_version) {
855 return all_ops_are_invalid;
856 }
857 for (auto& node : function->node()) {
858 if (node.domain() == "" || node.domain() == "ai.onnx") {
859 const OpSchema* op1 =
860 OpSchemaRegistry::Instance()->GetSchema(node.op_type(), requested_opset_version, node.domain());
861 const OpSchema* op2 =
862 OpSchemaRegistry::Instance()->GetSchema(node.op_type(), function_since_version, node.domain());
863 if (op1 != op2) {
864 if (updated_ops) {
865 updated_ops->insert(node.op_type());
866 }
867 all_ops_are_invalid = false;
868 }
869 }
870 }
871
872 return all_ops_are_invalid;
873}
874
875OpSchema& OpSchema::FillUsing(const std::function<void(OpSchema&)>& populator) {
876 if (populator) {
877 populator(*this);
878 }
879 return *this;
880}
881
882void OpSchema::BuildFunction(FunctionProto& function_body) const {
883 function_body.set_name(this->name_);
884 function_body.set_doc_string(this->doc_);
885 function_body.set_domain(this->domain_);
886 for (auto& i : inputs_) {
887 function_body.add_input(i.GetName());
888 }
889 for (auto& o : outputs_) {
890 function_body.add_output(o.GetName());
891 }
892 for (auto& a : attributes_) {
893 function_body.add_attribute(a.first);
894 }
895
896 // In a typical onnx function where the function and all the
897 // ops in function body belong to the same domain we implicitly add
898 // {domain_, since_version_} to funciton opset imports if it is not already added.
899 // This is simply for convienince. If any of the function body ops do not belong to same
900 // domain as function itself, then the function author needs to explicitly add all the relevant
901 // opset imports.
902 if (function_body.opset_import().size() == 0) {
903 auto* schema_opset = function_body.mutable_opset_import()->Add();
904 schema_opset->set_domain(domain_);
905 schema_opset->set_version(since_version_);
906 }
907}
908
909void OpSchema::Finalize() {
910#define ENFORCE(x) \
911 do { \
912 if (!(x)) \
913 ONNX_THROW_EX(std::logic_error("ONNX Schema " + name_ + ": failed validating the check: " + #x)); \
914 } while (0)
915
916 // Calculate min/max number of inputs.
917 // <Min number of inputs> = <number of "single" inputs> + <number of
918 // "optional" but not trailing inputs>. <Max number of inputs> = <number of
919 // all inputs or std::numeric_limits<int>::max() (if the last input is
920 // variadic).
921
922 // Flag indicates whether an optional input is trailing one (there's no single
923 // or variadic input behind).
924 for (size_t i = 0; i < inputs_.size(); ++i) {
925 switch (inputs_[i].GetOption()) {
926 case OpSchema::Single:
927 ++max_input_;
928 min_input_ = max_input_;
929 break;
930 case OpSchema::Optional:
931 ++max_input_;
932 break;
933 case OpSchema::Variadic:
934 // Only last input formal parameter could be variadic.
935 ENFORCE((inputs_.size() - 1) == i);
936 min_input_ = max_input_ + inputs_[i].GetMinArity();
937 max_input_ = std::numeric_limits<int>::max();
938 break;
939 }
940 }
941
942 // Calculate min/max number of outputs.
943 for (size_t i = 0; i < outputs_.size(); ++i) {
944 switch (outputs_[i].GetOption()) {
945 case OpSchema::Single:
946 ++max_output_;
947 min_output_ = max_output_;
948 break;
949 case OpSchema::Optional:
950 ++max_output_;
951 break;
952 case OpSchema::Variadic:
953 // Only last output formal parameter could be variadic.
954 ENFORCE((outputs_.size() - 1) == i);
955 min_output_ = max_output_ + outputs_[i].GetMinArity();
956 max_output_ = std::numeric_limits<int>::max();
957 break;
958 }
959 }
960
961 // all inputs and outputs have names
962 for (const auto& it : inputs_) {
963 ENFORCE(!(it.GetName().empty()));
964 }
965 for (const auto& it : outputs_) {
966 ENFORCE(!(it.GetName().empty()));
967 }
968
969 ParseAndSetTypes(&inputs_);
970 ParseAndSetTypes(&outputs_);
971
972 for (auto& func : opset_version_to_function_body_) {
973 BuildFunction(*func.second);
974 }
975}
976
977std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
978 if (!schema.attributes_.empty()) {
979 out << "Attributes:" << std::endl;
980 for (const auto& pair : schema.attributes_) {
981 out << " " << pair.second.name << " : " << pair.second.description << std::endl;
982 }
983 }
984 if (schema.max_input_ > 0) {
985 out << "Inputs:" << std::endl;
986 if (!schema.inputs_.empty()) {
987 for (size_t i = 0; i < schema.inputs_.size(); ++i) {
988 const auto& p = schema.inputs_[i];
989 const auto& name = p.GetName();
990 const auto& description = p.GetDescription();
991 const auto& type_str = p.GetTypeStr();
992 out << " " << i << ", " << (!name.empty() ? name : "(unnamed)") << " : "
993 << (!description.empty() ? description : "(no doc)") << " : "
994 << (!type_str.empty() ? type_str : "(no type)") << std::endl;
995 }
996 } else {
997 out << " (no explicit description available)" << std::endl;
998 }
999 }
1000 if (schema.max_output_ > 0) {
1001 out << "Outputs:" << std::endl;
1002 if (!schema.outputs_.empty()) {
1003 for (size_t i = 0; i < schema.outputs_.size(); ++i) {
1004 const auto& p = schema.outputs_[i];
1005 const auto& name = p.GetName();
1006 const auto& description = p.GetDescription();
1007 const auto& type_str = p.GetTypeStr();
1008 out << " " << i << ", " << (!name.empty() ? name : "(unnamed)") << " : "
1009 << (!description.empty() ? description : "(no doc)") << " : "
1010 << (!type_str.empty() ? type_str : "(no type)") << std::endl;
1011 }
1012 } else {
1013 out << " (no explicit description available)" << std::endl;
1014 }
1015 }
1016 out << std::endl;
1017 if (schema.doc()) {
1018 out << schema.doc();
1019 } else {
1020 out << "(no documentation yet)" << std::endl;
1021 }
1022 out << std::endl;
1023 if (schema.line_) {
1024 out << "Defined at " << schema.file_ << ":" << schema.line_ << std::endl;
1025 }
1026 return out;
1027}
1028
1029OpSchemaRegistry::DomainToVersionRange& OpSchemaRegistry::DomainToVersionRange::Instance() {
1030 static DomainToVersionRange domain_to_version_range;
1031 return domain_to_version_range;
1032};
1033
1034// Private method used by OpSchemaRegisterOnce and OpSchemaRegistry::map()
1035OpName_Domain_Version_Schema_Map& OpSchemaRegistry::GetMapWithoutEnsuringRegistration() {
1036 static OpName_Domain_Version_Schema_Map map;
1037 return map;
1038}
1039
1040OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() {
1041 auto& map = GetMapWithoutEnsuringRegistration();
1042
1043 // The following class is used to register operators the
1044 // first time this method is called, in a thread-safe fashion.
1045 class SchemasRegisterer {
1046 public:
1047 SchemasRegisterer() {
1048 // In debug builds, the number of schema registered in this constructor
1049 // is compared against the number of calls to schema registration macros.
1050#ifndef NDEBUG
1051 size_t dbg_initial_schema_count = GetRegisteredSchemaCount();
1052#endif
1053
1054 RegisterOnnxOperatorSetSchema();
1055
1056#ifdef ONNX_ML
1057 RegisterOnnxMLOperatorSetSchema();
1058#endif
1059
1060 // Invoke register of training operators.
1061 RegisterOnnxTrainingOperatorSetSchema();
1062
1063 // Invoke register of experimental operators.
1064 RegisterOnnxPreviewOperatorSetSchema();
1065
1066#ifndef NDEBUG
1067 size_t dbg_registered_schema_count = GetRegisteredSchemaCount() - dbg_initial_schema_count;
1068 // Check enabled only if schemas for all opset versions are loaded
1069 if (OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0) {
1070 ONNX_ASSERTM(
1071 dbg_registered_schema_count == ONNX_DBG_GET_COUNT_IN_OPSETS(),
1072 "%u schema were exposed from operator sets and automatically placed into the static registry. "
1073 "%u were expected based on calls to registration macros. Operator set functions may need to be updated.",
1074 dbg_registered_schema_count,
1075 ONNX_DBG_GET_COUNT_IN_OPSETS());
1076 }
1077#endif
1078 }
1079
1080 private:
1081 static size_t GetRegisteredSchemaCount() {
1082 size_t count = 0;
1083 for (auto& x : GetMapWithoutEnsuringRegistration()) {
1084 for (auto& y : x.second) {
1085 count += y.second.size();
1086 }
1087 }
1088 return count;
1089 }
1090 };
1091
1092#ifndef __ONNX_DISABLE_STATIC_REGISTRATION
1093 static SchemasRegisterer schemasRegisterer;
1094#endif
1095
1096 return map;
1097}
1098
1099size_t ReplaceAll(std::string& s, const char* from, const char* to) {
1100 size_t numReplaced = 0;
1101 std::string::size_type lenFrom = std::strlen(from);
1102 std::string::size_type lenTo = std::strlen(to);
1103 for (std::string::size_type pos = s.find(from); pos != std::string::npos; pos = s.find(from, pos + lenTo)) {
1104 s.replace(pos, lenFrom, to);
1105 numReplaced++;
1106 }
1107 return numReplaced;
1108}
1109
1110} // namespace ONNX_NAMESPACE
1111