1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #pragma once |
6 | |
7 | #include <climits> |
8 | #include <cstring> |
9 | #include <functional> |
10 | #include <initializer_list> |
11 | #include <iostream> |
12 | #include <limits> |
13 | #include <memory> |
14 | #include <ostream> |
15 | #include <set> |
16 | #include <string> |
17 | #include <tuple> |
18 | #include <unordered_map> |
19 | #include <unordered_set> |
20 | #include <vector> |
21 | |
22 | #include "onnx/common/common.h" |
23 | #include "onnx/common/constants.h" |
24 | #include "onnx/defs/shape_inference.h" |
25 | |
26 | namespace ONNX_NAMESPACE { |
27 | |
28 | struct FunctionBodyBuildContext { |
29 | virtual const AttributeProto* getAttribute(const std::string& name) const = 0; |
30 | virtual bool hasInput(int inputIndex) const = 0; |
31 | virtual bool hasOutput(int inputIndex) const = 0; |
32 | // getInputType(i) should return null for missing optional inputs, or if |
33 | // type-inference could not infer the input-type (erroneous model). |
34 | virtual const TypeProto* getInputType(int inputIndex) const = 0; |
35 | virtual ~FunctionBodyBuildContext() {} |
36 | }; |
37 | |
38 | struct FunctionBodyBuildContextImpl : public FunctionBodyBuildContext { |
39 | // Input_types: use a default TypeProto for missing types. We use a different convention |
40 | // here (from FunctionBodyBuildContext) to simplify python interoperability. |
41 | // The default value for input_types is included only for backward compatibility. |
42 | // It can be used for functions that do not depend on the type-context, but |
43 | // will not be sufficient for functions that do use the type-context. |
44 | FunctionBodyBuildContextImpl(const NodeProto& node_proto, const std::vector<TypeProto>& input_types = {}) |
45 | : node_proto_(node_proto), input_types_(input_types) { |
46 | for (auto& attr : node_proto.attribute()) { |
47 | attributesByName_[attr.name()] = &attr; |
48 | } |
49 | } |
50 | |
51 | const AttributeProto* getAttribute(const std::string& name) const override { |
52 | auto iter = attributesByName_.find(name); |
53 | if (iter == attributesByName_.end()) { |
54 | return nullptr; |
55 | } else { |
56 | return iter->second; |
57 | } |
58 | } |
59 | |
60 | bool hasInput(int inputIndex) const override { |
61 | if (inputIndex >= node_proto_.input_size()) |
62 | return false; |
63 | return node_proto_.input(inputIndex) != "" ; |
64 | } |
65 | |
66 | bool hasOutput(int inputIndex) const override { |
67 | if (inputIndex >= node_proto_.output_size()) |
68 | return false; |
69 | return node_proto_.output(inputIndex) != "" ; |
70 | } |
71 | |
72 | const TypeProto* getInputType(int inputIndex) const override { |
73 | if (inputIndex < 0) |
74 | return nullptr; |
75 | size_t j = static_cast<size_t>(inputIndex); |
76 | if (j >= input_types_.size()) |
77 | return nullptr; |
78 | // Convert default value (no variant set) into null. |
79 | if (input_types_[j].value_case() == TypeProto::ValueCase::VALUE_NOT_SET) |
80 | return nullptr; |
81 | return &input_types_[j]; |
82 | } |
83 | |
84 | std::unordered_map<std::string, const AttributeProto*> attributesByName_; |
85 | |
86 | NodeProto node_proto_; |
87 | std::vector<TypeProto> input_types_; |
88 | }; |
89 | |
90 | using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>; |
91 | |
92 | class OpSchema; |
93 | using ContextDependentFunctionBodyBuilder = |
94 | std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>; |
95 | |
96 | class SchemaError final : public std::runtime_error { |
97 | public: |
98 | using std::runtime_error::runtime_error; |
99 | |
100 | SchemaError(const std::string& message) : std::runtime_error(message) {} |
101 | |
102 | const char* what() const noexcept override { |
103 | if (!expanded_message_.empty()) { |
104 | return expanded_message_.c_str(); |
105 | } |
106 | return std::runtime_error::what(); |
107 | } |
108 | |
109 | void AppendContext(const std::string& context) { |
110 | expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: " , context); |
111 | } |
112 | |
113 | private: |
114 | std::string expanded_message_; |
115 | }; |
116 | |
117 | #define fail_schema(...) ONNX_THROW_EX(ONNX_NAMESPACE::SchemaError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); |
118 | |
119 | using OperatorSetVersion = int; |
120 | |
121 | using DataTypeSet = std::unordered_set<DataType>; |
122 | |
123 | // Type constraint map. Key is type string. Value is data type set and |
124 | // description. |
125 | using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>; |
126 | |
127 | /** |
128 | * @brief A class to record the schema of an op. |
129 | * |
130 | * OpSchema records the common interface of an op specified by its name. |
131 | * |
132 | * To register an OpSchema, one can use the macro ONNX_OPERATOR_SCHEMA(name) and |
133 | * then append the various functions in the class. For example, for an op |
134 | * that takes in two inputs, one output, and the first input and output |
135 | * could be in-place, can be written as |
136 | * |
137 | * ONNX_OPERATOR_SCHEMA(name) |
138 | * .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}}); |
139 | * |
140 | * To manufacture methods that may be used to register an OpSchema |
141 | * non-statically, the following may be used: |
142 | * |
143 | * ONNX_OPERATOR_SET_SCHEMA(name, version, OpSchema() |
144 | * .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}})); |
145 | */ |
146 | class OpSchema final { |
147 | public: |
148 | static constexpr int kUninitializedSinceVersion = -1; |
149 | // Formal parameter options. |
150 | enum FormalParameterOption : uint8_t { |
151 | // The formal parameter is single and not optional. |
152 | // Number of supplied actual parameters must be 1. |
153 | Single = 0, |
154 | // The formal parameter is single and optional. |
155 | // Number of supplied actual parameters may be 0 or 1. |
156 | Optional = 1, |
157 | // The formal parameter is variadic. |
158 | // Number of supplied actual parameters must be N or more, where |
159 | // the minimum value N is indicated separately (default value 1). |
160 | Variadic = 2, |
161 | }; |
162 | enum DifferentiationCategory : uint8_t { |
163 | // Whether this formal parameter is differentiable or not cannot |
164 | // be statically determined. It also covers variadic formal |
165 | // parameters which contain both of differentiable and |
166 | // non-differentiable variables. |
167 | Unknown = 0, |
168 | // This formal parameter is differentiable. That is, this formal |
169 | // parameter can be differentiable input of Gradient operator. |
170 | Differentiable = 1, |
171 | // This formal parameter is not differentiable. That is, this formal |
172 | // parameter can not be differentiable input of Gradient operator. |
173 | NonDifferentiable = 2 |
174 | }; |
175 | |
176 | // Formal parameter represenation, including input/output name, typeStr, |
177 | // description, and type constraints. |
178 | class FormalParameter final { |
179 | public: |
180 | // Constructor. |
181 | FormalParameter() = default; |
182 | |
183 | explicit FormalParameter( |
184 | std::string name, |
185 | DataTypeSet allowed_type_set, |
186 | std::string type_str, |
187 | const std::string& description, |
188 | FormalParameterOption param_option = Single, |
189 | bool is_homogeneous = true, |
190 | int min_arity = 1, |
191 | DifferentiationCategory differentiation_category = Unknown) |
192 | : name_(std::move(name)), |
193 | type_set_(std::move(allowed_type_set)), |
194 | type_str_(std::move(type_str)), |
195 | #ifndef __ONNX_NO_DOC_STRINGS |
196 | description_(description), |
197 | #endif |
198 | param_option_(param_option), |
199 | is_homogeneous_(is_homogeneous), |
200 | min_arity_(min_arity), |
201 | differentiation_category_(differentiation_category) { |
202 | #ifdef __ONNX_NO_DOC_STRINGS |
203 | ONNX_UNUSED_PARAMETER(description); |
204 | #endif |
205 | } |
206 | |
207 | explicit FormalParameter( |
208 | std::string name, |
209 | const std::string& description, |
210 | std::string type_str, |
211 | FormalParameterOption param_option = Single, |
212 | bool is_homogeneous = true, |
213 | int min_arity = 1, |
214 | DifferentiationCategory differentiation_category = Unknown) |
215 | : name_(std::move(name)), |
216 | type_str_(std::move(type_str)), |
217 | #ifndef __ONNX_NO_DOC_STRINGS |
218 | description_(description), |
219 | #endif |
220 | param_option_(param_option), |
221 | is_homogeneous_(is_homogeneous), |
222 | min_arity_(min_arity), |
223 | differentiation_category_(differentiation_category) { |
224 | #ifdef __ONNX_NO_DOC_STRINGS |
225 | ONNX_UNUSED_PARAMETER(description); |
226 | #endif |
227 | } |
228 | |
229 | // Get formal parameter name. |
230 | const std::string& GetName() const; |
231 | |
232 | // Get allowed data types. |
233 | const DataTypeSet& GetTypes() const; |
234 | |
235 | // Get formal parameter type string. |
236 | const std::string& GetTypeStr() const; |
237 | |
238 | // Get formal parameter description. |
239 | const std::string& GetDescription() const; |
240 | |
241 | // Get the parameter option, it could be Single, Optional or Variadic. |
242 | FormalParameterOption GetOption() const; |
243 | |
244 | // Get whether a variadic parameter requires all to be of same type |
245 | bool GetIsHomogeneous() const; |
246 | |
247 | // Get minimum arity. Applicable only in the Variadic case. |
248 | int GetMinArity() const; |
249 | |
250 | // Get the differentiation property of this formal parameter. |
251 | DifferentiationCategory GetDifferentiationCategory() const; |
252 | |
253 | private: |
254 | friend class OpSchema; |
255 | |
256 | DataTypeSet& MutableTypes(); |
257 | |
258 | // Formal parameter name. |
259 | std::string name_; |
260 | |
261 | // A set of data types supported for <*this> formal parameter. |
262 | // It should contain at least one element if this formal parameter is good. |
263 | DataTypeSet type_set_; |
264 | |
265 | // The <parameter type> string specified when registring an op. |
266 | // It could be a supported data type or a type constraint key, which |
267 | // maps to a set of supported data types. |
268 | std::string type_str_; |
269 | |
270 | // Formal parameter description. |
271 | std::string description_; |
272 | |
273 | // Formal parameter option. |
274 | FormalParameterOption param_option_; |
275 | |
276 | // For variadic parameters, a flag indicating if all parameters must be of |
277 | // same type |
278 | bool is_homogeneous_; |
279 | |
280 | // Minimum number of parameters expected. Applicable only for Variadic. |
281 | int min_arity_; |
282 | |
283 | // True if this parameter can be an differentiable inputs of Gradient. |
284 | // Otherwise, using this parameter as an differentiable inputs of Gradient |
285 | // is prohibited. |
286 | DifferentiationCategory differentiation_category_; |
287 | }; |
288 | |
289 | enum class SupportType : uint8_t { |
290 | COMMON, // Supported by all frameworks that support this IR. |
291 | EXPERIMENTAL, // This OP is experimental and can be changed or removed in |
292 | // the future. |
293 | }; |
294 | |
295 | OpSchema() : OpSchema("unknown" , "unknown" , 0) {} |
296 | OpSchema(std::string name, std::string file, int line) |
297 | : name_(std::move(name)), file_(std::move(file)), line_(line), support_(SupportType::COMMON) {} |
298 | |
299 | /** |
300 | * @brief Returns the file that the op schema is registered from. |
301 | */ |
302 | const std::string& file() const { |
303 | return file_; |
304 | } |
305 | |
306 | /** |
307 | * @brief Returns the line in file that the op schema is registered from. |
308 | */ |
309 | int line() const { |
310 | return line_; |
311 | } |
312 | |
313 | /** |
314 | * @brief Returns the support level of the op schema. |
315 | */ |
316 | SupportType support_level() const { |
317 | return support_; |
318 | } |
319 | |
320 | /** |
321 | * @brief Returns the docstring of the op schema. |
322 | */ |
323 | const char* doc() const { |
324 | return doc_.empty() ? nullptr : doc_.c_str(); |
325 | } |
326 | |
327 | // Check if input and output types fall into valid set and match each other |
328 | void CheckInputOutputType(struct InferenceContext&) const; |
329 | |
330 | /** |
331 | * @brief Verifies if a NodeProto matches the pattern specified in |
332 | * the schema. |
333 | */ |
334 | void Verify(const NodeProto& node) const; |
335 | |
336 | // Functions to set the property of the operator schemas. |
337 | // Sets the number of inputs, either a fixed number or a min and a max. |
338 | |
339 | /** |
340 | * The earliest operator set version which this operator was |
341 | * present in. If an operator has had no BC-breaking changes, |
342 | * this is simply the first operator set the operator was a member |
343 | * of; if it has had BC-breaking changes, then for the semantics |
344 | * /as described/ in the OpSchema entry, this version describes |
345 | * the operator set which introduced the BC-breaking change. |
346 | * |
347 | * For example, suppose op Foo was added in v3, and had a BC-breaking |
348 | * change in v6. Then there will be an op schema entry for Foo with |
349 | * SinceVersion(3), and another, updated op schema entry for Foo |
350 | * with SinceVersion(6). |
351 | */ |
352 | OpSchema& SinceVersion(OperatorSetVersion n); // aka int |
353 | |
354 | /** |
355 | * Marks this op as deprecated as of it's since_version. This will cause the |
356 | * Schema() lookup functions to return nullptr when the version is in the |
357 | * deprecated range. |
358 | */ |
359 | OpSchema& Deprecate(); |
360 | |
361 | bool Deprecated() const { |
362 | return deprecated_; |
363 | } |
364 | |
365 | /** |
366 | * @brief Input could be one of the values specified in allowed_input_nums. |
367 | */ |
368 | OpSchema& NumInputs(std::set<int> allowed_input_nums); |
369 | |
370 | /** |
371 | * @brief Output could be one of the values specified in allowed_output_nums. |
372 | */ |
373 | OpSchema& NumOutputs(std::set<int> allowed_output_nums); |
374 | |
375 | // Shape Inference |
376 | // |
377 | // Note that signatures are defined to allow for forward-declaring |
378 | // any structs used from ir.h |
379 | OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction); |
380 | InferenceFunction GetTypeAndShapeInferenceFunction() const { |
381 | return tensor_inference_function_ ? tensor_inference_function_ : dummyInferenceFunction; |
382 | } |
383 | |
384 | OpSchema& PartialDataPropagationFunction(DataPropagationFunction dataProgationFunction); |
385 | DataPropagationFunction GetDataPropagationFunction() const { |
386 | return data_propagation_function_ ? data_propagation_function_ : dummyDataPropagationFunction; |
387 | } |
388 | |
389 | // Set the support level for the op schema. |
390 | OpSchema& SetSupportLevel(SupportType supportType); |
391 | |
392 | // Functions to do documentation for the operator schema. |
393 | // This may be disabled to save memory. |
394 | OpSchema& SetDoc(const char* doc) { |
395 | #ifndef __ONNX_NO_DOC_STRINGS |
396 | SetDoc(std::string(doc)); |
397 | #else |
398 | ONNX_UNUSED_PARAMETER(doc); |
399 | #endif |
400 | |
401 | return *this; |
402 | } |
403 | |
404 | OpSchema& SetDoc(const std::string& doc) { |
405 | #ifndef __ONNX_NO_DOC_STRINGS |
406 | doc_ = doc; |
407 | #else |
408 | ONNX_UNUSED_PARAMETER(doc); |
409 | #endif |
410 | return *this; |
411 | } |
412 | |
413 | // Functions to specify name for the operator schema. |
414 | OpSchema& SetName(const char* name); |
415 | OpSchema& SetName(std::string name); |
416 | |
417 | // Functions to specify code location for the operator schema. |
418 | OpSchema& SetLocation(const char* file, int line); |
419 | OpSchema& SetLocation(std::string file, int line); |
420 | |
421 | // Functions to specify domain for the operator schema. |
422 | // Default domain value (ONNX_DOMAIN) means it's ONNX domain. |
423 | OpSchema& SetDomain(const char* domain); |
424 | OpSchema& SetDomain(std::string domain); |
425 | |
426 | struct Attribute final { |
427 | Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_) |
428 | : name(std::move(name_)), |
429 | description(std::move(description_)), |
430 | type(type_), |
431 | required(required_), |
432 | default_value() {} |
433 | |
434 | Attribute(std::string name_, std::string description_, AttributeProto default_value_) |
435 | : name(std::move(name_)), |
436 | description(std::move(description_)), |
437 | type(default_value_.type()), |
438 | required(false), |
439 | default_value(std::move(default_value_)) {} |
440 | |
441 | const std::string name; |
442 | const std::string description; |
443 | AttributeProto::AttributeType type; |
444 | bool required; |
445 | AttributeProto default_value; |
446 | }; |
447 | |
448 | OpSchema& Attr(Attribute attr); |
449 | |
450 | // Register "optional" attribute with default value. |
451 | #define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName) \ |
452 | OpSchema& Attr( \ |
453 | std::string name, std::string description, AttributeProto::AttributeType type, const TypeName& defaultValue); \ |
454 | /* non-STL wrapper to reduce binary size */ \ |
455 | OpSchema& Attr( \ |
456 | const char* name, const char* description, AttributeProto::AttributeType type, const TypeName& defaultValue); \ |
457 | OpSchema& Attr( \ |
458 | std::string name, \ |
459 | std::string description, \ |
460 | AttributeProto::AttributeType type, \ |
461 | const std::vector<TypeName>& defaultValue); |
462 | |
463 | ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t) |
464 | ATTR_SETTER_WITH_DEFAULT_VALUE(float) |
465 | ATTR_SETTER_WITH_DEFAULT_VALUE(std::string) |
466 | ATTR_SETTER_WITH_DEFAULT_VALUE(TensorProto) |
467 | ATTR_SETTER_WITH_DEFAULT_VALUE(GraphProto) |
468 | ATTR_SETTER_WITH_DEFAULT_VALUE(TypeProto) |
469 | |
470 | // Register "required" attribute without default value. |
471 | OpSchema& Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true); |
472 | |
473 | // Non-STL wrapper to reduce binary size |
474 | OpSchema& Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required = true); |
475 | |
476 | OpSchema& AllowUncheckedAttributes(); |
477 | |
478 | // Type constraint. |
479 | struct TypeConstraintParam final { |
480 | TypeConstraintParam( |
481 | std::string type_param_str_, |
482 | std::vector<std::string> allowed_type_strs_, |
483 | std::string description_) |
484 | : type_param_str(std::move(type_param_str_)), |
485 | allowed_type_strs(std::move(allowed_type_strs_)), |
486 | description(std::move(description_)) {} |
487 | |
488 | // Type parameter string, for example, "T", "T1", etc. |
489 | std::string type_param_str; |
490 | // Allowed type strings for <*this> type parameter, for example, |
491 | // "tensor(float)". |
492 | std::vector<std::string> allowed_type_strs; |
493 | // Type parameter description. |
494 | std::string description; |
495 | }; |
496 | |
497 | // Grammar for type strings used in Input(), Output(). |
498 | // <type> ::= <data_type> | |
499 | // tensor(<data_type>) | |
500 | // seq(<type>) | |
501 | // map(<data_type>, <type>) | |
502 | // <type_parameter> |
503 | // <data_type> :: = float | int32 | string | bool | uint8 |
504 | // | int8 | uint16 | int16 | int64 | float16 | double |
505 | // <type_parameter> ::= any type parameter string, say "T". |
506 | // |
507 | // NOTE: 1) <type_parameter> will always be together with a type constraints |
508 | // specification. |
509 | // 2) <type> ::= <data_type> means the data is scalar (zero dimension). |
510 | // |
511 | // Example: |
512 | // ONNX_OPERATOR_SET_SCHEMA(Sum, 1, OpSchema() |
513 | // .Input(0, "input_a", "the first input", "T") |
514 | // .Input(1, "input_b", "the second input", "T") |
515 | // .Output(0, "sum", "the sum of two numbers", "T") |
516 | // .TypeConstraint("T", {"float", "double", "int32"}, "allowed data types for |
517 | // sum.")) |
518 | // |
519 | // Optional = true means that the input might have empty input value |
520 | // (represented as "") in the graph even though the later inputs have values. |
521 | // It's useful for complex situation when there are several independent |
522 | // optional inputs. |
523 | OpSchema& Input( |
524 | int n, |
525 | std::string name, |
526 | const std::string& description, |
527 | std::string type_str, |
528 | FormalParameterOption param_option = Single, |
529 | bool is_homogeneous = true, |
530 | int min_arity = 1, |
531 | DifferentiationCategory differentiation_category = Unknown); |
532 | |
533 | // Non-STL wrapper to reduce binary size |
534 | OpSchema& Input( |
535 | int n, |
536 | const char* name, |
537 | const char* description, |
538 | const char* type_str, |
539 | FormalParameterOption param_option = Single, |
540 | bool is_homogeneous = true, |
541 | int min_arity = 1, |
542 | DifferentiationCategory differentiation_category = Unknown); |
543 | |
544 | OpSchema& Output( |
545 | int n, |
546 | std::string name, |
547 | const std::string& description, |
548 | std::string type_str, |
549 | FormalParameterOption param_option = Single, |
550 | bool is_homogeneous = true, |
551 | int min_arity = 1, |
552 | DifferentiationCategory differentiation_category = Unknown); |
553 | |
554 | // Non-STL wrapper to reduce binary size |
555 | OpSchema& Output( |
556 | int n, |
557 | const char* name, |
558 | const char* description, |
559 | const char* type_str, |
560 | FormalParameterOption param_option = Single, |
561 | bool is_homogeneous = true, |
562 | int min_arity = 1, |
563 | DifferentiationCategory differentiation_category = Unknown); |
564 | |
565 | OpSchema& TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description); |
566 | |
567 | // Non-STL wrapper to reduce binary size |
568 | OpSchema& |
569 | TypeConstraint(const char* type_str, std::initializer_list<const char*> constraints, const char* description); |
570 | |
571 | // Convenience members for types |
572 | |
573 | // All high-precision numeric types. |
574 | static const std::vector<std::string>& numeric_types_for_math_reduction_with_bfloat() { |
575 | static const std::vector<std::string> numeric_types_for_math_reduction_with_bfloat = { |
576 | "tensor(uint32)" , |
577 | "tensor(uint64)" , |
578 | "tensor(int32)" , |
579 | "tensor(int64)" , |
580 | "tensor(float16)" , |
581 | "tensor(float)" , |
582 | "tensor(double)" , |
583 | "tensor(bfloat16)" }; |
584 | return numeric_types_for_math_reduction_with_bfloat; |
585 | } |
586 | |
587 | static const std::vector<std::string>& numeric_types_for_math_reduction() { |
588 | static const std::vector<std::string> numeric_types_for_math_reduction = { |
589 | "tensor(uint32)" , |
590 | "tensor(uint64)" , |
591 | "tensor(int32)" , |
592 | "tensor(int64)" , |
593 | "tensor(float16)" , |
594 | "tensor(float)" , |
595 | "tensor(double)" }; |
596 | return numeric_types_for_math_reduction; |
597 | } |
598 | |
599 | static const std::vector<std::string>& all_numeric_types_with_bfloat() { |
600 | static const std::vector<std::string> all_numeric_types_with_bfloat = { |
601 | "tensor(uint8)" , |
602 | "tensor(uint16)" , |
603 | "tensor(uint32)" , |
604 | "tensor(uint64)" , |
605 | "tensor(int8)" , |
606 | "tensor(int16)" , |
607 | "tensor(int32)" , |
608 | "tensor(int64)" , |
609 | "tensor(float16)" , |
610 | "tensor(float)" , |
611 | "tensor(double)" , |
612 | "tensor(bfloat16)" }; |
613 | return all_numeric_types_with_bfloat; |
614 | } |
615 | |
616 | static const std::vector<std::string>& all_numeric_types() { |
617 | static const std::vector<std::string> all_numeric_types = { |
618 | "tensor(uint8)" , |
619 | "tensor(uint16)" , |
620 | "tensor(uint32)" , |
621 | "tensor(uint64)" , |
622 | "tensor(int8)" , |
623 | "tensor(int16)" , |
624 | "tensor(int32)" , |
625 | "tensor(int64)" , |
626 | "tensor(float16)" , |
627 | "tensor(float)" , |
628 | "tensor(double)" }; |
629 | return all_numeric_types; |
630 | } |
631 | |
632 | static const std::vector<std::string>& all_numeric_sequence_types() { |
633 | static const std::vector<std::string> all_numeric_sequence_types = { |
634 | "seq(tensor(uint8))" , |
635 | "seq(tensor(uint16))" , |
636 | "seq(tensor(uint32))" , |
637 | "seq(tensor(uint64))" , |
638 | "seq(tensor(int8))" , |
639 | "seq(tensor(int16))" , |
640 | "seq(tensor(int32))" , |
641 | "seq(tensor(int64))" , |
642 | "seq(tensor(float16))" , |
643 | "seq(tensor(float))" , |
644 | "seq(tensor(double))" }; |
645 | return all_numeric_sequence_types; |
646 | } |
647 | |
648 | static const std::vector<std::string>& all_tensor_types() { |
649 | static const std::vector<std::string> all_tensor_types = { |
650 | "tensor(uint8)" , |
651 | "tensor(uint16)" , |
652 | "tensor(uint32)" , |
653 | "tensor(uint64)" , |
654 | "tensor(int8)" , |
655 | "tensor(int16)" , |
656 | "tensor(int32)" , |
657 | "tensor(int64)" , |
658 | "tensor(float16)" , |
659 | "tensor(float)" , |
660 | "tensor(double)" , |
661 | "tensor(string)" , |
662 | "tensor(bool)" , |
663 | "tensor(complex64)" , |
664 | "tensor(complex128)" }; |
665 | return all_tensor_types; |
666 | } |
667 | |
668 | static const std::vector<std::string>& all_tensor_types_with_bfloat() { |
669 | static const std::vector<std::string> all_tensor_types_with_bfloat = { |
670 | "tensor(uint8)" , |
671 | "tensor(uint16)" , |
672 | "tensor(uint32)" , |
673 | "tensor(uint64)" , |
674 | "tensor(int8)" , |
675 | "tensor(int16)" , |
676 | "tensor(int32)" , |
677 | "tensor(int64)" , |
678 | "tensor(bfloat16)" , |
679 | "tensor(float16)" , |
680 | "tensor(float)" , |
681 | "tensor(double)" , |
682 | "tensor(string)" , |
683 | "tensor(bool)" , |
684 | "tensor(complex64)" , |
685 | "tensor(complex128)" }; |
686 | return all_tensor_types_with_bfloat; |
687 | } |
688 | |
689 | static const std::vector<std::string>& all_tensor_sequence_types() { |
690 | static const std::vector<std::string> all_tensor_sequence_types = { |
691 | "seq(tensor(uint8))" , |
692 | "seq(tensor(uint16))" , |
693 | "seq(tensor(uint32))" , |
694 | "seq(tensor(uint64))" , |
695 | "seq(tensor(int8))" , |
696 | "seq(tensor(int16))" , |
697 | "seq(tensor(int32))" , |
698 | "seq(tensor(int64))" , |
699 | "seq(tensor(float16))" , |
700 | "seq(tensor(float))" , |
701 | "seq(tensor(double))" , |
702 | "seq(tensor(string))" , |
703 | "seq(tensor(bool))" , |
704 | "seq(tensor(complex64))" , |
705 | "seq(tensor(complex128))" }; |
706 | return all_tensor_sequence_types; |
707 | } |
708 | |
709 | static const std::vector<std::string>& all_tensor_sequence_types_with_bfloat() { |
710 | static const std::vector<std::string> all_tensor_sequence_types_with_bfloat = { |
711 | "seq(tensor(uint8))" , |
712 | "seq(tensor(uint16))" , |
713 | "seq(tensor(uint32))" , |
714 | "seq(tensor(uint64))" , |
715 | "seq(tensor(int8))" , |
716 | "seq(tensor(int16))" , |
717 | "seq(tensor(int32))" , |
718 | "seq(tensor(int64))" , |
719 | "seq(tensor(bfloat16))" , |
720 | "seq(tensor(float16))" , |
721 | "seq(tensor(float))" , |
722 | "seq(tensor(double))" , |
723 | "seq(tensor(string))" , |
724 | "seq(tensor(bool))" , |
725 | "seq(tensor(complex64))" , |
726 | "seq(tensor(complex128))" }; |
727 | return all_tensor_sequence_types_with_bfloat; |
728 | } |
729 | |
730 | static const std::vector<std::string>& all_optional_types() { |
731 | static const std::vector<std::string> all_optional_types = { |
732 | "optional(seq(tensor(uint8)))" , "optional(seq(tensor(uint16)))" , "optional(seq(tensor(uint32)))" , |
733 | "optional(seq(tensor(uint64)))" , "optional(seq(tensor(int8)))" , "optional(seq(tensor(int16)))" , |
734 | "optional(seq(tensor(int32)))" , "optional(seq(tensor(int64)))" , "optional(seq(tensor(float16)))" , |
735 | "optional(seq(tensor(float)))" , "optional(seq(tensor(double)))" , "optional(seq(tensor(string)))" , |
736 | "optional(seq(tensor(bool)))" , "optional(seq(tensor(complex64)))" , "optional(seq(tensor(complex128)))" , |
737 | "optional(tensor(uint8))" , "optional(tensor(uint16))" , "optional(tensor(uint32))" , |
738 | "optional(tensor(uint64))" , "optional(tensor(int8))" , "optional(tensor(int16))" , |
739 | "optional(tensor(int32))" , "optional(tensor(int64))" , "optional(tensor(float16))" , |
740 | "optional(tensor(float))" , "optional(tensor(double))" , "optional(tensor(string))" , |
741 | "optional(tensor(bool))" , "optional(tensor(complex64))" , "optional(tensor(complex128))" }; |
742 | return all_optional_types; |
743 | } |
744 | |
745 | static const std::vector<std::string>& all_optional_types_with_bfloat() { |
746 | static const std::vector<std::string> all_optional_types = { |
747 | "optional(seq(tensor(uint8)))" , "optional(seq(tensor(uint16)))" , "optional(seq(tensor(uint32)))" , |
748 | "optional(seq(tensor(uint64)))" , "optional(seq(tensor(int8)))" , "optional(seq(tensor(int16)))" , |
749 | "optional(seq(tensor(int32)))" , "optional(seq(tensor(int64)))" , "optional(seq(tensor(bfloat16)))" , |
750 | "optional(seq(tensor(float16)))" , "optional(seq(tensor(float)))" , "optional(seq(tensor(double)))" , |
751 | "optional(seq(tensor(string)))" , "optional(seq(tensor(bool)))" , "optional(seq(tensor(complex64)))" , |
752 | "optional(seq(tensor(complex128)))" , "optional(tensor(uint8))" , "optional(tensor(uint16))" , |
753 | "optional(tensor(uint32))" , "optional(tensor(uint64))" , "optional(tensor(int8))" , |
754 | "optional(tensor(int16))" , "optional(tensor(int32))" , "optional(tensor(int64))" , |
755 | "optional(tensor(bfloat16))" , "optional(tensor(float16))" , "optional(tensor(float))" , |
756 | "optional(tensor(double))" , "optional(tensor(string))" , "optional(tensor(bool))" , |
757 | "optional(tensor(complex64))" , "optional(tensor(complex128))" }; |
758 | return all_optional_types; |
759 | } |
760 | |
761 | // Calls the passed function with `this` as an argument. Useful for |
762 | // adding docs for temlated/macro ops. |
763 | OpSchema& FillUsing(const std::function<void(OpSchema&)>& populator); |
764 | |
765 | friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema); |
766 | |
767 | const std::string& domain() const { |
768 | return domain_; |
769 | } |
770 | |
771 | const std::map<std::string, Attribute>& attributes() const { |
772 | return attributes_; |
773 | } |
774 | |
775 | // Get input formal parameters. |
776 | const std::vector<FormalParameter>& inputs() const { |
777 | return inputs_; |
778 | } |
779 | |
780 | // Get output formal parameters. |
781 | const std::vector<FormalParameter>& outputs() const { |
782 | return outputs_; |
783 | } |
784 | |
785 | const std::vector<TypeConstraintParam>& typeConstraintParams() const { |
786 | return type_constraint_params_; |
787 | } |
788 | |
789 | const std::string& Name() const { |
790 | return name_; |
791 | } |
792 | |
793 | OperatorSetVersion SinceVersion() const { |
794 | return since_version_; |
795 | } |
796 | |
797 | int since_version() const { |
798 | return since_version_; |
799 | } |
800 | |
801 | bool deprecated() const { |
802 | return deprecated_; |
803 | } |
804 | |
805 | int min_input() const { |
806 | return min_input_; |
807 | } |
808 | int max_input() const { |
809 | return max_input_; |
810 | } |
811 | int min_output() const { |
812 | return min_output_; |
813 | } |
814 | int max_output() const { |
815 | return max_output_; |
816 | } |
817 | |
818 | bool has_type_and_shape_inference_function() const { |
819 | return tensor_inference_function_ ? true : false; |
820 | } |
821 | |
822 | bool has_data_propagation_function() const { |
823 | return data_propagation_function_ ? true : false; |
824 | } |
825 | |
826 | std::vector<int> function_opset_versions() const { |
827 | std::vector<int> opset_versions; |
828 | std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it = opset_version_to_function_body_.cbegin(); |
829 | for (; it != opset_version_to_function_body_.cend(); ++it) { |
830 | opset_versions.push_back(it->first); |
831 | } |
832 | return opset_versions; |
833 | } |
834 | |
835 | bool HasFunction() const { |
836 | return !opset_version_to_function_body_.empty(); |
837 | } |
838 | |
839 | OpSchema& FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version = kUninitializedSinceVersion); |
840 | |
841 | OpSchema& FunctionBody( |
842 | const std::vector<NodeProto>& func_nodes, |
843 | const std::vector<OperatorSetIdProto>& opsets, |
844 | int opset_version = kUninitializedSinceVersion); |
845 | |
846 | OpSchema& FunctionBody(const char* func_body, int opset_version = kUninitializedSinceVersion); |
847 | |
848 | // since_version_ of an OpSchema tells the last opset version when an op is defined. |
849 | // When the op's definition is changed, a new OpSchema (of the same op_type) is created |
850 | // with a newer since_version_, reflecting the opset version at the time of change. |
851 | // For a function op, operators used to define its function body may change |
852 | // while there is no change to the function op definition itself. |
853 | // When this happens, mutiple function bodies are provided, each for a specific opset version. |
854 | // |
855 | // Take LogSoftmax for example. Its latest opset version is 13. |
856 | // In LogSoftmax's function body, ReduceMax (with since_version_ 1, 11, 12, 18) is used. |
857 | // When a model containing LogSoftmax with opset_import version within 13 to 17 is loaded, function body |
858 | // with opset_version 13 is used for inlining. |
859 | // When the same model but opset_import version 18 is loaded, function body |
860 | // with opset_version 18 is used for inlining. |
861 | // Clearly function body for opset_import version 13 will not work |
862 | // in a model with opset_import version 18 because the function body make worng use of ReduceMax(18). |
863 | // Inside GetFunction we ensure that ops being used to construct a function body do not endure such |
864 | // issue. |
865 | const FunctionProto* GetFunction( |
866 | int requested_opset_version = OpSchema::kUninitializedSinceVersion, |
867 | bool validate = false) const; |
868 | |
869 | std::vector<int> context_dependent_function_opset_versions() const { |
870 | std::vector<int> opset_versions; |
871 | std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it = opset_version_to_function_builder_.cbegin(); |
872 | for (; it != opset_version_to_function_builder_.cend(); ++it) { |
873 | opset_versions.push_back(it->first); |
874 | } |
875 | return opset_versions; |
876 | } |
877 | |
878 | bool HasContextDependentFunction() const { |
879 | return !opset_version_to_function_builder_.empty(); |
880 | } |
881 | |
882 | bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const { |
883 | return opset_version_to_function_builder_.find(opset_version) != opset_version_to_function_builder_.end(); |
884 | } |
885 | |
886 | OpSchema& SetContextDependentFunctionBodyBuilder( |
887 | ContextDependentFunctionBodyBuilder, |
888 | int opset_version = kUninitializedSinceVersion); |
889 | |
890 | bool BuildContextDependentFunction( |
891 | const FunctionBodyBuildContext& ctx, |
892 | FunctionProto& function_proto, |
893 | int requested_opset_version = OpSchema::kUninitializedSinceVersion) const; |
894 | |
895 | // Verifies that the schema is valid and all specifications are compatible. |
896 | // It will also parse all type strings specified for inputs/outputs into valid |
897 | // TypeProto and create global unique string pointer as the DataType for |
898 | // efficiency. |
899 | void Finalize(); |
900 | |
901 | // Build function with information stored in opschema |
902 | void BuildFunction(FunctionProto& function_body) const; |
903 | |
904 | private: |
905 | void ParseAndSetTypes( |
906 | /*out*/ std::vector<OpSchema::FormalParameter>* formalParameters); |
907 | bool ValidateReferencedOpsInFuncton( |
908 | const FunctionProto* function, |
909 | int requested_opset_version, |
910 | int function_since_version, |
911 | std::set<std::string>* updated_ops = nullptr) const; |
912 | void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const; |
913 | |
914 | std::string name_; |
915 | std::string file_; |
916 | std::string doc_; |
917 | // Default domain value ("") means it's ONNX domain. |
918 | std::string domain_ = ONNX_DOMAIN; |
919 | std::map<std::string, Attribute> attributes_{}; |
920 | bool allows_unchecked_attributes_ = false; |
921 | std::vector<FormalParameter> inputs_; |
922 | std::vector<FormalParameter> outputs_; |
923 | std::vector<TypeConstraintParam> type_constraint_params_; |
924 | TypeConstraintMap type_constraints_; |
925 | int line_ = 0; |
926 | SupportType support_; |
927 | int min_input_ = 0; |
928 | int max_input_ = 0; |
929 | int min_output_ = 0; |
930 | int max_output_ = 0; |
931 | // The default is a little goofy, since it is never what you want |
932 | OperatorSetVersion since_version_ = kUninitializedSinceVersion; |
933 | bool deprecated_{}; |
934 | std::function<bool(int)> num_inputs_allowed_ = [](int) { return true; }; |
935 | std::function<bool(int)> num_outputs_allowed_ = [](int) { return true; }; |
936 | InferenceFunction tensor_inference_function_; |
937 | DataPropagationFunction data_propagation_function_; |
938 | |
939 | std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_; |
940 | std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_; |
941 | }; |
942 | |
943 | // Map type to store operator schemas. The format is, |
944 | // <OpName, <Domain, <OperatorSetVersion, OpSchema>>>. |
945 | using OpName_Domain_Version_Schema_Map = |
946 | std::unordered_map<std::string, std::unordered_map<std::string, std::map<OperatorSetVersion, OpSchema>>>; |
947 | |
948 | class ISchemaRegistry { |
949 | public: |
950 | virtual ~ISchemaRegistry() = default; |
951 | |
952 | virtual const OpSchema* |
953 | GetSchema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) const = 0; |
954 | }; |
955 | |
956 | /** |
957 | * @brief A registry to hold all the operator schemas. |
958 | */ |
959 | class OpSchemaRegistry final : public ISchemaRegistry { |
960 | public: |
961 | // A singleton class to store domain to min/max op_set version map, as well as |
962 | // domain to last-release op_set version map. |
963 | class DomainToVersionRange final { |
964 | public: |
965 | DomainToVersionRange() { |
966 | // Increase the highest version when you make BC-breaking changes to the |
967 | // operator schema on specific domain. Update the lowest version when it's |
968 | // determined to remove too old version history. |
969 | map_[ONNX_DOMAIN] = std::make_pair(1, 18); |
970 | map_[AI_ONNX_ML_DOMAIN] = std::make_pair(1, 3); |
971 | map_[AI_ONNX_TRAINING_DOMAIN] = std::make_pair(1, 1); |
972 | // ONNX's preview domain contains operators subject to change, so |
973 | // versining is not meaningful and that domain should have only one |
974 | // version. |
975 | map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1); |
976 | // Version corresponding last release of ONNX. Update this to match with |
977 | // the max version above in a *release* version of ONNX. But in other |
978 | // versions, the max version may be ahead of the last-release-version. |
979 | last_release_version_map_[ONNX_DOMAIN] = 18; |
980 | last_release_version_map_[AI_ONNX_ML_DOMAIN] = 3; |
981 | last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1; |
982 | last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1; |
983 | } |
984 | |
985 | const std::unordered_map<std::string, std::pair<int, int>>& Map() const { |
986 | return map_; |
987 | } |
988 | |
989 | const std::unordered_map<std::string, int>& LastReleaseVersionMap() const { |
990 | return last_release_version_map_; |
991 | } |
992 | |
993 | // Add customized domain to min/max version. |
994 | // Onnx partners are able to use onnx operator schema api to |
995 | // register customized op in their own domain. |
996 | // Can optionally specify last_release_version (to make it similar to |
997 | // standard ONNX domains as above). Custom-domains are free to interpret |
998 | // this as appropriate (that is, as relative to releases of custom-domain |
999 | // as opposed to ONNX releases). |
1000 | void |
1001 | AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) { |
1002 | std::lock_guard<std::mutex> lock(mutex_); |
1003 | assert(map_.end() == map_.find(domain)); |
1004 | map_[domain] = std::make_pair(min_version, max_version); |
1005 | // If a last-release-version is not explicitly specified, use max as |
1006 | // last-release-version. |
1007 | if (last_release_version == -1) |
1008 | last_release_version = max_version; |
1009 | assert(last_release_version_map_.end() == last_release_version_map_.find(domain)); |
1010 | last_release_version_map_[domain] = last_release_version; |
1011 | } |
1012 | |
1013 | static DomainToVersionRange& Instance(); |
1014 | |
1015 | private: |
1016 | // Key: domain. Value: <lowest version, highest version> pair. |
1017 | std::unordered_map<std::string, std::pair<int, int>> map_; |
1018 | |
1019 | // Key: domain. Value: most recent release opset version. Note that |
1020 | // the highest opset version may be ahead of the most recent release's opset |
1021 | // version. |
1022 | std::unordered_map<std::string, int> last_release_version_map_; |
1023 | |
1024 | std::mutex mutex_; |
1025 | }; |
1026 | |
1027 | class OpSchemaRegisterOnce final { |
1028 | public: |
1029 | OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0) { |
1030 | ONNX_TRY { |
1031 | op_schema.Finalize(); |
1032 | auto& m = GetMapWithoutEnsuringRegistration(); |
1033 | auto& op_name = op_schema.Name(); |
1034 | auto& op_domain = op_schema.domain(); |
1035 | auto ver = op_schema.SinceVersion(); |
1036 | if (OpSchema::kUninitializedSinceVersion == ver) { |
1037 | op_schema.SinceVersion(1); |
1038 | ver = op_schema.SinceVersion(); |
1039 | } |
1040 | // Stops because the opset_version is higher than opset_version_to_load |
1041 | if (opset_version_to_load != 0 && ver > opset_version_to_load) { |
1042 | return; |
1043 | } |
1044 | if (m[op_name][op_domain].count(ver)) { |
1045 | const auto& schema = m[op_name][op_domain][ver]; |
1046 | std::stringstream err; |
1047 | err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver |
1048 | << ") from file " << op_schema.file() << " line " << op_schema.line() |
1049 | << ", but it is already registered from file " << schema.file() << " line " << schema.line() << std::endl; |
1050 | fail_schema(err.str()); |
1051 | } |
1052 | // Return early if schema for the targeted opset version has already been loaded |
1053 | if (opset_version_to_load != 0 && !m[op_name][op_domain].empty()) { |
1054 | return; |
1055 | } |
1056 | auto ver_range_map = DomainToVersionRange::Instance().Map(); |
1057 | auto ver_range_it = ver_range_map.find(op_domain); |
1058 | if (ver_range_it == ver_range_map.end()) { |
1059 | std::stringstream err; |
1060 | err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver |
1061 | << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its domain is not" |
1062 | << " known by the checker." << std::endl; |
1063 | |
1064 | fail_schema(err.str()); |
1065 | } |
1066 | auto lower_bound_incl = ver_range_it->second.first; |
1067 | auto upper_bound_incl = ver_range_it->second.second; |
1068 | if (!(lower_bound_incl <= ver && upper_bound_incl >= ver)) { |
1069 | std::stringstream err; |
1070 | err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver |
1071 | << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its version is not " |
1072 | << "in the inclusive range [" << lower_bound_incl << ", " << upper_bound_incl |
1073 | << "] (usually, this means you " |
1074 | << "bumped the operator version but " |
1075 | << "forgot to update the version range in DomainToVersionRange " |
1076 | << "in onnx/defs/schema.h)." << std::endl; |
1077 | fail_schema(err.str()); |
1078 | } |
1079 | |
1080 | m[op_name][op_domain].insert(std::pair<int, OpSchema&&>(ver, std::move(op_schema))); |
1081 | } |
1082 | ONNX_CATCH(const std::exception& e) { |
1083 | ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; }); |
1084 | } |
1085 | } |
1086 | }; |
1087 | |
1088 | // Return the latest schema for an operator in specified domain. |
1089 | // Domain with default value ONNX_DOMAIN means ONNX. |
1090 | static const OpSchema* Schema(const std::string& key, const std::string& domain = ONNX_DOMAIN) { |
1091 | auto& m = map(); |
1092 | if (m.count(key) && m[key].count(domain)) { |
1093 | return &m[key][domain].rbegin()->second; |
1094 | } else { |
1095 | return nullptr; |
1096 | } |
1097 | } |
1098 | |
1099 | // Return the schema with biggest version, which is not greater than specified |
1100 | // <maxInclusiveVersion> in specified domain. Domain with default value |
1101 | // ONNX_DOMAIN means ONNX. |
1102 | static const OpSchema* |
1103 | Schema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) { |
1104 | auto& m = map(); |
1105 | if (m.count(key) && m[key].count(domain)) { |
1106 | auto pos = m[key][domain].lower_bound(maxInclusiveVersion); |
1107 | if (m[key][domain].begin() == pos && pos->first > maxInclusiveVersion) { |
1108 | // All versions are greater than specified version. |
1109 | return nullptr; |
1110 | } |
1111 | if (m[key][domain].end() == pos || pos->first > maxInclusiveVersion) { |
1112 | // All versions are less than specified version, or, |
1113 | // The <pos> version is greater than specified version. |
1114 | pos--; |
1115 | } |
1116 | |
1117 | // Schema with exact version as specified one exists. |
1118 | return &(pos->second); |
1119 | } else { |
1120 | return nullptr; |
1121 | } |
1122 | } |
1123 | |
1124 | static OpSchemaRegistry* Instance(); |
1125 | |
1126 | const OpSchema* GetSchema( |
1127 | const std::string& key, |
1128 | const int maxInclusiveVersion, |
1129 | const std::string& domain = ONNX_DOMAIN) const override { |
1130 | return Schema(key, maxInclusiveVersion, domain); |
1131 | } |
1132 | static void SetLoadedSchemaVersion(int target_version) { |
1133 | loaded_schema_version = target_version; |
1134 | } |
1135 | static int GetLoadedSchemaVersion() { |
1136 | return loaded_schema_version; |
1137 | } |
1138 | |
1139 | private: |
1140 | // OpSchemaRegistry should not need to be instantiated except statically |
1141 | // within this class |
1142 | OpSchemaRegistry() = default; |
1143 | |
1144 | /** |
1145 | * @brief Returns the underlying string to OpSchema map. |
1146 | * |
1147 | * You should not manually manipulate the map object returned. Instead, use |
1148 | * the macros defined such as ONNX_OPERATOR_SET_SCHEMA to register your |
1149 | * operator schema. |
1150 | * |
1151 | * We wrap it inside a function to avoid the static initialization order |
1152 | * fiasco. |
1153 | */ |
1154 | static OpName_Domain_Version_Schema_Map& GetMapWithoutEnsuringRegistration(); |
1155 | static OpName_Domain_Version_Schema_Map& map(); |
1156 | static int loaded_schema_version; |
1157 | |
1158 | public: |
1159 | static const std::vector<OpSchema> get_all_schemas_with_history() { |
1160 | std::vector<OpSchema> r; |
1161 | for (auto& x : map()) { |
1162 | for (auto& y : x.second) { |
1163 | for (auto& z : y.second) { |
1164 | r.emplace_back(z.second); |
1165 | } |
1166 | } |
1167 | } |
1168 | return r; |
1169 | } |
1170 | |
1171 | static const std::vector<OpSchema> get_all_schemas() { |
1172 | std::vector<OpSchema> r; |
1173 | for (auto& x : map()) { |
1174 | for (auto& y : x.second) { |
1175 | auto& version2schema = y.second; |
1176 | r.emplace_back(version2schema.rbegin()->second); |
1177 | } |
1178 | } |
1179 | return r; |
1180 | } |
1181 | }; |
1182 | |
1183 | void RegisterSchema(OpSchema schema, int opset_version_to_load = 0); |
1184 | |
1185 | // Registers the latest opset schema before opset_version_to_load |
1186 | // By default opset_version_to_load=0 means it will register all versions |
1187 | template <class T> |
1188 | void RegisterOpSetSchema(int opset_version_to_load = 0) { |
1189 | T::ForEachSchema([opset_version_to_load](OpSchema&& schema) { RegisterSchema(schema, opset_version_to_load); }); |
1190 | }; |
1191 | |
1192 | // Forward declaration for the non-specialized GetOpSchema method. This |
1193 | // enforces a consistent signature on functions that query individual schema, |
1194 | // which are defined as specializations of this function. |
1195 | template <typename T> |
1196 | OpSchema GetOpSchema(); |
1197 | |
1198 | #define ONNX_OPERATOR_SET_SCHEMA(name, ver, impl) ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ONNX_DOMAIN, ver, true, impl) |
1199 | |
1200 | #define ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl) \ |
1201 | ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxML, AI_ONNX_ML_DOMAIN, ver, true, impl) |
1202 | |
1203 | #define ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \ |
1204 | ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxTraining, AI_ONNX_TRAINING_DOMAIN, ver, true, impl) |
1205 | |
1206 | #define ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \ |
1207 | ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxPreview, AI_ONNX_PREVIEW_TRAINING_DOMAIN, ver, true, impl) |
1208 | |
1209 | // Defines specialization of GetOpSchema for a class whose name is determined |
1210 | // based on a convention using name, domain, and version. Operator schema are |
1211 | // normally included in operator sets and registered in OpSchemaRegistry::map(). |
1212 | // In this case, callers should set dbg_included_in_static_opset to true. This |
1213 | // assists with runtime validation in DEBUG builds ensuring the intended set |
1214 | // of operator schema is registered. |
1215 | #define ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl) \ |
1216 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name); \ |
1217 | template <> \ |
1218 | OpSchema GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)>() { \ |
1219 | return impl.SetName(#name).SetDomain(domain_str).SinceVersion(ver).SetLocation(__FILE__, __LINE__); \ |
1220 | } \ |
1221 | size_t dbg_count_check_##name##_##domain##_ver##ver = \ |
1222 | (dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() : 0; |
1223 | #ifdef NDEBUG |
1224 | #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0 |
1225 | #else |
1226 | #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().IncrementCount() |
1227 | #define ONNX_DBG_GET_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().GetCount() |
1228 | |
1229 | class DbgOperatorSetTracker { |
1230 | public: |
1231 | static DbgOperatorSetTracker& Instance(); |
1232 | |
1233 | size_t IncrementCount() { |
1234 | return ++count_; |
1235 | } |
1236 | |
1237 | size_t GetCount() const { |
1238 | return count_; |
1239 | } |
1240 | |
1241 | private: |
1242 | size_t count_ = 0; |
1243 | }; |
1244 | #endif |
1245 | |
1246 | // Naming convention for operator schema classes |
1247 | #define ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name) name##_##domain##_ver##ver |
1248 | |
1249 | // Naming convention for preview operator schema classes |
1250 | #define ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name) \ |
1251 | ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxPreview, ver, name) |
1252 | |
1253 | // Helper function |
1254 | size_t ReplaceAll(std::string& s, const char* from, const char* to); |
1255 | |
1256 | #ifdef __GNUC__ |
1257 | #define ONNX_UNUSED __attribute__((__unused__)) |
1258 | #else |
1259 | #define ONNX_UNUSED |
1260 | #endif |
1261 | |
1262 | // Legacy macros to register schema at static initialization |
1263 | #define ONNX_OPERATOR_SCHEMA(name) ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name) |
1264 | #define ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name) |
1265 | #define ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name) \ |
1266 | static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce(op_schema_register_once##name##Counter) ONNX_UNUSED = \ |
1267 | OpSchema(#name, __FILE__, __LINE__) |
1268 | |
1269 | // Helper function |
1270 | size_t ReplaceAll(std::string& s, const char* from, const char* to); |
1271 | |
1272 | inline std::string GenerateOptionalArgumentsDoc() { |
1273 | return "This operator has **optional** inputs/outputs. " |
1274 | "See [the doc](IR.md) for more details about the representation of " |
1275 | "optional arguments. An empty string may be used in the place of " |
1276 | "an actual argument's name to indicate a missing argument. " |
1277 | "Trailing optional arguments (those not followed by an argument " |
1278 | "that is present) may also be simply omitted.\n" ; |
1279 | } |
1280 | |
1281 | inline std::string GenerateBroadcastingDocMul() { |
1282 | return "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**;" |
1283 | " for more details please check [the doc](Broadcasting.md)." ; |
1284 | } |
1285 | |
1286 | inline std::string GenerateBroadcastingDocUni(const char* from, const char* to) { |
1287 | std::string ret = "This operator supports **unidirectional broadcasting** (" ; |
1288 | ret = ret + from + " should be unidirectional broadcastable to " + to + |
1289 | ");" |
1290 | " for more details please check [the doc](Broadcasting.md)." ; |
1291 | return ret; |
1292 | } |
1293 | |
1294 | /* |
1295 | * Macros for setting operator documentation |
1296 | * Use this macro for simple SetDoc() calls that generate documentation |
1297 | * directly. This is the macro to use in almost all cases. |
1298 | * Sample usage guidelines: |
1299 | * const char* doc_str = "foo"; |
1300 | * SetDoc(GET_OP_DOC_STR(doc_str)) |
1301 | * |
1302 | * SetDoc(GET_OP_DOC_STR( |
1303 | std::string(BitShift_ver11_doc) + GenerateBroadcastingDocMul())) |
1304 | */ |
1305 | #ifndef __ONNX_NO_DOC_STRINGS |
1306 | #define GET_OP_DOC_STR(doc_str) (doc_str) |
1307 | #else |
1308 | #define GET_OP_DOC_STR(doc_str) ("") |
1309 | #endif |
1310 | |
1311 | /* |
1312 | * Use this macro when the documentation needs to be populated in some |
1313 | * complicated way like string substitutions, etc before calling SetDoc. |
1314 | * Sample usage guidelines: |
1315 | std::string doc; |
1316 | POPULATE_OP_DOC_STR( |
1317 | doc = R"DOC( |
1318 | Returns the tensor resulted from performing the `{name}` logical operation |
1319 | elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting |
1320 | support). |
1321 | |
1322 | {broadcast_doc} |
1323 | )DOC"; |
1324 | ReplaceAll(doc, "{name}", name); |
1325 | ReplaceAll( |
1326 | doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str());); |
1327 | schema.SetDoc(doc); |
1328 | * |
1329 | */ |
1330 | #ifndef __ONNX_NO_DOC_STRINGS |
1331 | #define POPULATE_OP_DOC_STR(DocPopulatorCode) \ |
1332 | do { \ |
1333 | DocPopulatorCode \ |
1334 | } while (0) |
1335 | #else |
1336 | #define POPULATE_OP_DOC_STR(DocPopulatorCode) |
1337 | #endif |
1338 | |
1339 | } // namespace ONNX_NAMESPACE |
1340 | |