1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. |
6 | // Adventurous users should note that the APIs will probably change. |
7 | |
8 | #pragma once |
9 | #include <stdint.h> |
10 | #include <string> |
11 | #include <unordered_map> |
12 | #include <vector> |
13 | |
14 | namespace ONNX_NAMESPACE { |
15 | |
16 | #define FORALL_BUILTIN_SYMBOLS(_) \ |
17 | _(spatial) \ |
18 | _(select_last_index) \ |
19 | _(coordinate_transformation_mode) \ |
20 | _(PythonOp) \ |
21 | _(CppOp) \ |
22 | _(Param) \ |
23 | _(Select) \ |
24 | _(Return) \ |
25 | _(Eval) \ |
26 | _(add) \ |
27 | _(Add) \ |
28 | _(Div) \ |
29 | _(Mul) \ |
30 | _(Neg) \ |
31 | _(Sub) \ |
32 | _(Pow) \ |
33 | _(Sigmoid) \ |
34 | _(ArgMax) \ |
35 | _(Concat) \ |
36 | _(Softmax) \ |
37 | _(LogSoftmax) \ |
38 | _(Dropout) \ |
39 | _(Tanh) \ |
40 | _(mul) \ |
41 | _(neg) \ |
42 | _(sigmoid) \ |
43 | _(tanh) \ |
44 | _(Constant) \ |
45 | _(cat) \ |
46 | _(Slice) \ |
47 | _(Squeeze) \ |
48 | _(Undefined) \ |
49 | _(FusionGroup) \ |
50 | _(MatMul) \ |
51 | _(Gemm) \ |
52 | _(Tile) \ |
53 | _(SubConstant) \ |
54 | _(Scale) \ |
55 | _(Transpose) \ |
56 | _(Pad) \ |
57 | _(Reshape) \ |
58 | _(split) \ |
59 | _(chunk) \ |
60 | _(Offset) \ |
61 | _(value) \ |
62 | _(Subgraph) \ |
63 | _(BatchNormalization) \ |
64 | _(Conv) \ |
65 | _(ConvTranspose) \ |
66 | _(is_test) \ |
67 | _(epsilon) \ |
68 | _(expand) \ |
69 | _(Expand) \ |
70 | _(order) \ |
71 | _(momentum) \ |
72 | _(consumed_inputs) \ |
73 | _(kernels) \ |
74 | _(kernel_shape) \ |
75 | _(kernel) \ |
76 | _(scale) \ |
77 | _(strides) \ |
78 | _(stride) \ |
79 | _(pads) \ |
80 | _(pad) \ |
81 | _(beta) \ |
82 | _(alpha) \ |
83 | _(dilations) \ |
84 | _(dilation) \ |
85 | _(broadcast) \ |
86 | _(axis) \ |
87 | _(ratio) \ |
88 | _(size) \ |
89 | _(dim) \ |
90 | _(keepdims) \ |
91 | _(perm) \ |
92 | _(shape) \ |
93 | _(axes) \ |
94 | _(group) \ |
95 | _(inplace) \ |
96 | _(transA) \ |
97 | _(transB) \ |
98 | _(other) \ |
99 | _(__and__) \ |
100 | _(__lshift__) \ |
101 | _(__or__) \ |
102 | _(__rshift__) \ |
103 | _(__xor__) \ |
104 | _(abs) \ |
105 | _(acos) \ |
106 | _(asin) \ |
107 | _(atan) \ |
108 | _(atan2) \ |
109 | _(ceil) \ |
110 | _(clamp) \ |
111 | _(cos) \ |
112 | _(cosh) \ |
113 | _(div) \ |
114 | _(eq) \ |
115 | _(equal) \ |
116 | _(Exp) \ |
117 | _(ends) \ |
118 | _(expm1) \ |
119 | _(floor) \ |
120 | _(fmod) \ |
121 | _(frac) \ |
122 | _(ge) \ |
123 | _(gt) \ |
124 | _(le) \ |
125 | _(lerp) \ |
126 | _(lgamma) \ |
127 | _(Log) \ |
128 | _(log1p) \ |
129 | _(lt) \ |
130 | _(max) \ |
131 | _(min) \ |
132 | _(ne) \ |
133 | _(ones) \ |
134 | _(pow) \ |
135 | _(reciprocal) \ |
136 | _(remainder) \ |
137 | _(round) \ |
138 | _(rsqrt) \ |
139 | _(sin) \ |
140 | _(sinh) \ |
141 | _(Sqrt) \ |
142 | _(sub) \ |
143 | _(starts) \ |
144 | _(tan) \ |
145 | _(trunc) \ |
146 | _(zeros) \ |
147 | _(exponent) \ |
148 | _(device) \ |
149 | _(mode) \ |
150 | _(Identity) \ |
151 | _(Loop) \ |
152 | _(If) \ |
153 | _(body) \ |
154 | _(then_branch) \ |
155 | _(else_branch) \ |
156 | _(Captured) \ |
157 | _(__control_inputs) \ |
158 | _(count_include_pad) \ |
159 | _(storage_order) \ |
160 | _(Unsqueeze) \ |
161 | _(ReduceL1) \ |
162 | _(ReduceL2) \ |
163 | _(ReduceLogSum) \ |
164 | _(ReduceLogSumExp) \ |
165 | _(ReduceMax) \ |
166 | _(ReduceMean) \ |
167 | _(ReduceMin) \ |
168 | _(ReduceProd) \ |
169 | _(ReduceSum) \ |
170 | _(ReduceSumSquare) \ |
171 | _(Cast) \ |
172 | _(to) \ |
173 | _(PRelu) \ |
174 | _(Greater) \ |
175 | _(Less) \ |
176 | _(scales) \ |
177 | _(Upsample) \ |
178 | _(RNN) \ |
179 | _(layout) \ |
180 | _(k) \ |
181 | _(Flatten) \ |
182 | _(ScatterElements) \ |
183 | _(Resize) \ |
184 | _(ceil_mode) \ |
185 | _(num_outputs) |
186 | |
187 | enum BuiltinSymbol { |
188 | #define DEFINE_SYMBOL(s) k##s, |
189 | FORALL_BUILTIN_SYMBOLS(DEFINE_SYMBOL) |
190 | #undef DEFINE_SYMBOL |
191 | kLastSymbol, // where we start counting for new symbols |
192 | }; |
193 | |
194 | struct Symbol { |
195 | Symbol() {} |
196 | /*implicit*/ Symbol(BuiltinSymbol value) : value(value) {} |
197 | explicit Symbol(const std::string& s); |
198 | explicit Symbol(uint32_t value) : value(value) {} |
199 | |
200 | operator uint32_t() const { |
201 | return value; |
202 | } |
203 | const char* toString() const; |
204 | |
205 | private: |
206 | uint32_t value; |
207 | }; |
208 | |
209 | static inline bool operator==(Symbol lhs, Symbol rhs) { |
210 | return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs); |
211 | } |
212 | // necessary to prevent ambiguous overload resolutions |
213 | static inline bool operator==(BuiltinSymbol lhs, Symbol rhs) { |
214 | return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs); |
215 | } |
216 | static inline bool operator==(Symbol lhs, BuiltinSymbol rhs) { |
217 | return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs); |
218 | } |
219 | |
220 | inline Symbol operator"" _sym(const char* s, size_t) { |
221 | return Symbol(s); |
222 | } |
223 | |
224 | } // namespace ONNX_NAMESPACE |
225 | |
226 | // make symbol behave like an integer in hash tables |
227 | namespace std { |
228 | template <> |
229 | struct hash<ONNX_NAMESPACE::Symbol> { |
230 | std::size_t operator()(ONNX_NAMESPACE::Symbol s) const { |
231 | return std::hash<uint32_t>()(static_cast<uint32_t>(s)); |
232 | } |
233 | }; |
234 | |
235 | } // namespace std |
236 | |