1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// ATTENTION: The code in this file is highly EXPERIMENTAL.
6// Adventurous users should note that the APIs will probably change.
7
8#pragma once
9#include <stdint.h>
10#include <string>
11#include <unordered_map>
12#include <vector>
13
14namespace ONNX_NAMESPACE {
15
16#define FORALL_BUILTIN_SYMBOLS(_) \
17 _(spatial) \
18 _(select_last_index) \
19 _(coordinate_transformation_mode) \
20 _(PythonOp) \
21 _(CppOp) \
22 _(Param) \
23 _(Select) \
24 _(Return) \
25 _(Eval) \
26 _(add) \
27 _(Add) \
28 _(Div) \
29 _(Mul) \
30 _(Neg) \
31 _(Sub) \
32 _(Pow) \
33 _(Sigmoid) \
34 _(ArgMax) \
35 _(Concat) \
36 _(Softmax) \
37 _(LogSoftmax) \
38 _(Dropout) \
39 _(Tanh) \
40 _(mul) \
41 _(neg) \
42 _(sigmoid) \
43 _(tanh) \
44 _(Constant) \
45 _(cat) \
46 _(Slice) \
47 _(Squeeze) \
48 _(Undefined) \
49 _(FusionGroup) \
50 _(MatMul) \
51 _(Gemm) \
52 _(Tile) \
53 _(SubConstant) \
54 _(Scale) \
55 _(Transpose) \
56 _(Pad) \
57 _(Reshape) \
58 _(split) \
59 _(chunk) \
60 _(Offset) \
61 _(value) \
62 _(Subgraph) \
63 _(BatchNormalization) \
64 _(Conv) \
65 _(ConvTranspose) \
66 _(is_test) \
67 _(epsilon) \
68 _(expand) \
69 _(Expand) \
70 _(order) \
71 _(momentum) \
72 _(consumed_inputs) \
73 _(kernels) \
74 _(kernel_shape) \
75 _(kernel) \
76 _(scale) \
77 _(strides) \
78 _(stride) \
79 _(pads) \
80 _(pad) \
81 _(beta) \
82 _(alpha) \
83 _(dilations) \
84 _(dilation) \
85 _(broadcast) \
86 _(axis) \
87 _(ratio) \
88 _(size) \
89 _(dim) \
90 _(keepdims) \
91 _(perm) \
92 _(shape) \
93 _(axes) \
94 _(group) \
95 _(inplace) \
96 _(transA) \
97 _(transB) \
98 _(other) \
99 _(__and__) \
100 _(__lshift__) \
101 _(__or__) \
102 _(__rshift__) \
103 _(__xor__) \
104 _(abs) \
105 _(acos) \
106 _(asin) \
107 _(atan) \
108 _(atan2) \
109 _(ceil) \
110 _(clamp) \
111 _(cos) \
112 _(cosh) \
113 _(div) \
114 _(eq) \
115 _(equal) \
116 _(Exp) \
117 _(ends) \
118 _(expm1) \
119 _(floor) \
120 _(fmod) \
121 _(frac) \
122 _(ge) \
123 _(gt) \
124 _(le) \
125 _(lerp) \
126 _(lgamma) \
127 _(Log) \
128 _(log1p) \
129 _(lt) \
130 _(max) \
131 _(min) \
132 _(ne) \
133 _(ones) \
134 _(pow) \
135 _(reciprocal) \
136 _(remainder) \
137 _(round) \
138 _(rsqrt) \
139 _(sin) \
140 _(sinh) \
141 _(Sqrt) \
142 _(sub) \
143 _(starts) \
144 _(tan) \
145 _(trunc) \
146 _(zeros) \
147 _(exponent) \
148 _(device) \
149 _(mode) \
150 _(Identity) \
151 _(Loop) \
152 _(If) \
153 _(body) \
154 _(then_branch) \
155 _(else_branch) \
156 _(Captured) \
157 _(__control_inputs) \
158 _(count_include_pad) \
159 _(storage_order) \
160 _(Unsqueeze) \
161 _(ReduceL1) \
162 _(ReduceL2) \
163 _(ReduceLogSum) \
164 _(ReduceLogSumExp) \
165 _(ReduceMax) \
166 _(ReduceMean) \
167 _(ReduceMin) \
168 _(ReduceProd) \
169 _(ReduceSum) \
170 _(ReduceSumSquare) \
171 _(Cast) \
172 _(to) \
173 _(PRelu) \
174 _(Greater) \
175 _(Less) \
176 _(scales) \
177 _(Upsample) \
178 _(RNN) \
179 _(layout) \
180 _(k) \
181 _(Flatten) \
182 _(ScatterElements) \
183 _(Resize) \
184 _(ceil_mode) \
185 _(num_outputs)
186
187enum BuiltinSymbol {
188#define DEFINE_SYMBOL(s) k##s,
189 FORALL_BUILTIN_SYMBOLS(DEFINE_SYMBOL)
190#undef DEFINE_SYMBOL
191 kLastSymbol, // where we start counting for new symbols
192};
193
194struct Symbol {
195 Symbol() {}
196 /*implicit*/ Symbol(BuiltinSymbol value) : value(value) {}
197 explicit Symbol(const std::string& s);
198 explicit Symbol(uint32_t value) : value(value) {}
199
200 operator uint32_t() const {
201 return value;
202 }
203 const char* toString() const;
204
205 private:
206 uint32_t value;
207};
208
209static inline bool operator==(Symbol lhs, Symbol rhs) {
210 return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
211}
212// necessary to prevent ambiguous overload resolutions
213static inline bool operator==(BuiltinSymbol lhs, Symbol rhs) {
214 return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
215}
216static inline bool operator==(Symbol lhs, BuiltinSymbol rhs) {
217 return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
218}
219
220inline Symbol operator"" _sym(const char* s, size_t) {
221 return Symbol(s);
222}
223
224} // namespace ONNX_NAMESPACE
225
226// make symbol behave like an integer in hash tables
227namespace std {
228template <>
229struct hash<ONNX_NAMESPACE::Symbol> {
230 std::size_t operator()(ONNX_NAMESPACE::Symbol s) const {
231 return std::hash<uint32_t>()(static_cast<uint32_t>(s));
232 }
233};
234
235} // namespace std
236