1#pragma once
2
3#include "onnx/defs/schema.h"
4
5namespace ONNX_NAMESPACE {
6
7class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
8 PyTorch,
9 1,
10 SparseLengthsSumFused8BitRowwise);
11class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum);
12class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum);
13class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather);
14class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct);
15class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed);
16class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchMatMul);
17class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, ExpandDims);
18
19// Iterate over schema from ai.onnx.pytorch domain opset 1
20class OpSet_PyTorch_ver1 {
21 public:
22 static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
23 fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
24 PyTorch, 1, SparseLengthsSumFused8BitRowwise)>());
25 fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
26 PyTorch, 1, SparseLengthsSum)>());
27 fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
28 PyTorch, 1, SparseLengthsWeightedSum)>());
29 fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
30 PyTorch, 1, BatchGather)>());
31 fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
32 PyTorch, 1, DotProduct)>());
33 fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
34 PyTorch, 1, FCTransposed)>());
35 fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
36 PyTorch, 1, BatchMatMul)>());
37 fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
38 PyTorch, 1, ExpandDims)>());
39 }
40};
41
42inline void RegisterPyTorchOperatorSetSchema() {
43 RegisterOpSetSchema<OpSet_PyTorch_ver1>();
44}
45
46} // namespace ONNX_NAMESPACE
47