1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #pragma once |
6 | |
7 | #ifdef ONNX_ML |
8 | |
9 | #include "onnx/defs/schema.h" |
10 | |
11 | namespace ONNX_NAMESPACE { |
12 | |
13 | // Forward declarations for ai.onnx.ml version 1 |
14 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, ArrayFeatureExtractor); |
15 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Binarizer); |
16 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, CastMap); |
17 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, CategoryMapper); |
18 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, DictVectorizer); |
19 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, FeatureVectorizer); |
20 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Imputer); |
21 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LabelEncoder); |
22 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LinearClassifier); |
23 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LinearRegressor); |
24 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Normalizer); |
25 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, OneHotEncoder); |
26 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, SVMClassifier); |
27 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, SVMRegressor); |
28 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Scaler); |
29 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, TreeEnsembleClassifier); |
30 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, TreeEnsembleRegressor); |
31 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, ZipMap); |
32 | |
33 | // Iterate over schema from ai.onnx.ml version 1 |
34 | class OpSet_OnnxML_ver1 { |
35 | public: |
36 | static void ForEachSchema(std::function<void(OpSchema&&)> fn) { |
37 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, ArrayFeatureExtractor)>()); |
38 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Binarizer)>()); |
39 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, CastMap)>()); |
40 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, CategoryMapper)>()); |
41 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, DictVectorizer)>()); |
42 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, FeatureVectorizer)>()); |
43 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Imputer)>()); |
44 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LabelEncoder)>()); |
45 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LinearClassifier)>()); |
46 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LinearRegressor)>()); |
47 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Normalizer)>()); |
48 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, OneHotEncoder)>()); |
49 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, SVMClassifier)>()); |
50 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, SVMRegressor)>()); |
51 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Scaler)>()); |
52 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, TreeEnsembleClassifier)>()); |
53 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, TreeEnsembleRegressor)>()); |
54 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, ZipMap)>()); |
55 | } |
56 | }; |
57 | |
58 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 2, LabelEncoder); |
59 | |
60 | class OpSet_OnnxML_ver2 { |
61 | public: |
62 | static void ForEachSchema(std::function<void(OpSchema&&)> fn) { |
63 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 2, LabelEncoder)>()); |
64 | } |
65 | }; |
66 | |
67 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 3, TreeEnsembleClassifier); |
68 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 3, TreeEnsembleRegressor); |
69 | |
70 | class OpSet_OnnxML_ver3 { |
71 | public: |
72 | static void ForEachSchema(std::function<void(OpSchema&&)> fn) { |
73 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 3, TreeEnsembleClassifier)>()); |
74 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 3, TreeEnsembleRegressor)>()); |
75 | } |
76 | }; |
77 | |
78 | inline void RegisterOnnxMLOperatorSetSchema() { |
79 | RegisterOpSetSchema<OpSet_OnnxML_ver1>(); |
80 | RegisterOpSetSchema<OpSet_OnnxML_ver2>(); |
81 | RegisterOpSetSchema<OpSet_OnnxML_ver3>(); |
82 | } |
83 | } // namespace ONNX_NAMESPACE |
84 | |
85 | #endif |
86 | |