1 | /* |
---|---|
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #pragma once |
6 | |
7 | #include "onnx/defs/schema.h" |
8 | |
9 | namespace ONNX_NAMESPACE { |
10 | |
11 | // Declare training operators. |
12 | |
13 | class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient); |
14 | class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum); |
15 | class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad); |
16 | class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam); |
17 | |
18 | // Iterate over schema from ai.onnx.training version 1 |
19 | class OpSet_OnnxPreview_ver1 { |
20 | public: |
21 | static void ForEachSchema(std::function<void(OpSchema&&)> fn) { |
22 | fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient)>()); |
23 | fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum)>()); |
24 | fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad)>()); |
25 | fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam)>()); |
26 | } |
27 | }; |
28 | |
29 | // Register preview operators. |
30 | inline void RegisterOnnxPreviewOperatorSetSchema() { |
31 | // Preview operators should have only one version. |
32 | // If changes are needed for a specific preview operator, |
33 | // its spec should be modified without increasing its version. |
34 | RegisterOpSetSchema<OpSet_OnnxPreview_ver1>(); |
35 | } |
36 | |
37 | } // namespace ONNX_NAMESPACE |
38 |