1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/variant.h> |
6 | #include <cmath> |
7 | #include <iostream> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | class TORCH_CUDA_CU_API IntOrDouble { |
15 | c10::variant<double, int64_t> value_; |
16 | |
17 | public: |
18 | IntOrDouble(int64_t i) : value_(i) {} |
19 | IntOrDouble(double d) : value_(d) {} |
20 | IntOrDouble(int i) : value_((int64_t)i) {} |
21 | IntOrDouble(size_t i) : value_((int64_t)i) {} |
22 | IntOrDouble() : IntOrDouble(0) {} |
23 | |
24 | // Avoid using copy constructor of c10::variant as it's |
25 | // deprecated. |
26 | IntOrDouble(const IntOrDouble& other) { |
27 | value_ = other.value_; |
28 | } |
29 | |
30 | // Explicitly define copy assignment operator as its implicit definition is |
31 | // deprecated |
32 | IntOrDouble& operator=(const IntOrDouble& other) { |
33 | value_ = other.value_; |
34 | return *this; |
35 | } |
36 | |
37 | bool is_int() const { |
38 | return c10::holds_alternative<int64_t>(value_); |
39 | } |
40 | |
41 | template <typename T> |
42 | T as() const { |
43 | TORCH_CHECK( |
44 | c10::holds_alternative<T>(value_), |
45 | "The expected dtype and the actual dtype does not match in IntOrDouble" ); |
46 | return c10::get<T>(value_); |
47 | } |
48 | |
49 | template <typename T> |
50 | T cast() const; |
51 | |
52 | #define DEFINE_ARITHMETIC_OP(op) \ |
53 | IntOrDouble operator op(const IntOrDouble& other) const { \ |
54 | switch ((int)is_int() << 1 | (int)other.is_int()) { \ |
55 | case 0b00: \ |
56 | return IntOrDouble(as<double>() op other.as<double>()); \ |
57 | case 0b01: \ |
58 | return IntOrDouble(as<double>() op other.as<int64_t>()); \ |
59 | case 0b10: \ |
60 | return IntOrDouble(as<int64_t>() op other.as<double>()); \ |
61 | case 0b11: \ |
62 | return IntOrDouble(as<int64_t>() op other.as<int64_t>()); \ |
63 | } \ |
64 | TORCH_INTERNAL_ASSERT(false); \ |
65 | } \ |
66 | template <typename T> \ |
67 | IntOrDouble operator op(T other) const { \ |
68 | if (is_int()) { \ |
69 | return IntOrDouble(as<int64_t>() op other); \ |
70 | } \ |
71 | return IntOrDouble(as<double>() op other); \ |
72 | } |
73 | |
74 | DEFINE_ARITHMETIC_OP(+) |
75 | DEFINE_ARITHMETIC_OP(-) |
76 | DEFINE_ARITHMETIC_OP(*) |
77 | DEFINE_ARITHMETIC_OP(/) |
78 | DEFINE_ARITHMETIC_OP(&&) |
79 | |
80 | #undef DEFINE_ARITHMETIC_OP |
81 | |
82 | #define DEFINE_ASSIGN_OP(assign, op) \ |
83 | IntOrDouble& operator assign(const IntOrDouble& other) { \ |
84 | switch ((int)is_int() << 1 | (int)other.is_int()) { \ |
85 | case 0b00: \ |
86 | return *this = IntOrDouble(as<double>() op other.as<double>()); \ |
87 | case 0b01: \ |
88 | return *this = IntOrDouble(as<double>() op other.as<int64_t>()); \ |
89 | case 0b10: \ |
90 | return *this = IntOrDouble(as<int64_t>() op other.as<double>()); \ |
91 | case 0b11: \ |
92 | return *this = IntOrDouble(as<int64_t>() op other.as<int64_t>()); \ |
93 | } \ |
94 | TORCH_INTERNAL_ASSERT(false); \ |
95 | } \ |
96 | template <typename T> \ |
97 | IntOrDouble& operator assign(T other) { \ |
98 | if (is_int()) { \ |
99 | return *this = IntOrDouble(as<int64_t>() op other); \ |
100 | } \ |
101 | return *this = IntOrDouble(as<double>() op other); \ |
102 | } |
103 | |
104 | DEFINE_ASSIGN_OP(+=, +) |
105 | DEFINE_ASSIGN_OP(-=, -) |
106 | DEFINE_ASSIGN_OP(*=, *) |
107 | DEFINE_ASSIGN_OP(/=, /) |
108 | |
109 | #undef DEFINE_ASSIGN_OP |
110 | |
111 | IntOrDouble operator%(const IntOrDouble& other) const { |
112 | if (is_int() && other.is_int()) { |
113 | return IntOrDouble(as<int64_t>() % other.as<int64_t>()); |
114 | } |
115 | TORCH_INTERNAL_ASSERT(false); |
116 | } |
117 | IntOrDouble operator%(int64_t other) const { |
118 | if (is_int()) { |
119 | return IntOrDouble(as<int64_t>() % other); |
120 | } |
121 | TORCH_INTERNAL_ASSERT(false); |
122 | } |
123 | IntOrDouble& operator%=(const IntOrDouble& other) { |
124 | if (is_int() && other.is_int()) { |
125 | return *this = IntOrDouble(as<int64_t>() % other.as<int64_t>()); |
126 | } |
127 | TORCH_INTERNAL_ASSERT(false); |
128 | } |
129 | IntOrDouble& operator%=(int64_t other) { |
130 | if (is_int()) { |
131 | return *this = IntOrDouble(as<int64_t>() % other); |
132 | } |
133 | TORCH_INTERNAL_ASSERT(false); |
134 | } |
135 | |
136 | #define DEFINE_COMPARE_OP(op) \ |
137 | bool operator op(const IntOrDouble& other) const { \ |
138 | switch ((int)is_int() << 1 | (int)other.is_int()) { \ |
139 | case 0b00: \ |
140 | return as<double>() op other.as<double>(); \ |
141 | case 0b01: \ |
142 | return as<double>() op other.as<int64_t>(); \ |
143 | case 0b10: \ |
144 | return as<int64_t>() op other.as<double>(); \ |
145 | case 0b11: \ |
146 | return as<int64_t>() op other.as<int64_t>(); \ |
147 | } \ |
148 | TORCH_INTERNAL_ASSERT(false); \ |
149 | } \ |
150 | bool operator op(double other) { \ |
151 | if (is_int()) { \ |
152 | return as<int64_t>() op other; \ |
153 | } \ |
154 | return as<double>() op other; \ |
155 | } \ |
156 | bool operator op(int64_t other) { \ |
157 | if (is_int()) { \ |
158 | return as<int64_t>() op other; \ |
159 | } \ |
160 | return as<double>() op other; \ |
161 | } \ |
162 | bool operator op(int other) { \ |
163 | if (is_int()) { \ |
164 | return as<int64_t>() op other; \ |
165 | } \ |
166 | return as<double>() op other; \ |
167 | } |
168 | |
169 | DEFINE_COMPARE_OP(>) |
170 | DEFINE_COMPARE_OP(>=) |
171 | DEFINE_COMPARE_OP(<) |
172 | DEFINE_COMPARE_OP(<=) |
173 | DEFINE_COMPARE_OP(==) |
174 | DEFINE_COMPARE_OP(!=) |
175 | |
176 | #undef DEFINE_COMPARE_OP |
177 | |
178 | IntOrDouble operator-() const { |
179 | if (is_int()) { |
180 | return IntOrDouble(-as<int64_t>()); |
181 | } |
182 | return IntOrDouble(-as<double>()); |
183 | } |
184 | |
185 | explicit operator double() const; |
186 | explicit operator int64_t() const; |
187 | explicit operator size_t() const; |
188 | explicit operator int() const; |
189 | }; |
190 | |
191 | #define DEFINE_ARITHMETIC_OP(op) \ |
192 | template <typename T> \ |
193 | inline IntOrDouble operator op(T lhs, IntOrDouble rhs) { \ |
194 | if (rhs.is_int()) { \ |
195 | return IntOrDouble(lhs op rhs.as<int64_t>()); \ |
196 | } \ |
197 | return IntOrDouble(lhs op rhs.as<double>()); \ |
198 | } |
199 | |
200 | DEFINE_ARITHMETIC_OP(+) |
201 | DEFINE_ARITHMETIC_OP(-) |
202 | DEFINE_ARITHMETIC_OP(*) |
203 | DEFINE_ARITHMETIC_OP(/) |
204 | |
205 | #undef DEFINE_ARITHMETIC_OP |
206 | |
207 | template <> |
208 | inline double IntOrDouble::cast<double>() const { |
209 | if (is_int()) { |
210 | return (double)as<int64_t>(); |
211 | } |
212 | return as<double>(); |
213 | } |
214 | |
215 | template <> |
216 | inline int64_t IntOrDouble::cast<int64_t>() const { |
217 | if (!is_int()) { |
218 | return (int64_t)as<double>(); |
219 | } |
220 | return as<int64_t>(); |
221 | } |
222 | |
223 | inline IntOrDouble::operator double() const { |
224 | return as<double>(); |
225 | } |
226 | |
227 | inline IntOrDouble::operator int64_t() const { |
228 | return as<int64_t>(); |
229 | } |
230 | |
231 | inline IntOrDouble::operator size_t() const { |
232 | return as<int64_t>(); |
233 | } |
234 | |
235 | inline IntOrDouble::operator int() const { |
236 | return as<int64_t>(); |
237 | } |
238 | |
239 | #define DEFINE_EQ_OP(op) \ |
240 | inline bool operator op(double lhs, const IntOrDouble& rhs) { \ |
241 | if (rhs.is_int()) { \ |
242 | return false; \ |
243 | } \ |
244 | return lhs op rhs.as<double>(); \ |
245 | } \ |
246 | \ |
247 | inline bool operator op(int64_t lhs, const IntOrDouble& rhs) { \ |
248 | if (rhs.is_int()) { \ |
249 | return lhs op rhs.as<int64_t>(); \ |
250 | } \ |
251 | return false; \ |
252 | } \ |
253 | \ |
254 | inline bool operator op(int lhs, const IntOrDouble& rhs) { \ |
255 | return operator op((int64_t)lhs, rhs); \ |
256 | } |
257 | |
258 | DEFINE_EQ_OP(==) |
259 | DEFINE_EQ_OP(!=) |
260 | |
261 | #undef DEFINE_EQ_OP |
262 | |
263 | inline std::ostream& operator<<(std::ostream& os, const IntOrDouble& val) { |
264 | if (val.is_int()) { |
265 | return os << val.as<int64_t>(); |
266 | } |
267 | return os << val.as<double>(); |
268 | } |
269 | |
270 | namespace IntOrDouble_functions { |
271 | |
272 | inline IntOrDouble ceildiv(const IntOrDouble& a, const IntOrDouble& b) { |
273 | if (a.is_int() && b.is_int()) { |
274 | auto aa = a.as<int64_t>(); |
275 | auto bb = b.as<int64_t>(); |
276 | if (bb > 0) { |
277 | return (aa + bb - 1) / bb; |
278 | } else { |
279 | return (aa + bb + 1) / bb; |
280 | } |
281 | } |
282 | return std::ceil((a / b).as<double>()); |
283 | } |
284 | |
285 | inline IntOrDouble max(const IntOrDouble& a, const IntOrDouble& b) { |
286 | if (a.is_int() && b.is_int()) { |
287 | return std::max(a.as<int64_t>(), b.as<int64_t>()); |
288 | } |
289 | return (a > b ? a : b).cast<double>(); |
290 | } |
291 | |
292 | inline IntOrDouble min(const IntOrDouble& a, const IntOrDouble& b) { |
293 | if (a.is_int() && b.is_int()) { |
294 | return std::min(a.as<int64_t>(), b.as<int64_t>()); |
295 | } |
296 | return (a < b ? a : b).cast<double>(); |
297 | } |
298 | |
299 | inline IntOrDouble abs(const IntOrDouble& a) { |
300 | if (a.is_int()) { |
301 | return IntOrDouble(std::abs(a.as<int64_t>())); |
302 | } else { |
303 | return IntOrDouble(std::abs(a.as<double>())); |
304 | } |
305 | } |
306 | |
307 | } // namespace IntOrDouble_functions |
308 | |
309 | } // namespace cuda |
310 | } // namespace fuser |
311 | } // namespace jit |
312 | } // namespace torch |
313 | |