1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#pragma once
6
7#include <climits>
8#include <cstring>
9#include <functional>
10#include <initializer_list>
11#include <iostream>
12#include <limits>
13#include <memory>
14#include <ostream>
15#include <set>
16#include <string>
17#include <tuple>
18#include <unordered_map>
19#include <unordered_set>
20#include <vector>
21
22#include "onnx/common/common.h"
23#include "onnx/common/constants.h"
24#include "onnx/defs/shape_inference.h"
25
26namespace ONNX_NAMESPACE {
27
28struct FunctionBodyBuildContext {
29 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
30 virtual bool hasInput(int inputIndex) const = 0;
31 virtual bool hasOutput(int inputIndex) const = 0;
32 // getInputType(i) should return null for missing optional inputs, or if
33 // type-inference could not infer the input-type (erroneous model).
34 virtual const TypeProto* getInputType(int inputIndex) const = 0;
35 virtual ~FunctionBodyBuildContext() {}
36};
37
38struct FunctionBodyBuildContextImpl : public FunctionBodyBuildContext {
39 // Input_types: use a default TypeProto for missing types. We use a different convention
40 // here (from FunctionBodyBuildContext) to simplify python interoperability.
41 // The default value for input_types is included only for backward compatibility.
42 // It can be used for functions that do not depend on the type-context, but
43 // will not be sufficient for functions that do use the type-context.
44 FunctionBodyBuildContextImpl(const NodeProto& node_proto, const std::vector<TypeProto>& input_types = {})
45 : node_proto_(node_proto), input_types_(input_types) {
46 for (auto& attr : node_proto.attribute()) {
47 attributesByName_[attr.name()] = &attr;
48 }
49 }
50
51 const AttributeProto* getAttribute(const std::string& name) const override {
52 auto iter = attributesByName_.find(name);
53 if (iter == attributesByName_.end()) {
54 return nullptr;
55 } else {
56 return iter->second;
57 }
58 }
59
60 bool hasInput(int inputIndex) const override {
61 if (inputIndex >= node_proto_.input_size())
62 return false;
63 return node_proto_.input(inputIndex) != "";
64 }
65
66 bool hasOutput(int inputIndex) const override {
67 if (inputIndex >= node_proto_.output_size())
68 return false;
69 return node_proto_.output(inputIndex) != "";
70 }
71
72 const TypeProto* getInputType(int inputIndex) const override {
73 if (inputIndex < 0)
74 return nullptr;
75 size_t j = static_cast<size_t>(inputIndex);
76 if (j >= input_types_.size())
77 return nullptr;
78 // Convert default value (no variant set) into null.
79 if (input_types_[j].value_case() == TypeProto::ValueCase::VALUE_NOT_SET)
80 return nullptr;
81 return &input_types_[j];
82 }
83
84 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
85
86 NodeProto node_proto_;
87 std::vector<TypeProto> input_types_;
88};
89
90using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>;
91
92class OpSchema;
93using ContextDependentFunctionBodyBuilder =
94 std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>;
95
96class SchemaError final : public std::runtime_error {
97 public:
98 using std::runtime_error::runtime_error;
99
100 SchemaError(const std::string& message) : std::runtime_error(message) {}
101
102 const char* what() const noexcept override {
103 if (!expanded_message_.empty()) {
104 return expanded_message_.c_str();
105 }
106 return std::runtime_error::what();
107 }
108
109 void AppendContext(const std::string& context) {
110 expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
111 }
112
113 private:
114 std::string expanded_message_;
115};
116
117#define fail_schema(...) ONNX_THROW_EX(ONNX_NAMESPACE::SchemaError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
118
119using OperatorSetVersion = int;
120
121using DataTypeSet = std::unordered_set<DataType>;
122
123// Type constraint map. Key is type string. Value is data type set and
124// description.
125using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>;
126
127/**
128 * @brief A class to record the schema of an op.
129 *
130 * OpSchema records the common interface of an op specified by its name.
131 *
132 * To register an OpSchema, one can use the macro ONNX_OPERATOR_SCHEMA(name) and
133 * then append the various functions in the class. For example, for an op
134 * that takes in two inputs, one output, and the first input and output
135 * could be in-place, can be written as
136 *
137 * ONNX_OPERATOR_SCHEMA(name)
138 * .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}});
139 *
140 * To manufacture methods that may be used to register an OpSchema
141 * non-statically, the following may be used:
142 *
143 * ONNX_OPERATOR_SET_SCHEMA(name, version, OpSchema()
144 * .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}}));
145 */
146class OpSchema final {
147 public:
148 static constexpr int kUninitializedSinceVersion = -1;
149 // Formal parameter options.
150 enum FormalParameterOption : uint8_t {
151 // The formal parameter is single and not optional.
152 // Number of supplied actual parameters must be 1.
153 Single = 0,
154 // The formal parameter is single and optional.
155 // Number of supplied actual parameters may be 0 or 1.
156 Optional = 1,
157 // The formal parameter is variadic.
158 // Number of supplied actual parameters must be N or more, where
159 // the minimum value N is indicated separately (default value 1).
160 Variadic = 2,
161 };
162 enum DifferentiationCategory : uint8_t {
163 // Whether this formal parameter is differentiable or not cannot
164 // be statically determined. It also covers variadic formal
165 // parameters which contain both of differentiable and
166 // non-differentiable variables.
167 Unknown = 0,
168 // This formal parameter is differentiable. That is, this formal
169 // parameter can be differentiable input of Gradient operator.
170 Differentiable = 1,
171 // This formal parameter is not differentiable. That is, this formal
172 // parameter can not be differentiable input of Gradient operator.
173 NonDifferentiable = 2
174 };
175
176 // Formal parameter represenation, including input/output name, typeStr,
177 // description, and type constraints.
178 class FormalParameter final {
179 public:
180 // Constructor.
181 FormalParameter() = default;
182
183 explicit FormalParameter(
184 std::string name,
185 DataTypeSet allowed_type_set,
186 std::string type_str,
187 const std::string& description,
188 FormalParameterOption param_option = Single,
189 bool is_homogeneous = true,
190 int min_arity = 1,
191 DifferentiationCategory differentiation_category = Unknown)
192 : name_(std::move(name)),
193 type_set_(std::move(allowed_type_set)),
194 type_str_(std::move(type_str)),
195#ifndef __ONNX_NO_DOC_STRINGS
196 description_(description),
197#endif
198 param_option_(param_option),
199 is_homogeneous_(is_homogeneous),
200 min_arity_(min_arity),
201 differentiation_category_(differentiation_category) {
202#ifdef __ONNX_NO_DOC_STRINGS
203 ONNX_UNUSED_PARAMETER(description);
204#endif
205 }
206
207 explicit FormalParameter(
208 std::string name,
209 const std::string& description,
210 std::string type_str,
211 FormalParameterOption param_option = Single,
212 bool is_homogeneous = true,
213 int min_arity = 1,
214 DifferentiationCategory differentiation_category = Unknown)
215 : name_(std::move(name)),
216 type_str_(std::move(type_str)),
217#ifndef __ONNX_NO_DOC_STRINGS
218 description_(description),
219#endif
220 param_option_(param_option),
221 is_homogeneous_(is_homogeneous),
222 min_arity_(min_arity),
223 differentiation_category_(differentiation_category) {
224#ifdef __ONNX_NO_DOC_STRINGS
225 ONNX_UNUSED_PARAMETER(description);
226#endif
227 }
228
229 // Get formal parameter name.
230 const std::string& GetName() const;
231
232 // Get allowed data types.
233 const DataTypeSet& GetTypes() const;
234
235 // Get formal parameter type string.
236 const std::string& GetTypeStr() const;
237
238 // Get formal parameter description.
239 const std::string& GetDescription() const;
240
241 // Get the parameter option, it could be Single, Optional or Variadic.
242 FormalParameterOption GetOption() const;
243
244 // Get whether a variadic parameter requires all to be of same type
245 bool GetIsHomogeneous() const;
246
247 // Get minimum arity. Applicable only in the Variadic case.
248 int GetMinArity() const;
249
250 // Get the differentiation property of this formal parameter.
251 DifferentiationCategory GetDifferentiationCategory() const;
252
253 private:
254 friend class OpSchema;
255
256 DataTypeSet& MutableTypes();
257
258 // Formal parameter name.
259 std::string name_;
260
261 // A set of data types supported for <*this> formal parameter.
262 // It should contain at least one element if this formal parameter is good.
263 DataTypeSet type_set_;
264
265 // The <parameter type> string specified when registring an op.
266 // It could be a supported data type or a type constraint key, which
267 // maps to a set of supported data types.
268 std::string type_str_;
269
270 // Formal parameter description.
271 std::string description_;
272
273 // Formal parameter option.
274 FormalParameterOption param_option_;
275
276 // For variadic parameters, a flag indicating if all parameters must be of
277 // same type
278 bool is_homogeneous_;
279
280 // Minimum number of parameters expected. Applicable only for Variadic.
281 int min_arity_;
282
283 // True if this parameter can be an differentiable inputs of Gradient.
284 // Otherwise, using this parameter as an differentiable inputs of Gradient
285 // is prohibited.
286 DifferentiationCategory differentiation_category_;
287 };
288
289 enum class SupportType : uint8_t {
290 COMMON, // Supported by all frameworks that support this IR.
291 EXPERIMENTAL, // This OP is experimental and can be changed or removed in
292 // the future.
293 };
294
295 OpSchema() : OpSchema("unknown", "unknown", 0) {}
296 OpSchema(std::string name, std::string file, int line)
297 : name_(std::move(name)), file_(std::move(file)), line_(line), support_(SupportType::COMMON) {}
298
299 /**
300 * @brief Returns the file that the op schema is registered from.
301 */
302 const std::string& file() const {
303 return file_;
304 }
305
306 /**
307 * @brief Returns the line in file that the op schema is registered from.
308 */
309 int line() const {
310 return line_;
311 }
312
313 /**
314 * @brief Returns the support level of the op schema.
315 */
316 SupportType support_level() const {
317 return support_;
318 }
319
320 /**
321 * @brief Returns the docstring of the op schema.
322 */
323 const char* doc() const {
324 return doc_.empty() ? nullptr : doc_.c_str();
325 }
326
327 // Check if input and output types fall into valid set and match each other
328 void CheckInputOutputType(struct InferenceContext&) const;
329
330 /**
331 * @brief Verifies if a NodeProto matches the pattern specified in
332 * the schema.
333 */
334 void Verify(const NodeProto& node) const;
335
336 // Functions to set the property of the operator schemas.
337 // Sets the number of inputs, either a fixed number or a min and a max.
338
339 /**
340 * The earliest operator set version which this operator was
341 * present in. If an operator has had no BC-breaking changes,
342 * this is simply the first operator set the operator was a member
343 * of; if it has had BC-breaking changes, then for the semantics
344 * /as described/ in the OpSchema entry, this version describes
345 * the operator set which introduced the BC-breaking change.
346 *
347 * For example, suppose op Foo was added in v3, and had a BC-breaking
348 * change in v6. Then there will be an op schema entry for Foo with
349 * SinceVersion(3), and another, updated op schema entry for Foo
350 * with SinceVersion(6).
351 */
352 OpSchema& SinceVersion(OperatorSetVersion n); // aka int
353
354 /**
355 * Marks this op as deprecated as of it's since_version. This will cause the
356 * Schema() lookup functions to return nullptr when the version is in the
357 * deprecated range.
358 */
359 OpSchema& Deprecate();
360
361 bool Deprecated() const {
362 return deprecated_;
363 }
364
365 /**
366 * @brief Input could be one of the values specified in allowed_input_nums.
367 */
368 OpSchema& NumInputs(std::set<int> allowed_input_nums);
369
370 /**
371 * @brief Output could be one of the values specified in allowed_output_nums.
372 */
373 OpSchema& NumOutputs(std::set<int> allowed_output_nums);
374
375 // Shape Inference
376 //
377 // Note that signatures are defined to allow for forward-declaring
378 // any structs used from ir.h
379 OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
380 InferenceFunction GetTypeAndShapeInferenceFunction() const {
381 return tensor_inference_function_ ? tensor_inference_function_ : dummyInferenceFunction;
382 }
383
384 OpSchema& PartialDataPropagationFunction(DataPropagationFunction dataProgationFunction);
385 DataPropagationFunction GetDataPropagationFunction() const {
386 return data_propagation_function_ ? data_propagation_function_ : dummyDataPropagationFunction;
387 }
388
389 // Set the support level for the op schema.
390 OpSchema& SetSupportLevel(SupportType supportType);
391
392 // Functions to do documentation for the operator schema.
393 // This may be disabled to save memory.
394 OpSchema& SetDoc(const char* doc) {
395#ifndef __ONNX_NO_DOC_STRINGS
396 SetDoc(std::string(doc));
397#else
398 ONNX_UNUSED_PARAMETER(doc);
399#endif
400
401 return *this;
402 }
403
404 OpSchema& SetDoc(const std::string& doc) {
405#ifndef __ONNX_NO_DOC_STRINGS
406 doc_ = doc;
407#else
408 ONNX_UNUSED_PARAMETER(doc);
409#endif
410 return *this;
411 }
412
413 // Functions to specify name for the operator schema.
414 OpSchema& SetName(const char* name);
415 OpSchema& SetName(std::string name);
416
417 // Functions to specify code location for the operator schema.
418 OpSchema& SetLocation(const char* file, int line);
419 OpSchema& SetLocation(std::string file, int line);
420
421 // Functions to specify domain for the operator schema.
422 // Default domain value (ONNX_DOMAIN) means it's ONNX domain.
423 OpSchema& SetDomain(const char* domain);
424 OpSchema& SetDomain(std::string domain);
425
426 struct Attribute final {
427 Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)
428 : name(std::move(name_)),
429 description(std::move(description_)),
430 type(type_),
431 required(required_),
432 default_value() {}
433
434 Attribute(std::string name_, std::string description_, AttributeProto default_value_)
435 : name(std::move(name_)),
436 description(std::move(description_)),
437 type(default_value_.type()),
438 required(false),
439 default_value(std::move(default_value_)) {}
440
441 const std::string name;
442 const std::string description;
443 AttributeProto::AttributeType type;
444 bool required;
445 AttributeProto default_value;
446 };
447
448 OpSchema& Attr(Attribute attr);
449
450// Register "optional" attribute with default value.
451#define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName) \
452 OpSchema& Attr( \
453 std::string name, std::string description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
454 /* non-STL wrapper to reduce binary size */ \
455 OpSchema& Attr( \
456 const char* name, const char* description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
457 OpSchema& Attr( \
458 std::string name, \
459 std::string description, \
460 AttributeProto::AttributeType type, \
461 const std::vector<TypeName>& defaultValue);
462
463 ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t)
464 ATTR_SETTER_WITH_DEFAULT_VALUE(float)
465 ATTR_SETTER_WITH_DEFAULT_VALUE(std::string)
466 ATTR_SETTER_WITH_DEFAULT_VALUE(TensorProto)
467 ATTR_SETTER_WITH_DEFAULT_VALUE(GraphProto)
468 ATTR_SETTER_WITH_DEFAULT_VALUE(TypeProto)
469
470 // Register "required" attribute without default value.
471 OpSchema& Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true);
472
473 // Non-STL wrapper to reduce binary size
474 OpSchema& Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required = true);
475
476 OpSchema& AllowUncheckedAttributes();
477
478 // Type constraint.
479 struct TypeConstraintParam final {
480 TypeConstraintParam(
481 std::string type_param_str_,
482 std::vector<std::string> allowed_type_strs_,
483 std::string description_)
484 : type_param_str(std::move(type_param_str_)),
485 allowed_type_strs(std::move(allowed_type_strs_)),
486 description(std::move(description_)) {}
487
488 // Type parameter string, for example, "T", "T1", etc.
489 std::string type_param_str;
490 // Allowed type strings for <*this> type parameter, for example,
491 // "tensor(float)".
492 std::vector<std::string> allowed_type_strs;
493 // Type parameter description.
494 std::string description;
495 };
496
497 // Grammar for type strings used in Input(), Output().
498 // <type> ::= <data_type> |
499 // tensor(<data_type>) |
500 // seq(<type>) |
501 // map(<data_type>, <type>) |
502 // <type_parameter>
503 // <data_type> :: = float | int32 | string | bool | uint8
504 // | int8 | uint16 | int16 | int64 | float16 | double
505 // <type_parameter> ::= any type parameter string, say "T".
506 //
507 // NOTE: 1) <type_parameter> will always be together with a type constraints
508 // specification.
509 // 2) <type> ::= <data_type> means the data is scalar (zero dimension).
510 //
511 // Example:
512 // ONNX_OPERATOR_SET_SCHEMA(Sum, 1, OpSchema()
513 // .Input(0, "input_a", "the first input", "T")
514 // .Input(1, "input_b", "the second input", "T")
515 // .Output(0, "sum", "the sum of two numbers", "T")
516 // .TypeConstraint("T", {"float", "double", "int32"}, "allowed data types for
517 // sum."))
518 //
519 // Optional = true means that the input might have empty input value
520 // (represented as "") in the graph even though the later inputs have values.
521 // It's useful for complex situation when there are several independent
522 // optional inputs.
523 OpSchema& Input(
524 int n,
525 std::string name,
526 const std::string& description,
527 std::string type_str,
528 FormalParameterOption param_option = Single,
529 bool is_homogeneous = true,
530 int min_arity = 1,
531 DifferentiationCategory differentiation_category = Unknown);
532
533 // Non-STL wrapper to reduce binary size
534 OpSchema& Input(
535 int n,
536 const char* name,
537 const char* description,
538 const char* type_str,
539 FormalParameterOption param_option = Single,
540 bool is_homogeneous = true,
541 int min_arity = 1,
542 DifferentiationCategory differentiation_category = Unknown);
543
544 OpSchema& Output(
545 int n,
546 std::string name,
547 const std::string& description,
548 std::string type_str,
549 FormalParameterOption param_option = Single,
550 bool is_homogeneous = true,
551 int min_arity = 1,
552 DifferentiationCategory differentiation_category = Unknown);
553
554 // Non-STL wrapper to reduce binary size
555 OpSchema& Output(
556 int n,
557 const char* name,
558 const char* description,
559 const char* type_str,
560 FormalParameterOption param_option = Single,
561 bool is_homogeneous = true,
562 int min_arity = 1,
563 DifferentiationCategory differentiation_category = Unknown);
564
565 OpSchema& TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description);
566
567 // Non-STL wrapper to reduce binary size
568 OpSchema&
569 TypeConstraint(const char* type_str, std::initializer_list<const char*> constraints, const char* description);
570
571 // Convenience members for types
572
573 // All high-precision numeric types.
574 static const std::vector<std::string>& numeric_types_for_math_reduction_with_bfloat() {
575 static const std::vector<std::string> numeric_types_for_math_reduction_with_bfloat = {
576 "tensor(uint32)",
577 "tensor(uint64)",
578 "tensor(int32)",
579 "tensor(int64)",
580 "tensor(float16)",
581 "tensor(float)",
582 "tensor(double)",
583 "tensor(bfloat16)"};
584 return numeric_types_for_math_reduction_with_bfloat;
585 }
586
587 static const std::vector<std::string>& numeric_types_for_math_reduction() {
588 static const std::vector<std::string> numeric_types_for_math_reduction = {
589 "tensor(uint32)",
590 "tensor(uint64)",
591 "tensor(int32)",
592 "tensor(int64)",
593 "tensor(float16)",
594 "tensor(float)",
595 "tensor(double)"};
596 return numeric_types_for_math_reduction;
597 }
598
599 static const std::vector<std::string>& all_numeric_types_with_bfloat() {
600 static const std::vector<std::string> all_numeric_types_with_bfloat = {
601 "tensor(uint8)",
602 "tensor(uint16)",
603 "tensor(uint32)",
604 "tensor(uint64)",
605 "tensor(int8)",
606 "tensor(int16)",
607 "tensor(int32)",
608 "tensor(int64)",
609 "tensor(float16)",
610 "tensor(float)",
611 "tensor(double)",
612 "tensor(bfloat16)"};
613 return all_numeric_types_with_bfloat;
614 }
615
616 static const std::vector<std::string>& all_numeric_types() {
617 static const std::vector<std::string> all_numeric_types = {
618 "tensor(uint8)",
619 "tensor(uint16)",
620 "tensor(uint32)",
621 "tensor(uint64)",
622 "tensor(int8)",
623 "tensor(int16)",
624 "tensor(int32)",
625 "tensor(int64)",
626 "tensor(float16)",
627 "tensor(float)",
628 "tensor(double)"};
629 return all_numeric_types;
630 }
631
632 static const std::vector<std::string>& all_numeric_sequence_types() {
633 static const std::vector<std::string> all_numeric_sequence_types = {
634 "seq(tensor(uint8))",
635 "seq(tensor(uint16))",
636 "seq(tensor(uint32))",
637 "seq(tensor(uint64))",
638 "seq(tensor(int8))",
639 "seq(tensor(int16))",
640 "seq(tensor(int32))",
641 "seq(tensor(int64))",
642 "seq(tensor(float16))",
643 "seq(tensor(float))",
644 "seq(tensor(double))"};
645 return all_numeric_sequence_types;
646 }
647
648 static const std::vector<std::string>& all_tensor_types() {
649 static const std::vector<std::string> all_tensor_types = {
650 "tensor(uint8)",
651 "tensor(uint16)",
652 "tensor(uint32)",
653 "tensor(uint64)",
654 "tensor(int8)",
655 "tensor(int16)",
656 "tensor(int32)",
657 "tensor(int64)",
658 "tensor(float16)",
659 "tensor(float)",
660 "tensor(double)",
661 "tensor(string)",
662 "tensor(bool)",
663 "tensor(complex64)",
664 "tensor(complex128)"};
665 return all_tensor_types;
666 }
667
668 static const std::vector<std::string>& all_tensor_types_with_bfloat() {
669 static const std::vector<std::string> all_tensor_types_with_bfloat = {
670 "tensor(uint8)",
671 "tensor(uint16)",
672 "tensor(uint32)",
673 "tensor(uint64)",
674 "tensor(int8)",
675 "tensor(int16)",
676 "tensor(int32)",
677 "tensor(int64)",
678 "tensor(bfloat16)",
679 "tensor(float16)",
680 "tensor(float)",
681 "tensor(double)",
682 "tensor(string)",
683 "tensor(bool)",
684 "tensor(complex64)",
685 "tensor(complex128)"};
686 return all_tensor_types_with_bfloat;
687 }
688
689 static const std::vector<std::string>& all_tensor_sequence_types() {
690 static const std::vector<std::string> all_tensor_sequence_types = {
691 "seq(tensor(uint8))",
692 "seq(tensor(uint16))",
693 "seq(tensor(uint32))",
694 "seq(tensor(uint64))",
695 "seq(tensor(int8))",
696 "seq(tensor(int16))",
697 "seq(tensor(int32))",
698 "seq(tensor(int64))",
699 "seq(tensor(float16))",
700 "seq(tensor(float))",
701 "seq(tensor(double))",
702 "seq(tensor(string))",
703 "seq(tensor(bool))",
704 "seq(tensor(complex64))",
705 "seq(tensor(complex128))"};
706 return all_tensor_sequence_types;
707 }
708
709 static const std::vector<std::string>& all_tensor_sequence_types_with_bfloat() {
710 static const std::vector<std::string> all_tensor_sequence_types_with_bfloat = {
711 "seq(tensor(uint8))",
712 "seq(tensor(uint16))",
713 "seq(tensor(uint32))",
714 "seq(tensor(uint64))",
715 "seq(tensor(int8))",
716 "seq(tensor(int16))",
717 "seq(tensor(int32))",
718 "seq(tensor(int64))",
719 "seq(tensor(bfloat16))",
720 "seq(tensor(float16))",
721 "seq(tensor(float))",
722 "seq(tensor(double))",
723 "seq(tensor(string))",
724 "seq(tensor(bool))",
725 "seq(tensor(complex64))",
726 "seq(tensor(complex128))"};
727 return all_tensor_sequence_types_with_bfloat;
728 }
729
730 static const std::vector<std::string>& all_optional_types() {
731 static const std::vector<std::string> all_optional_types = {
732 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
733 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
734 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(float16)))",
735 "optional(seq(tensor(float)))", "optional(seq(tensor(double)))", "optional(seq(tensor(string)))",
736 "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))", "optional(seq(tensor(complex128)))",
737 "optional(tensor(uint8))", "optional(tensor(uint16))", "optional(tensor(uint32))",
738 "optional(tensor(uint64))", "optional(tensor(int8))", "optional(tensor(int16))",
739 "optional(tensor(int32))", "optional(tensor(int64))", "optional(tensor(float16))",
740 "optional(tensor(float))", "optional(tensor(double))", "optional(tensor(string))",
741 "optional(tensor(bool))", "optional(tensor(complex64))", "optional(tensor(complex128))"};
742 return all_optional_types;
743 }
744
745 static const std::vector<std::string>& all_optional_types_with_bfloat() {
746 static const std::vector<std::string> all_optional_types = {
747 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
748 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
749 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(bfloat16)))",
750 "optional(seq(tensor(float16)))", "optional(seq(tensor(float)))", "optional(seq(tensor(double)))",
751 "optional(seq(tensor(string)))", "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))",
752 "optional(seq(tensor(complex128)))", "optional(tensor(uint8))", "optional(tensor(uint16))",
753 "optional(tensor(uint32))", "optional(tensor(uint64))", "optional(tensor(int8))",
754 "optional(tensor(int16))", "optional(tensor(int32))", "optional(tensor(int64))",
755 "optional(tensor(bfloat16))", "optional(tensor(float16))", "optional(tensor(float))",
756 "optional(tensor(double))", "optional(tensor(string))", "optional(tensor(bool))",
757 "optional(tensor(complex64))", "optional(tensor(complex128))"};
758 return all_optional_types;
759 }
760
761 // Calls the passed function with `this` as an argument. Useful for
762 // adding docs for temlated/macro ops.
763 OpSchema& FillUsing(const std::function<void(OpSchema&)>& populator);
764
765 friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema);
766
767 const std::string& domain() const {
768 return domain_;
769 }
770
771 const std::map<std::string, Attribute>& attributes() const {
772 return attributes_;
773 }
774
775 // Get input formal parameters.
776 const std::vector<FormalParameter>& inputs() const {
777 return inputs_;
778 }
779
780 // Get output formal parameters.
781 const std::vector<FormalParameter>& outputs() const {
782 return outputs_;
783 }
784
785 const std::vector<TypeConstraintParam>& typeConstraintParams() const {
786 return type_constraint_params_;
787 }
788
789 const std::string& Name() const {
790 return name_;
791 }
792
793 OperatorSetVersion SinceVersion() const {
794 return since_version_;
795 }
796
797 int since_version() const {
798 return since_version_;
799 }
800
801 bool deprecated() const {
802 return deprecated_;
803 }
804
805 int min_input() const {
806 return min_input_;
807 }
808 int max_input() const {
809 return max_input_;
810 }
811 int min_output() const {
812 return min_output_;
813 }
814 int max_output() const {
815 return max_output_;
816 }
817
818 bool has_type_and_shape_inference_function() const {
819 return tensor_inference_function_ ? true : false;
820 }
821
822 bool has_data_propagation_function() const {
823 return data_propagation_function_ ? true : false;
824 }
825
826 std::vector<int> function_opset_versions() const {
827 std::vector<int> opset_versions;
828 std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it = opset_version_to_function_body_.cbegin();
829 for (; it != opset_version_to_function_body_.cend(); ++it) {
830 opset_versions.push_back(it->first);
831 }
832 return opset_versions;
833 }
834
835 bool HasFunction() const {
836 return !opset_version_to_function_body_.empty();
837 }
838
839 OpSchema& FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version = kUninitializedSinceVersion);
840
841 OpSchema& FunctionBody(
842 const std::vector<NodeProto>& func_nodes,
843 const std::vector<OperatorSetIdProto>& opsets,
844 int opset_version = kUninitializedSinceVersion);
845
846 OpSchema& FunctionBody(const char* func_body, int opset_version = kUninitializedSinceVersion);
847
848 // since_version_ of an OpSchema tells the last opset version when an op is defined.
849 // When the op's definition is changed, a new OpSchema (of the same op_type) is created
850 // with a newer since_version_, reflecting the opset version at the time of change.
851 // For a function op, operators used to define its function body may change
852 // while there is no change to the function op definition itself.
853 // When this happens, mutiple function bodies are provided, each for a specific opset version.
854 //
855 // Take LogSoftmax for example. Its latest opset version is 13.
856 // In LogSoftmax's function body, ReduceMax (with since_version_ 1, 11, 12, 18) is used.
857 // When a model containing LogSoftmax with opset_import version within 13 to 17 is loaded, function body
858 // with opset_version 13 is used for inlining.
859 // When the same model but opset_import version 18 is loaded, function body
860 // with opset_version 18 is used for inlining.
861 // Clearly function body for opset_import version 13 will not work
862 // in a model with opset_import version 18 because the function body make worng use of ReduceMax(18).
863 // Inside GetFunction we ensure that ops being used to construct a function body do not endure such
864 // issue.
865 const FunctionProto* GetFunction(
866 int requested_opset_version = OpSchema::kUninitializedSinceVersion,
867 bool validate = false) const;
868
869 std::vector<int> context_dependent_function_opset_versions() const {
870 std::vector<int> opset_versions;
871 std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it = opset_version_to_function_builder_.cbegin();
872 for (; it != opset_version_to_function_builder_.cend(); ++it) {
873 opset_versions.push_back(it->first);
874 }
875 return opset_versions;
876 }
877
878 bool HasContextDependentFunction() const {
879 return !opset_version_to_function_builder_.empty();
880 }
881
882 bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const {
883 return opset_version_to_function_builder_.find(opset_version) != opset_version_to_function_builder_.end();
884 }
885
886 OpSchema& SetContextDependentFunctionBodyBuilder(
887 ContextDependentFunctionBodyBuilder,
888 int opset_version = kUninitializedSinceVersion);
889
890 bool BuildContextDependentFunction(
891 const FunctionBodyBuildContext& ctx,
892 FunctionProto& function_proto,
893 int requested_opset_version = OpSchema::kUninitializedSinceVersion) const;
894
895 // Verifies that the schema is valid and all specifications are compatible.
896 // It will also parse all type strings specified for inputs/outputs into valid
897 // TypeProto and create global unique string pointer as the DataType for
898 // efficiency.
899 void Finalize();
900
901 // Build function with information stored in opschema
902 void BuildFunction(FunctionProto& function_body) const;
903
904 private:
905 void ParseAndSetTypes(
906 /*out*/ std::vector<OpSchema::FormalParameter>* formalParameters);
907 bool ValidateReferencedOpsInFuncton(
908 const FunctionProto* function,
909 int requested_opset_version,
910 int function_since_version,
911 std::set<std::string>* updated_ops = nullptr) const;
912 void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const;
913
914 std::string name_;
915 std::string file_;
916 std::string doc_;
917 // Default domain value ("") means it's ONNX domain.
918 std::string domain_ = ONNX_DOMAIN;
919 std::map<std::string, Attribute> attributes_{};
920 bool allows_unchecked_attributes_ = false;
921 std::vector<FormalParameter> inputs_;
922 std::vector<FormalParameter> outputs_;
923 std::vector<TypeConstraintParam> type_constraint_params_;
924 TypeConstraintMap type_constraints_;
925 int line_ = 0;
926 SupportType support_;
927 int min_input_ = 0;
928 int max_input_ = 0;
929 int min_output_ = 0;
930 int max_output_ = 0;
931 // The default is a little goofy, since it is never what you want
932 OperatorSetVersion since_version_ = kUninitializedSinceVersion;
933 bool deprecated_{};
934 std::function<bool(int)> num_inputs_allowed_ = [](int) { return true; };
935 std::function<bool(int)> num_outputs_allowed_ = [](int) { return true; };
936 InferenceFunction tensor_inference_function_;
937 DataPropagationFunction data_propagation_function_;
938
939 std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_;
940 std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_;
941};
942
943// Map type to store operator schemas. The format is,
944// <OpName, <Domain, <OperatorSetVersion, OpSchema>>>.
945using OpName_Domain_Version_Schema_Map =
946 std::unordered_map<std::string, std::unordered_map<std::string, std::map<OperatorSetVersion, OpSchema>>>;
947
948class ISchemaRegistry {
949 public:
950 virtual ~ISchemaRegistry() = default;
951
952 virtual const OpSchema*
953 GetSchema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) const = 0;
954};
955
956/**
957 * @brief A registry to hold all the operator schemas.
958 */
959class OpSchemaRegistry final : public ISchemaRegistry {
960 public:
961 // A singleton class to store domain to min/max op_set version map, as well as
962 // domain to last-release op_set version map.
963 class DomainToVersionRange final {
964 public:
965 DomainToVersionRange() {
966 // Increase the highest version when you make BC-breaking changes to the
967 // operator schema on specific domain. Update the lowest version when it's
968 // determined to remove too old version history.
969 map_[ONNX_DOMAIN] = std::make_pair(1, 18);
970 map_[AI_ONNX_ML_DOMAIN] = std::make_pair(1, 3);
971 map_[AI_ONNX_TRAINING_DOMAIN] = std::make_pair(1, 1);
972 // ONNX's preview domain contains operators subject to change, so
973 // versining is not meaningful and that domain should have only one
974 // version.
975 map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1);
976 // Version corresponding last release of ONNX. Update this to match with
977 // the max version above in a *release* version of ONNX. But in other
978 // versions, the max version may be ahead of the last-release-version.
979 last_release_version_map_[ONNX_DOMAIN] = 18;
980 last_release_version_map_[AI_ONNX_ML_DOMAIN] = 3;
981 last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1;
982 last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1;
983 }
984
985 const std::unordered_map<std::string, std::pair<int, int>>& Map() const {
986 return map_;
987 }
988
989 const std::unordered_map<std::string, int>& LastReleaseVersionMap() const {
990 return last_release_version_map_;
991 }
992
993 // Add customized domain to min/max version.
994 // Onnx partners are able to use onnx operator schema api to
995 // register customized op in their own domain.
996 // Can optionally specify last_release_version (to make it similar to
997 // standard ONNX domains as above). Custom-domains are free to interpret
998 // this as appropriate (that is, as relative to releases of custom-domain
999 // as opposed to ONNX releases).
1000 void
1001 AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
1002 std::lock_guard<std::mutex> lock(mutex_);
1003 assert(map_.end() == map_.find(domain));
1004 map_[domain] = std::make_pair(min_version, max_version);
1005 // If a last-release-version is not explicitly specified, use max as
1006 // last-release-version.
1007 if (last_release_version == -1)
1008 last_release_version = max_version;
1009 assert(last_release_version_map_.end() == last_release_version_map_.find(domain));
1010 last_release_version_map_[domain] = last_release_version;
1011 }
1012
1013 static DomainToVersionRange& Instance();
1014
1015 private:
1016 // Key: domain. Value: <lowest version, highest version> pair.
1017 std::unordered_map<std::string, std::pair<int, int>> map_;
1018
1019 // Key: domain. Value: most recent release opset version. Note that
1020 // the highest opset version may be ahead of the most recent release's opset
1021 // version.
1022 std::unordered_map<std::string, int> last_release_version_map_;
1023
1024 std::mutex mutex_;
1025 };
1026
1027 class OpSchemaRegisterOnce final {
1028 public:
1029 OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0) {
1030 ONNX_TRY {
1031 op_schema.Finalize();
1032 auto& m = GetMapWithoutEnsuringRegistration();
1033 auto& op_name = op_schema.Name();
1034 auto& op_domain = op_schema.domain();
1035 auto ver = op_schema.SinceVersion();
1036 if (OpSchema::kUninitializedSinceVersion == ver) {
1037 op_schema.SinceVersion(1);
1038 ver = op_schema.SinceVersion();
1039 }
1040 // Stops because the opset_version is higher than opset_version_to_load
1041 if (opset_version_to_load != 0 && ver > opset_version_to_load) {
1042 return;
1043 }
1044 if (m[op_name][op_domain].count(ver)) {
1045 const auto& schema = m[op_name][op_domain][ver];
1046 std::stringstream err;
1047 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1048 << ") from file " << op_schema.file() << " line " << op_schema.line()
1049 << ", but it is already registered from file " << schema.file() << " line " << schema.line() << std::endl;
1050 fail_schema(err.str());
1051 }
1052 // Return early if schema for the targeted opset version has already been loaded
1053 if (opset_version_to_load != 0 && !m[op_name][op_domain].empty()) {
1054 return;
1055 }
1056 auto ver_range_map = DomainToVersionRange::Instance().Map();
1057 auto ver_range_it = ver_range_map.find(op_domain);
1058 if (ver_range_it == ver_range_map.end()) {
1059 std::stringstream err;
1060 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1061 << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its domain is not"
1062 << " known by the checker." << std::endl;
1063
1064 fail_schema(err.str());
1065 }
1066 auto lower_bound_incl = ver_range_it->second.first;
1067 auto upper_bound_incl = ver_range_it->second.second;
1068 if (!(lower_bound_incl <= ver && upper_bound_incl >= ver)) {
1069 std::stringstream err;
1070 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1071 << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its version is not "
1072 << "in the inclusive range [" << lower_bound_incl << ", " << upper_bound_incl
1073 << "] (usually, this means you "
1074 << "bumped the operator version but "
1075 << "forgot to update the version range in DomainToVersionRange "
1076 << "in onnx/defs/schema.h)." << std::endl;
1077 fail_schema(err.str());
1078 }
1079
1080 m[op_name][op_domain].insert(std::pair<int, OpSchema&&>(ver, std::move(op_schema)));
1081 }
1082 ONNX_CATCH(const std::exception& e) {
1083 ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; });
1084 }
1085 }
1086 };
1087
1088 // Return the latest schema for an operator in specified domain.
1089 // Domain with default value ONNX_DOMAIN means ONNX.
1090 static const OpSchema* Schema(const std::string& key, const std::string& domain = ONNX_DOMAIN) {
1091 auto& m = map();
1092 if (m.count(key) && m[key].count(domain)) {
1093 return &m[key][domain].rbegin()->second;
1094 } else {
1095 return nullptr;
1096 }
1097 }
1098
1099 // Return the schema with biggest version, which is not greater than specified
1100 // <maxInclusiveVersion> in specified domain. Domain with default value
1101 // ONNX_DOMAIN means ONNX.
1102 static const OpSchema*
1103 Schema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) {
1104 auto& m = map();
1105 if (m.count(key) && m[key].count(domain)) {
1106 auto pos = m[key][domain].lower_bound(maxInclusiveVersion);
1107 if (m[key][domain].begin() == pos && pos->first > maxInclusiveVersion) {
1108 // All versions are greater than specified version.
1109 return nullptr;
1110 }
1111 if (m[key][domain].end() == pos || pos->first > maxInclusiveVersion) {
1112 // All versions are less than specified version, or,
1113 // The <pos> version is greater than specified version.
1114 pos--;
1115 }
1116
1117 // Schema with exact version as specified one exists.
1118 return &(pos->second);
1119 } else {
1120 return nullptr;
1121 }
1122 }
1123
1124 static OpSchemaRegistry* Instance();
1125
1126 const OpSchema* GetSchema(
1127 const std::string& key,
1128 const int maxInclusiveVersion,
1129 const std::string& domain = ONNX_DOMAIN) const override {
1130 return Schema(key, maxInclusiveVersion, domain);
1131 }
1132 static void SetLoadedSchemaVersion(int target_version) {
1133 loaded_schema_version = target_version;
1134 }
1135 static int GetLoadedSchemaVersion() {
1136 return loaded_schema_version;
1137 }
1138
1139 private:
1140 // OpSchemaRegistry should not need to be instantiated except statically
1141 // within this class
1142 OpSchemaRegistry() = default;
1143
1144 /**
1145 * @brief Returns the underlying string to OpSchema map.
1146 *
1147 * You should not manually manipulate the map object returned. Instead, use
1148 * the macros defined such as ONNX_OPERATOR_SET_SCHEMA to register your
1149 * operator schema.
1150 *
1151 * We wrap it inside a function to avoid the static initialization order
1152 * fiasco.
1153 */
1154 static OpName_Domain_Version_Schema_Map& GetMapWithoutEnsuringRegistration();
1155 static OpName_Domain_Version_Schema_Map& map();
1156 static int loaded_schema_version;
1157
1158 public:
1159 static const std::vector<OpSchema> get_all_schemas_with_history() {
1160 std::vector<OpSchema> r;
1161 for (auto& x : map()) {
1162 for (auto& y : x.second) {
1163 for (auto& z : y.second) {
1164 r.emplace_back(z.second);
1165 }
1166 }
1167 }
1168 return r;
1169 }
1170
1171 static const std::vector<OpSchema> get_all_schemas() {
1172 std::vector<OpSchema> r;
1173 for (auto& x : map()) {
1174 for (auto& y : x.second) {
1175 auto& version2schema = y.second;
1176 r.emplace_back(version2schema.rbegin()->second);
1177 }
1178 }
1179 return r;
1180 }
1181};
1182
1183void RegisterSchema(OpSchema schema, int opset_version_to_load = 0);
1184
1185// Registers the latest opset schema before opset_version_to_load
1186// By default opset_version_to_load=0 means it will register all versions
1187template <class T>
1188void RegisterOpSetSchema(int opset_version_to_load = 0) {
1189 T::ForEachSchema([opset_version_to_load](OpSchema&& schema) { RegisterSchema(schema, opset_version_to_load); });
1190};
1191
1192// Forward declaration for the non-specialized GetOpSchema method. This
1193// enforces a consistent signature on functions that query individual schema,
1194// which are defined as specializations of this function.
1195template <typename T>
1196OpSchema GetOpSchema();
1197
1198#define ONNX_OPERATOR_SET_SCHEMA(name, ver, impl) ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ONNX_DOMAIN, ver, true, impl)
1199
1200#define ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl) \
1201 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxML, AI_ONNX_ML_DOMAIN, ver, true, impl)
1202
1203#define ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1204 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxTraining, AI_ONNX_TRAINING_DOMAIN, ver, true, impl)
1205
1206#define ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1207 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxPreview, AI_ONNX_PREVIEW_TRAINING_DOMAIN, ver, true, impl)
1208
1209// Defines specialization of GetOpSchema for a class whose name is determined
1210// based on a convention using name, domain, and version. Operator schema are
1211// normally included in operator sets and registered in OpSchemaRegistry::map().
1212// In this case, callers should set dbg_included_in_static_opset to true. This
1213// assists with runtime validation in DEBUG builds ensuring the intended set
1214// of operator schema is registered.
1215#define ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl) \
1216 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name); \
1217 template <> \
1218 OpSchema GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)>() { \
1219 return impl.SetName(#name).SetDomain(domain_str).SinceVersion(ver).SetLocation(__FILE__, __LINE__); \
1220 } \
1221 size_t dbg_count_check_##name##_##domain##_ver##ver = \
1222 (dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() : 0;
1223#ifdef NDEBUG
1224#define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0
1225#else
1226#define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().IncrementCount()
1227#define ONNX_DBG_GET_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().GetCount()
1228
1229class DbgOperatorSetTracker {
1230 public:
1231 static DbgOperatorSetTracker& Instance();
1232
1233 size_t IncrementCount() {
1234 return ++count_;
1235 }
1236
1237 size_t GetCount() const {
1238 return count_;
1239 }
1240
1241 private:
1242 size_t count_ = 0;
1243};
1244#endif
1245
1246// Naming convention for operator schema classes
1247#define ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name) name##_##domain##_ver##ver
1248
1249// Naming convention for preview operator schema classes
1250#define ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name) \
1251 ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxPreview, ver, name)
1252
1253// Helper function
1254size_t ReplaceAll(std::string& s, const char* from, const char* to);
1255
1256#ifdef __GNUC__
1257#define ONNX_UNUSED __attribute__((__unused__))
1258#else
1259#define ONNX_UNUSED
1260#endif
1261
1262// Legacy macros to register schema at static initialization
1263#define ONNX_OPERATOR_SCHEMA(name) ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name)
1264#define ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)
1265#define ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name) \
1266 static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce(op_schema_register_once##name##Counter) ONNX_UNUSED = \
1267 OpSchema(#name, __FILE__, __LINE__)
1268
1269// Helper function
1270size_t ReplaceAll(std::string& s, const char* from, const char* to);
1271
1272inline std::string GenerateOptionalArgumentsDoc() {
1273 return "This operator has **optional** inputs/outputs. "
1274 "See [the doc](IR.md) for more details about the representation of "
1275 "optional arguments. An empty string may be used in the place of "
1276 "an actual argument's name to indicate a missing argument. "
1277 "Trailing optional arguments (those not followed by an argument "
1278 "that is present) may also be simply omitted.\n";
1279}
1280
1281inline std::string GenerateBroadcastingDocMul() {
1282 return "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**;"
1283 " for more details please check [the doc](Broadcasting.md).";
1284}
1285
1286inline std::string GenerateBroadcastingDocUni(const char* from, const char* to) {
1287 std::string ret = "This operator supports **unidirectional broadcasting** (";
1288 ret = ret + from + " should be unidirectional broadcastable to " + to +
1289 ");"
1290 " for more details please check [the doc](Broadcasting.md).";
1291 return ret;
1292}
1293
1294/*
1295 * Macros for setting operator documentation
1296 * Use this macro for simple SetDoc() calls that generate documentation
1297 * directly. This is the macro to use in almost all cases.
1298 * Sample usage guidelines:
1299 * const char* doc_str = "foo";
1300 * SetDoc(GET_OP_DOC_STR(doc_str))
1301 *
1302 * SetDoc(GET_OP_DOC_STR(
1303 std::string(BitShift_ver11_doc) + GenerateBroadcastingDocMul()))
1304 */
1305#ifndef __ONNX_NO_DOC_STRINGS
1306#define GET_OP_DOC_STR(doc_str) (doc_str)
1307#else
1308#define GET_OP_DOC_STR(doc_str) ("")
1309#endif
1310
1311/*
1312 * Use this macro when the documentation needs to be populated in some
1313 * complicated way like string substitutions, etc before calling SetDoc.
1314 * Sample usage guidelines:
1315 std::string doc;
1316 POPULATE_OP_DOC_STR(
1317 doc = R"DOC(
1318Returns the tensor resulted from performing the `{name}` logical operation
1319elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting
1320support).
1321
1322{broadcast_doc}
1323)DOC";
1324 ReplaceAll(doc, "{name}", name);
1325 ReplaceAll(
1326 doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
1327 schema.SetDoc(doc);
1328 *
1329 */
1330#ifndef __ONNX_NO_DOC_STRINGS
1331#define POPULATE_OP_DOC_STR(DocPopulatorCode) \
1332 do { \
1333 DocPopulatorCode \
1334 } while (0)
1335#else
1336#define POPULATE_OP_DOC_STR(DocPopulatorCode)
1337#endif
1338
1339} // namespace ONNX_NAMESPACE
1340