1 | // Copyright (c) Facebook Inc. and Microsoft Corporation. |
2 | // Licensed under the MIT license. |
3 | |
4 | #include "./schema.h" |
5 | |
6 | namespace ONNX_NAMESPACE { |
7 | |
8 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) |
9 | ONNX_PYTORCH_OPERATOR_SET_SCHEMA( |
10 | SparseLengthsSumFused8BitRowwise, |
11 | 1, |
12 | OpSchema() |
13 | .SetDoc("Mirror Caffe2 SparseLengthsSumFused8BitRowwise operator" ) |
14 | .Input(0, "DATA" , "data tensor" , "T1" ) |
15 | .Input(1, "INDICES" , "indices tensor" , "T2" ) |
16 | .Input(2, "LENGTHS" , "lengths tensor" , "T2" ) |
17 | .Output(0, "output" , "Output tensor" , "T2" ) |
18 | .TypeConstraint( |
19 | "T1" , |
20 | {"tensor(uint8)" }, |
21 | "Constrain input data to uint8 tensors." ) |
22 | .TypeConstraint( |
23 | "T2" , |
24 | {"tensor(int8)" , |
25 | "tensor(int16)" , |
26 | "tensor(int32)" , |
27 | "tensor(int64)" , |
28 | "tensor(uint8)" , |
29 | "tensor(uint16)" , |
30 | "tensor(uint32)" , |
31 | "tensor(uint64)" }, |
32 | "Constrain index and length to integral tensors." )); |
33 | |
34 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) |
35 | ONNX_PYTORCH_OPERATOR_SET_SCHEMA( |
36 | SparseLengthsSum, |
37 | 1, |
38 | OpSchema() |
39 | .SetDoc("Mirror Caffe2 SparseLengthsSum operator" ) |
40 | .Input(0, "DATA" , "data tensor" , "T1" ) |
41 | .Input(1, "INDICES" , "indices tensor" , "T2" ) |
42 | .Input(2, "LENGTHS" , "lengths tensor" , "T2" ) |
43 | .Output(0, "output" , "Output tensor" , "T1" ) |
44 | .TypeConstraint( |
45 | "T1" , |
46 | {"tensor(float16)" , "tensor(float)" , "tensor(double)" }, |
47 | "Constrain input and output types to float tensors." ) |
48 | .TypeConstraint( |
49 | "T2" , |
50 | {"tensor(int8)" , |
51 | "tensor(int16)" , |
52 | "tensor(int32)" , |
53 | "tensor(int64)" , |
54 | "tensor(uint8)" , |
55 | "tensor(uint16)" , |
56 | "tensor(uint32)" , |
57 | "tensor(uint64)" }, |
58 | "Constrain index and length to integral tensors." )); |
59 | |
60 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) |
61 | ONNX_PYTORCH_OPERATOR_SET_SCHEMA( |
62 | SparseLengthsWeightedSum, |
63 | 1, |
64 | OpSchema() |
65 | .SetDoc("Mirror Caffe2 SparseLengthsWeightedSum operator" ) |
66 | .Input(0, "DATA" , "data tensor" , "T1" ) |
67 | .Input(1, "WEIGHTS" , "data tensor" , "T1" ) |
68 | .Input(2, "INDICES" , "indices tensor" , "T2" ) |
69 | .Input(3, "LENGTHS" , "lengths tensor" , "T2" ) |
70 | .Output(0, "output" , "Output tensor" , "T1" ) |
71 | .TypeConstraint( |
72 | "T1" , |
73 | {"tensor(float16)" , "tensor(float)" , "tensor(double)" }, |
74 | "Constrain input and output types to float tensors." ) |
75 | .TypeConstraint( |
76 | "T2" , |
77 | {"tensor(int8)" , |
78 | "tensor(int16)" , |
79 | "tensor(int32)" , |
80 | "tensor(int64)" , |
81 | "tensor(uint8)" , |
82 | "tensor(uint16)" , |
83 | "tensor(uint32)" , |
84 | "tensor(uint64)" }, |
85 | "Constrain index and length to integral tensors." )); |
86 | |
87 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) |
88 | ONNX_PYTORCH_OPERATOR_SET_SCHEMA( |
89 | BatchGather, |
90 | 1, |
91 | OpSchema() |
92 | .SetDoc("Mirror Caffe2 BatchGather operator" ) |
93 | .Input(0, "DATA" , "data tensor" , "T1" ) |
94 | .Input(1, "INDICES" , "indices tensor" , "T2" ) |
95 | .Output(0, "output" , "Output tensor" , "T1" ) |
96 | .TypeConstraint( |
97 | "T1" , |
98 | {"tensor(float16)" , "tensor(float)" , "tensor(double)" }, |
99 | "Constrain input and output types to float tensors." ) |
100 | .TypeConstraint( |
101 | "T2" , |
102 | {"tensor(int8)" , |
103 | "tensor(int16)" , |
104 | "tensor(int32)" , |
105 | "tensor(int64)" , |
106 | "tensor(uint8)" , |
107 | "tensor(uint16)" , |
108 | "tensor(uint32)" , |
109 | "tensor(uint64)" }, |
110 | "Constrain index and length to integral tensors." )); |
111 | |
112 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) |
113 | ONNX_PYTORCH_OPERATOR_SET_SCHEMA( |
114 | DotProduct, |
115 | 1, |
116 | OpSchema() |
117 | .SetDoc("Mirror Caffe2 DotProduct operator" ) |
118 | .Input(0, "X" , "Input 1 tensor" , "T" ) |
119 | .Input(1, "Y" , "Input 2 tensor" , "T" ) |
120 | .Output(0, "Z" , "Output tensor" , "T" ) |
121 | .TypeConstraint( |
122 | "T" , |
123 | {"tensor(float16)" , "tensor(float)" , "tensor(double)" }, |
124 | "Constrain input and output types to float tensors." )); |
125 | |
126 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) |
127 | ONNX_PYTORCH_OPERATOR_SET_SCHEMA( |
128 | FCTransposed, |
129 | 1, |
130 | OpSchema() |
131 | .SetDoc("Mirror Caffe2 FCTransposed operator" ) |
132 | .Input(0, "X" , "Input tensor" , "T" ) |
133 | .Input(1, "W" , "Weight tensor" , "T" ) |
134 | .Input(2, "B" , "Bias tensor" , "T" ) |
135 | .Output(0, "Z" , "Output tensor" , "T" ) |
136 | .TypeConstraint( |
137 | "T" , |
138 | {"tensor(float16)" , "tensor(float)" , "tensor(double)" }, |
139 | "Constrain input and output types to float tensors." )); |
140 | |
141 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) |
142 | ONNX_PYTORCH_OPERATOR_SET_SCHEMA( |
143 | BatchMatMul, |
144 | 1, |
145 | OpSchema() |
146 | .SetDoc("Mirror Caffe2 BatchMatMul operator" ) |
147 | .Input(0, "X" , "tensor of shape (dim0, dim1 ... M, K)" , "T" ) |
148 | .Input(1, "Y" , "tensor of shape (dim0, dim2 ... K, N)" , "T" ) |
149 | .Output(0, "Z" , "tensor of shape (dim0, dim1 ... M, N)" , "T" ) |
150 | .TypeConstraint( |
151 | "T" , |
152 | {"tensor(float16)" , "tensor(float)" , "tensor(double)" }, |
153 | "Constrain input and output types to float tensors." )); |
154 | |
155 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) |
156 | ONNX_PYTORCH_OPERATOR_SET_SCHEMA( |
157 | ExpandDims, |
158 | 1, |
159 | OpSchema() |
160 | .SetDoc("Mirror Caffe2 ExpandDims operator" ) |
161 | .Input(0, "X" , "Input tensor" , "T" ) |
162 | .Output(0, "Y" , "Output tensor" , "T" ) |
163 | .TypeConstraint( |
164 | "T" , |
165 | {"tensor(float16)" , "tensor(float)" , "tensor(double)" }, |
166 | "Constrain input and output types to float tensors." )); |
167 | |
168 | } // namespace ONNX_NAMESPACE |
169 | |