1 | #pragma once |
2 | |
3 | #include <c10/util/BFloat16.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/Half.h> |
6 | #include <c10/util/complex.h> |
7 | #include <c10/util/qint32.h> |
8 | #include <c10/util/qint8.h> |
9 | #include <c10/util/quint2x4.h> |
10 | #include <c10/util/quint4x2.h> |
11 | #include <c10/util/quint8.h> |
12 | |
13 | #include <complex> |
14 | #include <cstdint> |
15 | #include <ostream> |
16 | |
17 | namespace c10 { |
18 | |
19 | // For the macros below: |
20 | // NB: If you want to macro some code for all non-QInt scalar types (i.e. types |
21 | // with complete information, you probably want one of the |
22 | // AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND |
23 | // macros below, which are designed to behave similarly to the Dispatch macros |
24 | // with the same name. |
25 | |
26 | // NB: Order matters for this macro; it is relied upon in |
27 | // _promoteTypesLookup and the serialization format. |
28 | #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ |
29 | _(uint8_t, Byte) /* 0 */ \ |
30 | _(int8_t, Char) /* 1 */ \ |
31 | _(int16_t, Short) /* 2 */ \ |
32 | _(int, Int) /* 3 */ \ |
33 | _(int64_t, Long) /* 4 */ \ |
34 | _(at::Half, Half) /* 5 */ \ |
35 | _(float, Float) /* 6 */ \ |
36 | _(double, Double) /* 7 */ \ |
37 | _(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \ |
38 | _(c10::complex<float>, ComplexFloat) /* 9 */ \ |
39 | _(c10::complex<double>, ComplexDouble) /* 10 */ \ |
40 | _(bool, Bool) /* 11 */ \ |
41 | _(c10::qint8, QInt8) /* 12 */ \ |
42 | _(c10::quint8, QUInt8) /* 13 */ \ |
43 | _(c10::qint32, QInt32) /* 14 */ \ |
44 | _(at::BFloat16, BFloat16) /* 15 */ \ |
45 | _(c10::quint4x2, QUInt4x2) /* 16 */ \ |
46 | _(c10::quint2x4, QUInt2x4) /* 17 */ |
47 | |
48 | // If you want to support ComplexHalf for real, add ComplexHalf |
49 | // into this macro (and change the name). But beware: convert() |
50 | // doesn't work for all the conversions you need... |
51 | #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \ |
52 | _(uint8_t, Byte) \ |
53 | _(int8_t, Char) \ |
54 | _(int16_t, Short) \ |
55 | _(int, Int) \ |
56 | _(int64_t, Long) \ |
57 | _(at::Half, Half) \ |
58 | _(float, Float) \ |
59 | _(double, Double) \ |
60 | _(c10::complex<float>, ComplexFloat) \ |
61 | _(c10::complex<double>, ComplexDouble) \ |
62 | _(bool, Bool) \ |
63 | _(at::BFloat16, BFloat16) |
64 | |
65 | #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ |
66 | _(uint8_t, Byte) \ |
67 | _(int8_t, Char) \ |
68 | _(int16_t, Short) \ |
69 | _(int, Int) \ |
70 | _(int64_t, Long) \ |
71 | _(at::Half, Half) \ |
72 | _(float, Float) \ |
73 | _(double, Double) \ |
74 | _(c10::complex<c10::Half>, ComplexHalf) \ |
75 | _(c10::complex<float>, ComplexFloat) \ |
76 | _(c10::complex<double>, ComplexDouble) \ |
77 | _(bool, Bool) \ |
78 | _(at::BFloat16, BFloat16) |
79 | |
80 | enum class ScalarType : int8_t { |
81 | #define DEFINE_ENUM(_1, n) n, |
82 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM) |
83 | #undef DEFINE_ENUM |
84 | Undefined, |
85 | NumOptions |
86 | }; |
87 | |
88 | constexpr uint16_t NumScalarTypes = |
89 | static_cast<uint16_t>(ScalarType::NumOptions); |
90 | |
91 | namespace impl { |
92 | |
93 | // These are used to map ScalarTypes to C++ types. |
94 | |
95 | template <c10::ScalarType N> |
96 | struct ScalarTypeToCPPType; |
97 | |
98 | #define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ |
99 | template <> \ |
100 | struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \ |
101 | using type = cpp_type; \ |
102 | \ |
103 | /* This is a workaround for the CUDA bug which prevents */ \ |
104 | /* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \ |
105 | /* ambiguous reference which can't to be resolved. For some reason it */ \ |
106 | /* can't pick between at::detail and at::cuda::detail. */ \ |
107 | /* For repro example, please see: */ \ |
108 | /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \ |
109 | /* TODO: remove once the bug is fixed. */ \ |
110 | static type t; \ |
111 | }; |
112 | |
113 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) |
114 | |
115 | #undef SPECIALIZE_ScalarTypeToCPPType |
116 | |
117 | template <c10::ScalarType N> |
118 | using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type; |
119 | |
120 | } // namespace impl |
121 | |
122 | template <typename T> |
123 | struct CppTypeToScalarType; |
124 | |
125 | #define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \ |
126 | template <> \ |
127 | struct CppTypeToScalarType<cpp_type> \ |
128 | : std:: \ |
129 | integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \ |
130 | }; |
131 | |
132 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) |
133 | |
134 | #undef SPECIALIZE_CppTypeToScalarType |
135 | |
136 | #define AT_FORALL_INT_TYPES(_) \ |
137 | _(uint8_t, Byte) \ |
138 | _(int8_t, Char) \ |
139 | _(int16_t, Short) \ |
140 | _(int, Int) \ |
141 | _(int64_t, Long) |
142 | |
143 | #define AT_FORALL_SCALAR_TYPES(_) \ |
144 | _(uint8_t, Byte) \ |
145 | _(int8_t, Char) \ |
146 | _(int16_t, Short) \ |
147 | _(int, Int) \ |
148 | _(int64_t, Long) \ |
149 | _(float, Float) \ |
150 | _(double, Double) |
151 | |
152 | #define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ |
153 | _(uint8_t, Byte) \ |
154 | _(int8_t, Char) \ |
155 | _(int16_t, Short) \ |
156 | _(int, Int) \ |
157 | _(int64_t, Long) \ |
158 | _(float, Float) \ |
159 | _(double, Double) \ |
160 | _(decltype(::c10::impl::ScalarTypeToCPPType< \ |
161 | ::c10::ScalarType::SCALARTYPE>::t), \ |
162 | SCALARTYPE) |
163 | |
164 | #define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ |
165 | _(uint8_t, Byte) \ |
166 | _(int8_t, Char) \ |
167 | _(int16_t, Short) \ |
168 | _(int, Int) \ |
169 | _(int64_t, Long) \ |
170 | _(float, Float) \ |
171 | _(double, Double) \ |
172 | _(decltype(::c10::impl::ScalarTypeToCPPType< \ |
173 | ::c10::ScalarType::SCALARTYPE1>::t), \ |
174 | SCALARTYPE1) \ |
175 | _(decltype(::c10::impl::ScalarTypeToCPPType< \ |
176 | ::c10::ScalarType::SCALARTYPE2>::t), \ |
177 | SCALARTYPE2) |
178 | |
179 | #define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ |
180 | _(uint8_t, Byte) \ |
181 | _(int8_t, Char) \ |
182 | _(int16_t, Short) \ |
183 | _(int, Int) \ |
184 | _(int64_t, Long) \ |
185 | _(float, Float) \ |
186 | _(double, Double) \ |
187 | _(decltype(::c10::impl::ScalarTypeToCPPType< \ |
188 | ::c10::ScalarType::SCALARTYPE1>::t), \ |
189 | SCALARTYPE1) \ |
190 | _(decltype(::c10::impl::ScalarTypeToCPPType< \ |
191 | ::c10::ScalarType::SCALARTYPE2>::t), \ |
192 | SCALARTYPE2) \ |
193 | _(decltype(::c10::impl::ScalarTypeToCPPType< \ |
194 | ::c10::ScalarType::SCALARTYPE3>::t), \ |
195 | SCALARTYPE3) |
196 | |
197 | #define AT_FORALL_QINT_TYPES(_) \ |
198 | _(c10::qint8, QInt8) \ |
199 | _(c10::quint8, QUInt8) \ |
200 | _(c10::qint32, QInt32) \ |
201 | _(c10::quint4x2, QUInt4x2) \ |
202 | _(c10::quint2x4, QUInt2x4) |
203 | |
204 | #define AT_FORALL_COMPLEX_TYPES(_) \ |
205 | _(c10::complex<float>, ComplexFloat) \ |
206 | _(c10::complex<double>, ComplexDouble) |
207 | |
208 | #define DEFINE_CONSTANT(_, name) \ |
209 | constexpr ScalarType k##name = ScalarType::name; |
210 | |
211 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) |
212 | #undef DEFINE_CONSTANT |
213 | |
214 | static inline const char* toString(ScalarType t) { |
215 | #define DEFINE_CASE(_, name) \ |
216 | case ScalarType::name: \ |
217 | return #name; |
218 | |
219 | switch (t) { |
220 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) |
221 | default: |
222 | return "UNKNOWN_SCALAR" ; |
223 | } |
224 | #undef DEFINE_CASE |
225 | } |
226 | |
227 | static inline size_t elementSize(ScalarType t) { |
228 | #define CASE_ELEMENTSIZE_CASE(ctype, name) \ |
229 | case ScalarType::name: \ |
230 | return sizeof(ctype); |
231 | |
232 | switch (t) { |
233 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) |
234 | default: |
235 | TORCH_CHECK(false, "Unknown ScalarType" ); |
236 | } |
237 | #undef CASE_ELEMENTSIZE_CASE |
238 | } |
239 | |
240 | static inline bool isIntegralType(ScalarType t, bool includeBool) { |
241 | bool isIntegral = |
242 | (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || |
243 | t == ScalarType::Long || t == ScalarType::Short); |
244 | |
245 | return isIntegral || (includeBool && t == ScalarType::Bool); |
246 | } |
247 | |
248 | C10_DEPRECATED_MESSAGE( |
249 | "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead." ) |
250 | static inline bool isIntegralType(ScalarType t) { |
251 | return isIntegralType(t, /*includeBool=*/false); |
252 | } |
253 | |
254 | static inline bool isFloatingType(ScalarType t) { |
255 | return ( |
256 | t == ScalarType::Double || t == ScalarType::Float || |
257 | t == ScalarType::Half || t == ScalarType::BFloat16); |
258 | } |
259 | |
260 | static inline bool isComplexType(ScalarType t) { |
261 | return ( |
262 | t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || |
263 | t == ScalarType::ComplexDouble); |
264 | } |
265 | |
266 | static inline bool isQIntType(ScalarType t) { |
267 | // Don't forget to extend this when adding new QInt types |
268 | return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || |
269 | t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || |
270 | t == ScalarType::QUInt2x4; |
271 | } |
272 | |
273 | static inline ScalarType toQIntType(ScalarType t) { |
274 | switch (t) { |
275 | case ScalarType::Byte: |
276 | return ScalarType::QUInt8; |
277 | case ScalarType::Char: |
278 | return ScalarType::QInt8; |
279 | case ScalarType::Int: |
280 | return ScalarType::QInt32; |
281 | default: |
282 | return t; |
283 | } |
284 | } |
285 | |
286 | static inline ScalarType toUnderlying(ScalarType t) { |
287 | switch (t) { |
288 | case ScalarType::QUInt8: |
289 | return ScalarType::Byte; |
290 | case ScalarType::QInt8: |
291 | return ScalarType::Char; |
292 | case ScalarType::QInt32: |
293 | return ScalarType::Int; |
294 | case ScalarType::QUInt4x2: |
295 | return ScalarType::Byte; |
296 | case ScalarType::QUInt2x4: |
297 | return ScalarType::Byte; |
298 | default: |
299 | return t; |
300 | } |
301 | } |
302 | |
303 | static inline bool isSignedType(ScalarType t) { |
304 | TORCH_CHECK(!isQIntType(t), "isSignedType not supported for quantized types" ); |
305 | #define CASE_SIGNED(ctype, name) \ |
306 | case ScalarType::name: \ |
307 | return std::numeric_limits<ctype>::is_signed; |
308 | |
309 | switch (t) { |
310 | case ScalarType::ComplexHalf: |
311 | case ScalarType::ComplexFloat: |
312 | case ScalarType::ComplexDouble: |
313 | return true; |
314 | AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED) |
315 | default: |
316 | TORCH_CHECK(false, "Unknown ScalarType" ); |
317 | } |
318 | #undef CASE_SIGNED |
319 | } |
320 | |
321 | static inline bool isUnderlying(ScalarType type, ScalarType qtype) { |
322 | return type == toUnderlying(qtype); |
323 | } |
324 | |
325 | static inline ScalarType toRealValueType(ScalarType t) { |
326 | switch (t) { |
327 | case ScalarType::ComplexHalf: |
328 | return ScalarType::Half; |
329 | case ScalarType::ComplexFloat: |
330 | return ScalarType::Float; |
331 | case ScalarType::ComplexDouble: |
332 | return ScalarType::Double; |
333 | default: |
334 | return t; |
335 | } |
336 | } |
337 | |
338 | static inline ScalarType toComplexType(ScalarType t) { |
339 | switch (t) { |
340 | case ScalarType::BFloat16: |
341 | // BFloat16 has range equivalent to Float, |
342 | // so we map it to ComplexFloat. |
343 | return ScalarType::ComplexFloat; |
344 | case ScalarType::Half: |
345 | return ScalarType::ComplexHalf; |
346 | case ScalarType::Float: |
347 | return ScalarType::ComplexFloat; |
348 | case ScalarType::Double: |
349 | return ScalarType::ComplexDouble; |
350 | case ScalarType::ComplexHalf: |
351 | return ScalarType::ComplexHalf; |
352 | case ScalarType::ComplexFloat: |
353 | return ScalarType::ComplexFloat; |
354 | case ScalarType::ComplexDouble: |
355 | return ScalarType::ComplexDouble; |
356 | default: |
357 | TORCH_CHECK(false, "Unknown Complex ScalarType for " , t); |
358 | } |
359 | } |
360 | |
361 | // see tensor_attributes.rst for detailed explanation and examples |
362 | // of casting rules. |
363 | static inline bool canCast(const ScalarType from, const ScalarType to) { |
364 | // We disallow complex -> non complex, e.g., float_tensor *= complex is |
365 | // disallowed. |
366 | if (isComplexType(from) && !isComplexType(to)) { |
367 | return false; |
368 | } |
369 | // We disallow float -> integral, e.g., int_tensor *= float is disallowed. |
370 | if (isFloatingType(from) && isIntegralType(to, false)) { |
371 | return false; |
372 | } |
373 | |
374 | // Treat bool as a distinct "category," to be consistent with type promotion |
375 | // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same |
376 | // category as `bool_tensor`, we would not promote. Differing categories |
377 | // implies `bool_tensor += 5` is disallowed. |
378 | // |
379 | // NB: numpy distinguishes "unsigned" as a category to get the desired |
380 | // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because: |
381 | // * We don't want the performance hit of checking the runtime sign of |
382 | // Scalars. |
383 | // * `uint8_tensor + 5 -> int64_tensor` would be undesirable. |
384 | if (from != ScalarType::Bool && to == ScalarType::Bool) { |
385 | return false; |
386 | } |
387 | return true; |
388 | } |
389 | |
390 | static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { |
391 | // This is generated according to NumPy's promote_types |
392 | constexpr auto u1 = ScalarType::Byte; |
393 | constexpr auto i1 = ScalarType::Char; |
394 | constexpr auto i2 = ScalarType::Short; |
395 | constexpr auto i4 = ScalarType::Int; |
396 | constexpr auto i8 = ScalarType::Long; |
397 | constexpr auto f2 = ScalarType::Half; |
398 | constexpr auto f4 = ScalarType::Float; |
399 | constexpr auto f8 = ScalarType::Double; |
400 | constexpr auto c2 = ScalarType::ComplexHalf; |
401 | constexpr auto c4 = ScalarType::ComplexFloat; |
402 | constexpr auto c8 = ScalarType::ComplexDouble; |
403 | constexpr auto b1 = ScalarType::Bool; |
404 | constexpr auto bf = ScalarType::BFloat16; |
405 | constexpr auto ud = ScalarType::Undefined; |
406 | if (a == ud || b == ud) { |
407 | return ScalarType::Undefined; |
408 | } |
409 | |
410 | // For QInt types, we only allow exact match |
411 | if (isQIntType(a) && a == b) { |
412 | return a; |
413 | } |
414 | |
415 | if (isQIntType(a) || isQIntType(b)) { |
416 | TORCH_CHECK( |
417 | false, |
418 | "promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: " , |
419 | toString(a), |
420 | " " , |
421 | toString(b)); |
422 | } |
423 | |
424 | // this matrix has to be consistent with |
425 | // AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we |
426 | // are not sure about the correct value for type promotion. |
427 | static constexpr ScalarType _promoteTypesLookup[static_cast<int>( |
428 | ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = { |
429 | /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/ |
430 | /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf}, |
431 | /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf}, |
432 | /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf}, |
433 | /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf}, |
434 | /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf}, |
435 | /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4}, |
436 | /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4}, |
437 | /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8}, |
438 | /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4}, |
439 | /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4}, |
440 | /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8}, |
441 | /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf}, |
442 | /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, |
443 | /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, |
444 | /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, |
445 | /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf}, |
446 | }; |
447 | return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)]; |
448 | } |
449 | |
450 | inline std::ostream& operator<<( |
451 | std::ostream& stream, |
452 | at::ScalarType scalar_type) { |
453 | return stream << toString(scalar_type); |
454 | } |
455 | |
456 | #define AT_FORAUTOCAST_SCALAR_TYPES(_) \ |
457 | _(half, Half) /* 0 */ \ |
458 | _(bfloat16, BFloat16) /* 1 */ |
459 | |
460 | } // namespace c10 |
461 | |