1 | #pragma once |
2 | |
3 | #include "onnx/defs/schema.h" |
4 | |
5 | namespace ONNX_NAMESPACE { |
6 | |
7 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
8 | PyTorch, |
9 | 1, |
10 | SparseLengthsSumFused8BitRowwise); |
11 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum); |
12 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum); |
13 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather); |
14 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct); |
15 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed); |
16 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchMatMul); |
17 | class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, ExpandDims); |
18 | |
19 | // Iterate over schema from ai.onnx.pytorch domain opset 1 |
20 | class OpSet_PyTorch_ver1 { |
21 | public: |
22 | static void ForEachSchema(std::function<void(OpSchema&&)> fn) { |
23 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
24 | PyTorch, 1, SparseLengthsSumFused8BitRowwise)>()); |
25 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
26 | PyTorch, 1, SparseLengthsSum)>()); |
27 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
28 | PyTorch, 1, SparseLengthsWeightedSum)>()); |
29 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
30 | PyTorch, 1, BatchGather)>()); |
31 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
32 | PyTorch, 1, DotProduct)>()); |
33 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
34 | PyTorch, 1, FCTransposed)>()); |
35 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
36 | PyTorch, 1, BatchMatMul)>()); |
37 | fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( |
38 | PyTorch, 1, ExpandDims)>()); |
39 | } |
40 | }; |
41 | |
42 | inline void RegisterPyTorchOperatorSetSchema() { |
43 | RegisterOpSetSchema<OpSet_PyTorch_ver1>(); |
44 | } |
45 | |
46 | } // namespace ONNX_NAMESPACE |
47 | |