1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#pragma once
6
7#include "onnx/defs/schema.h"
8
9namespace ONNX_NAMESPACE {
10
11// Declare training operators.
12
13class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient);
14class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum);
15class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad);
16class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam);
17
18// Iterate over schema from ai.onnx.training version 1
19class OpSet_OnnxPreview_ver1 {
20 public:
21 static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
22 fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient)>());
23 fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum)>());
24 fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad)>());
25 fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam)>());
26 }
27};
28
29// Register preview operators.
30inline void RegisterOnnxPreviewOperatorSetSchema() {
31 // Preview operators should have only one version.
32 // If changes are needed for a specific preview operator,
33 // its spec should be modified without increasing its version.
34 RegisterOpSetSchema<OpSet_OnnxPreview_ver1>();
35}
36
37} // namespace ONNX_NAMESPACE
38