1#pragma once
2
3#include <c10/macros/Export.h>
4#include <c10/util/Exception.h>
5#include <c10/util/variant.h>
6#include <cmath>
7#include <iostream>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14class TORCH_CUDA_CU_API IntOrDouble {
15 c10::variant<double, int64_t> value_;
16
17 public:
18 IntOrDouble(int64_t i) : value_(i) {}
19 IntOrDouble(double d) : value_(d) {}
20 IntOrDouble(int i) : value_((int64_t)i) {}
21 IntOrDouble(size_t i) : value_((int64_t)i) {}
22 IntOrDouble() : IntOrDouble(0) {}
23
24 // Avoid using copy constructor of c10::variant as it's
25 // deprecated.
26 IntOrDouble(const IntOrDouble& other) {
27 value_ = other.value_;
28 }
29
30 // Explicitly define copy assignment operator as its implicit definition is
31 // deprecated
32 IntOrDouble& operator=(const IntOrDouble& other) {
33 value_ = other.value_;
34 return *this;
35 }
36
37 bool is_int() const {
38 return c10::holds_alternative<int64_t>(value_);
39 }
40
41 template <typename T>
42 T as() const {
43 TORCH_CHECK(
44 c10::holds_alternative<T>(value_),
45 "The expected dtype and the actual dtype does not match in IntOrDouble");
46 return c10::get<T>(value_);
47 }
48
49 template <typename T>
50 T cast() const;
51
52#define DEFINE_ARITHMETIC_OP(op) \
53 IntOrDouble operator op(const IntOrDouble& other) const { \
54 switch ((int)is_int() << 1 | (int)other.is_int()) { \
55 case 0b00: \
56 return IntOrDouble(as<double>() op other.as<double>()); \
57 case 0b01: \
58 return IntOrDouble(as<double>() op other.as<int64_t>()); \
59 case 0b10: \
60 return IntOrDouble(as<int64_t>() op other.as<double>()); \
61 case 0b11: \
62 return IntOrDouble(as<int64_t>() op other.as<int64_t>()); \
63 } \
64 TORCH_INTERNAL_ASSERT(false); \
65 } \
66 template <typename T> \
67 IntOrDouble operator op(T other) const { \
68 if (is_int()) { \
69 return IntOrDouble(as<int64_t>() op other); \
70 } \
71 return IntOrDouble(as<double>() op other); \
72 }
73
74 DEFINE_ARITHMETIC_OP(+)
75 DEFINE_ARITHMETIC_OP(-)
76 DEFINE_ARITHMETIC_OP(*)
77 DEFINE_ARITHMETIC_OP(/)
78 DEFINE_ARITHMETIC_OP(&&)
79
80#undef DEFINE_ARITHMETIC_OP
81
82#define DEFINE_ASSIGN_OP(assign, op) \
83 IntOrDouble& operator assign(const IntOrDouble& other) { \
84 switch ((int)is_int() << 1 | (int)other.is_int()) { \
85 case 0b00: \
86 return *this = IntOrDouble(as<double>() op other.as<double>()); \
87 case 0b01: \
88 return *this = IntOrDouble(as<double>() op other.as<int64_t>()); \
89 case 0b10: \
90 return *this = IntOrDouble(as<int64_t>() op other.as<double>()); \
91 case 0b11: \
92 return *this = IntOrDouble(as<int64_t>() op other.as<int64_t>()); \
93 } \
94 TORCH_INTERNAL_ASSERT(false); \
95 } \
96 template <typename T> \
97 IntOrDouble& operator assign(T other) { \
98 if (is_int()) { \
99 return *this = IntOrDouble(as<int64_t>() op other); \
100 } \
101 return *this = IntOrDouble(as<double>() op other); \
102 }
103
104 DEFINE_ASSIGN_OP(+=, +)
105 DEFINE_ASSIGN_OP(-=, -)
106 DEFINE_ASSIGN_OP(*=, *)
107 DEFINE_ASSIGN_OP(/=, /)
108
109#undef DEFINE_ASSIGN_OP
110
111 IntOrDouble operator%(const IntOrDouble& other) const {
112 if (is_int() && other.is_int()) {
113 return IntOrDouble(as<int64_t>() % other.as<int64_t>());
114 }
115 TORCH_INTERNAL_ASSERT(false);
116 }
117 IntOrDouble operator%(int64_t other) const {
118 if (is_int()) {
119 return IntOrDouble(as<int64_t>() % other);
120 }
121 TORCH_INTERNAL_ASSERT(false);
122 }
123 IntOrDouble& operator%=(const IntOrDouble& other) {
124 if (is_int() && other.is_int()) {
125 return *this = IntOrDouble(as<int64_t>() % other.as<int64_t>());
126 }
127 TORCH_INTERNAL_ASSERT(false);
128 }
129 IntOrDouble& operator%=(int64_t other) {
130 if (is_int()) {
131 return *this = IntOrDouble(as<int64_t>() % other);
132 }
133 TORCH_INTERNAL_ASSERT(false);
134 }
135
136#define DEFINE_COMPARE_OP(op) \
137 bool operator op(const IntOrDouble& other) const { \
138 switch ((int)is_int() << 1 | (int)other.is_int()) { \
139 case 0b00: \
140 return as<double>() op other.as<double>(); \
141 case 0b01: \
142 return as<double>() op other.as<int64_t>(); \
143 case 0b10: \
144 return as<int64_t>() op other.as<double>(); \
145 case 0b11: \
146 return as<int64_t>() op other.as<int64_t>(); \
147 } \
148 TORCH_INTERNAL_ASSERT(false); \
149 } \
150 bool operator op(double other) { \
151 if (is_int()) { \
152 return as<int64_t>() op other; \
153 } \
154 return as<double>() op other; \
155 } \
156 bool operator op(int64_t other) { \
157 if (is_int()) { \
158 return as<int64_t>() op other; \
159 } \
160 return as<double>() op other; \
161 } \
162 bool operator op(int other) { \
163 if (is_int()) { \
164 return as<int64_t>() op other; \
165 } \
166 return as<double>() op other; \
167 }
168
169 DEFINE_COMPARE_OP(>)
170 DEFINE_COMPARE_OP(>=)
171 DEFINE_COMPARE_OP(<)
172 DEFINE_COMPARE_OP(<=)
173 DEFINE_COMPARE_OP(==)
174 DEFINE_COMPARE_OP(!=)
175
176#undef DEFINE_COMPARE_OP
177
178 IntOrDouble operator-() const {
179 if (is_int()) {
180 return IntOrDouble(-as<int64_t>());
181 }
182 return IntOrDouble(-as<double>());
183 }
184
185 explicit operator double() const;
186 explicit operator int64_t() const;
187 explicit operator size_t() const;
188 explicit operator int() const;
189};
190
191#define DEFINE_ARITHMETIC_OP(op) \
192 template <typename T> \
193 inline IntOrDouble operator op(T lhs, IntOrDouble rhs) { \
194 if (rhs.is_int()) { \
195 return IntOrDouble(lhs op rhs.as<int64_t>()); \
196 } \
197 return IntOrDouble(lhs op rhs.as<double>()); \
198 }
199
200DEFINE_ARITHMETIC_OP(+)
201DEFINE_ARITHMETIC_OP(-)
202DEFINE_ARITHMETIC_OP(*)
203DEFINE_ARITHMETIC_OP(/)
204
205#undef DEFINE_ARITHMETIC_OP
206
207template <>
208inline double IntOrDouble::cast<double>() const {
209 if (is_int()) {
210 return (double)as<int64_t>();
211 }
212 return as<double>();
213}
214
215template <>
216inline int64_t IntOrDouble::cast<int64_t>() const {
217 if (!is_int()) {
218 return (int64_t)as<double>();
219 }
220 return as<int64_t>();
221}
222
223inline IntOrDouble::operator double() const {
224 return as<double>();
225}
226
227inline IntOrDouble::operator int64_t() const {
228 return as<int64_t>();
229}
230
231inline IntOrDouble::operator size_t() const {
232 return as<int64_t>();
233}
234
235inline IntOrDouble::operator int() const {
236 return as<int64_t>();
237}
238
239#define DEFINE_EQ_OP(op) \
240 inline bool operator op(double lhs, const IntOrDouble& rhs) { \
241 if (rhs.is_int()) { \
242 return false; \
243 } \
244 return lhs op rhs.as<double>(); \
245 } \
246 \
247 inline bool operator op(int64_t lhs, const IntOrDouble& rhs) { \
248 if (rhs.is_int()) { \
249 return lhs op rhs.as<int64_t>(); \
250 } \
251 return false; \
252 } \
253 \
254 inline bool operator op(int lhs, const IntOrDouble& rhs) { \
255 return operator op((int64_t)lhs, rhs); \
256 }
257
258DEFINE_EQ_OP(==)
259DEFINE_EQ_OP(!=)
260
261#undef DEFINE_EQ_OP
262
263inline std::ostream& operator<<(std::ostream& os, const IntOrDouble& val) {
264 if (val.is_int()) {
265 return os << val.as<int64_t>();
266 }
267 return os << val.as<double>();
268}
269
270namespace IntOrDouble_functions {
271
272inline IntOrDouble ceildiv(const IntOrDouble& a, const IntOrDouble& b) {
273 if (a.is_int() && b.is_int()) {
274 auto aa = a.as<int64_t>();
275 auto bb = b.as<int64_t>();
276 if (bb > 0) {
277 return (aa + bb - 1) / bb;
278 } else {
279 return (aa + bb + 1) / bb;
280 }
281 }
282 return std::ceil((a / b).as<double>());
283}
284
285inline IntOrDouble max(const IntOrDouble& a, const IntOrDouble& b) {
286 if (a.is_int() && b.is_int()) {
287 return std::max(a.as<int64_t>(), b.as<int64_t>());
288 }
289 return (a > b ? a : b).cast<double>();
290}
291
292inline IntOrDouble min(const IntOrDouble& a, const IntOrDouble& b) {
293 if (a.is_int() && b.is_int()) {
294 return std::min(a.as<int64_t>(), b.as<int64_t>());
295 }
296 return (a < b ? a : b).cast<double>();
297}
298
299inline IntOrDouble abs(const IntOrDouble& a) {
300 if (a.is_int()) {
301 return IntOrDouble(std::abs(a.as<int64_t>()));
302 } else {
303 return IntOrDouble(std::abs(a.as<double>()));
304 }
305}
306
307} // namespace IntOrDouble_functions
308
309} // namespace cuda
310} // namespace fuser
311} // namespace jit
312} // namespace torch
313