1#pragma once
2
3#include <c10/util/BFloat16.h>
4#include <c10/util/Exception.h>
5#include <c10/util/Half.h>
6#include <c10/util/complex.h>
7#include <c10/util/qint32.h>
8#include <c10/util/qint8.h>
9#include <c10/util/quint2x4.h>
10#include <c10/util/quint4x2.h>
11#include <c10/util/quint8.h>
12
13#include <complex>
14#include <cstdint>
15#include <ostream>
16
17namespace c10 {
18
19// For the macros below:
20// NB: If you want to macro some code for all non-QInt scalar types (i.e. types
21// with complete information, you probably want one of the
22// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
23// macros below, which are designed to behave similarly to the Dispatch macros
24// with the same name.
25
26// NB: Order matters for this macro; it is relied upon in
27// _promoteTypesLookup and the serialization format.
28#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
29 _(uint8_t, Byte) /* 0 */ \
30 _(int8_t, Char) /* 1 */ \
31 _(int16_t, Short) /* 2 */ \
32 _(int, Int) /* 3 */ \
33 _(int64_t, Long) /* 4 */ \
34 _(at::Half, Half) /* 5 */ \
35 _(float, Float) /* 6 */ \
36 _(double, Double) /* 7 */ \
37 _(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
38 _(c10::complex<float>, ComplexFloat) /* 9 */ \
39 _(c10::complex<double>, ComplexDouble) /* 10 */ \
40 _(bool, Bool) /* 11 */ \
41 _(c10::qint8, QInt8) /* 12 */ \
42 _(c10::quint8, QUInt8) /* 13 */ \
43 _(c10::qint32, QInt32) /* 14 */ \
44 _(at::BFloat16, BFloat16) /* 15 */ \
45 _(c10::quint4x2, QUInt4x2) /* 16 */ \
46 _(c10::quint2x4, QUInt2x4) /* 17 */
47
48// If you want to support ComplexHalf for real, add ComplexHalf
49// into this macro (and change the name). But beware: convert()
50// doesn't work for all the conversions you need...
51#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
52 _(uint8_t, Byte) \
53 _(int8_t, Char) \
54 _(int16_t, Short) \
55 _(int, Int) \
56 _(int64_t, Long) \
57 _(at::Half, Half) \
58 _(float, Float) \
59 _(double, Double) \
60 _(c10::complex<float>, ComplexFloat) \
61 _(c10::complex<double>, ComplexDouble) \
62 _(bool, Bool) \
63 _(at::BFloat16, BFloat16)
64
65#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
66 _(uint8_t, Byte) \
67 _(int8_t, Char) \
68 _(int16_t, Short) \
69 _(int, Int) \
70 _(int64_t, Long) \
71 _(at::Half, Half) \
72 _(float, Float) \
73 _(double, Double) \
74 _(c10::complex<c10::Half>, ComplexHalf) \
75 _(c10::complex<float>, ComplexFloat) \
76 _(c10::complex<double>, ComplexDouble) \
77 _(bool, Bool) \
78 _(at::BFloat16, BFloat16)
79
80enum class ScalarType : int8_t {
81#define DEFINE_ENUM(_1, n) n,
82 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM)
83#undef DEFINE_ENUM
84 Undefined,
85 NumOptions
86};
87
88constexpr uint16_t NumScalarTypes =
89 static_cast<uint16_t>(ScalarType::NumOptions);
90
91namespace impl {
92
93// These are used to map ScalarTypes to C++ types.
94
95template <c10::ScalarType N>
96struct ScalarTypeToCPPType;
97
98#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
99 template <> \
100 struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
101 using type = cpp_type; \
102 \
103 /* This is a workaround for the CUDA bug which prevents */ \
104 /* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
105 /* ambiguous reference which can't to be resolved. For some reason it */ \
106 /* can't pick between at::detail and at::cuda::detail. */ \
107 /* For repro example, please see: */ \
108 /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
109 /* TODO: remove once the bug is fixed. */ \
110 static type t; \
111 };
112
113AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
114
115#undef SPECIALIZE_ScalarTypeToCPPType
116
117template <c10::ScalarType N>
118using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
119
120} // namespace impl
121
122template <typename T>
123struct CppTypeToScalarType;
124
125#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
126 template <> \
127 struct CppTypeToScalarType<cpp_type> \
128 : std:: \
129 integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
130 };
131
132AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
133
134#undef SPECIALIZE_CppTypeToScalarType
135
136#define AT_FORALL_INT_TYPES(_) \
137 _(uint8_t, Byte) \
138 _(int8_t, Char) \
139 _(int16_t, Short) \
140 _(int, Int) \
141 _(int64_t, Long)
142
143#define AT_FORALL_SCALAR_TYPES(_) \
144 _(uint8_t, Byte) \
145 _(int8_t, Char) \
146 _(int16_t, Short) \
147 _(int, Int) \
148 _(int64_t, Long) \
149 _(float, Float) \
150 _(double, Double)
151
152#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
153 _(uint8_t, Byte) \
154 _(int8_t, Char) \
155 _(int16_t, Short) \
156 _(int, Int) \
157 _(int64_t, Long) \
158 _(float, Float) \
159 _(double, Double) \
160 _(decltype(::c10::impl::ScalarTypeToCPPType< \
161 ::c10::ScalarType::SCALARTYPE>::t), \
162 SCALARTYPE)
163
164#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
165 _(uint8_t, Byte) \
166 _(int8_t, Char) \
167 _(int16_t, Short) \
168 _(int, Int) \
169 _(int64_t, Long) \
170 _(float, Float) \
171 _(double, Double) \
172 _(decltype(::c10::impl::ScalarTypeToCPPType< \
173 ::c10::ScalarType::SCALARTYPE1>::t), \
174 SCALARTYPE1) \
175 _(decltype(::c10::impl::ScalarTypeToCPPType< \
176 ::c10::ScalarType::SCALARTYPE2>::t), \
177 SCALARTYPE2)
178
179#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
180 _(uint8_t, Byte) \
181 _(int8_t, Char) \
182 _(int16_t, Short) \
183 _(int, Int) \
184 _(int64_t, Long) \
185 _(float, Float) \
186 _(double, Double) \
187 _(decltype(::c10::impl::ScalarTypeToCPPType< \
188 ::c10::ScalarType::SCALARTYPE1>::t), \
189 SCALARTYPE1) \
190 _(decltype(::c10::impl::ScalarTypeToCPPType< \
191 ::c10::ScalarType::SCALARTYPE2>::t), \
192 SCALARTYPE2) \
193 _(decltype(::c10::impl::ScalarTypeToCPPType< \
194 ::c10::ScalarType::SCALARTYPE3>::t), \
195 SCALARTYPE3)
196
197#define AT_FORALL_QINT_TYPES(_) \
198 _(c10::qint8, QInt8) \
199 _(c10::quint8, QUInt8) \
200 _(c10::qint32, QInt32) \
201 _(c10::quint4x2, QUInt4x2) \
202 _(c10::quint2x4, QUInt2x4)
203
204#define AT_FORALL_COMPLEX_TYPES(_) \
205 _(c10::complex<float>, ComplexFloat) \
206 _(c10::complex<double>, ComplexDouble)
207
208#define DEFINE_CONSTANT(_, name) \
209 constexpr ScalarType k##name = ScalarType::name;
210
211AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
212#undef DEFINE_CONSTANT
213
214static inline const char* toString(ScalarType t) {
215#define DEFINE_CASE(_, name) \
216 case ScalarType::name: \
217 return #name;
218
219 switch (t) {
220 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
221 default:
222 return "UNKNOWN_SCALAR";
223 }
224#undef DEFINE_CASE
225}
226
227static inline size_t elementSize(ScalarType t) {
228#define CASE_ELEMENTSIZE_CASE(ctype, name) \
229 case ScalarType::name: \
230 return sizeof(ctype);
231
232 switch (t) {
233 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
234 default:
235 TORCH_CHECK(false, "Unknown ScalarType");
236 }
237#undef CASE_ELEMENTSIZE_CASE
238}
239
240static inline bool isIntegralType(ScalarType t, bool includeBool) {
241 bool isIntegral =
242 (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
243 t == ScalarType::Long || t == ScalarType::Short);
244
245 return isIntegral || (includeBool && t == ScalarType::Bool);
246}
247
248C10_DEPRECATED_MESSAGE(
249 "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
250static inline bool isIntegralType(ScalarType t) {
251 return isIntegralType(t, /*includeBool=*/false);
252}
253
254static inline bool isFloatingType(ScalarType t) {
255 return (
256 t == ScalarType::Double || t == ScalarType::Float ||
257 t == ScalarType::Half || t == ScalarType::BFloat16);
258}
259
260static inline bool isComplexType(ScalarType t) {
261 return (
262 t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
263 t == ScalarType::ComplexDouble);
264}
265
266static inline bool isQIntType(ScalarType t) {
267 // Don't forget to extend this when adding new QInt types
268 return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
269 t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
270 t == ScalarType::QUInt2x4;
271}
272
273static inline ScalarType toQIntType(ScalarType t) {
274 switch (t) {
275 case ScalarType::Byte:
276 return ScalarType::QUInt8;
277 case ScalarType::Char:
278 return ScalarType::QInt8;
279 case ScalarType::Int:
280 return ScalarType::QInt32;
281 default:
282 return t;
283 }
284}
285
286static inline ScalarType toUnderlying(ScalarType t) {
287 switch (t) {
288 case ScalarType::QUInt8:
289 return ScalarType::Byte;
290 case ScalarType::QInt8:
291 return ScalarType::Char;
292 case ScalarType::QInt32:
293 return ScalarType::Int;
294 case ScalarType::QUInt4x2:
295 return ScalarType::Byte;
296 case ScalarType::QUInt2x4:
297 return ScalarType::Byte;
298 default:
299 return t;
300 }
301}
302
303static inline bool isSignedType(ScalarType t) {
304 TORCH_CHECK(!isQIntType(t), "isSignedType not supported for quantized types");
305#define CASE_SIGNED(ctype, name) \
306 case ScalarType::name: \
307 return std::numeric_limits<ctype>::is_signed;
308
309 switch (t) {
310 case ScalarType::ComplexHalf:
311 case ScalarType::ComplexFloat:
312 case ScalarType::ComplexDouble:
313 return true;
314 AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
315 default:
316 TORCH_CHECK(false, "Unknown ScalarType");
317 }
318#undef CASE_SIGNED
319}
320
321static inline bool isUnderlying(ScalarType type, ScalarType qtype) {
322 return type == toUnderlying(qtype);
323}
324
325static inline ScalarType toRealValueType(ScalarType t) {
326 switch (t) {
327 case ScalarType::ComplexHalf:
328 return ScalarType::Half;
329 case ScalarType::ComplexFloat:
330 return ScalarType::Float;
331 case ScalarType::ComplexDouble:
332 return ScalarType::Double;
333 default:
334 return t;
335 }
336}
337
338static inline ScalarType toComplexType(ScalarType t) {
339 switch (t) {
340 case ScalarType::BFloat16:
341 // BFloat16 has range equivalent to Float,
342 // so we map it to ComplexFloat.
343 return ScalarType::ComplexFloat;
344 case ScalarType::Half:
345 return ScalarType::ComplexHalf;
346 case ScalarType::Float:
347 return ScalarType::ComplexFloat;
348 case ScalarType::Double:
349 return ScalarType::ComplexDouble;
350 case ScalarType::ComplexHalf:
351 return ScalarType::ComplexHalf;
352 case ScalarType::ComplexFloat:
353 return ScalarType::ComplexFloat;
354 case ScalarType::ComplexDouble:
355 return ScalarType::ComplexDouble;
356 default:
357 TORCH_CHECK(false, "Unknown Complex ScalarType for ", t);
358 }
359}
360
361// see tensor_attributes.rst for detailed explanation and examples
362// of casting rules.
363static inline bool canCast(const ScalarType from, const ScalarType to) {
364 // We disallow complex -> non complex, e.g., float_tensor *= complex is
365 // disallowed.
366 if (isComplexType(from) && !isComplexType(to)) {
367 return false;
368 }
369 // We disallow float -> integral, e.g., int_tensor *= float is disallowed.
370 if (isFloatingType(from) && isIntegralType(to, false)) {
371 return false;
372 }
373
374 // Treat bool as a distinct "category," to be consistent with type promotion
375 // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same
376 // category as `bool_tensor`, we would not promote. Differing categories
377 // implies `bool_tensor += 5` is disallowed.
378 //
379 // NB: numpy distinguishes "unsigned" as a category to get the desired
380 // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because:
381 // * We don't want the performance hit of checking the runtime sign of
382 // Scalars.
383 // * `uint8_tensor + 5 -> int64_tensor` would be undesirable.
384 if (from != ScalarType::Bool && to == ScalarType::Bool) {
385 return false;
386 }
387 return true;
388}
389
390static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
391 // This is generated according to NumPy's promote_types
392 constexpr auto u1 = ScalarType::Byte;
393 constexpr auto i1 = ScalarType::Char;
394 constexpr auto i2 = ScalarType::Short;
395 constexpr auto i4 = ScalarType::Int;
396 constexpr auto i8 = ScalarType::Long;
397 constexpr auto f2 = ScalarType::Half;
398 constexpr auto f4 = ScalarType::Float;
399 constexpr auto f8 = ScalarType::Double;
400 constexpr auto c2 = ScalarType::ComplexHalf;
401 constexpr auto c4 = ScalarType::ComplexFloat;
402 constexpr auto c8 = ScalarType::ComplexDouble;
403 constexpr auto b1 = ScalarType::Bool;
404 constexpr auto bf = ScalarType::BFloat16;
405 constexpr auto ud = ScalarType::Undefined;
406 if (a == ud || b == ud) {
407 return ScalarType::Undefined;
408 }
409
410 // For QInt types, we only allow exact match
411 if (isQIntType(a) && a == b) {
412 return a;
413 }
414
415 if (isQIntType(a) || isQIntType(b)) {
416 TORCH_CHECK(
417 false,
418 "promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: ",
419 toString(a),
420 " ",
421 toString(b));
422 }
423
424 // this matrix has to be consistent with
425 // AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we
426 // are not sure about the correct value for type promotion.
427 static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
428 ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
429 /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
430 /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
431 /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
432 /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf},
433 /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf},
434 /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf},
435 /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4},
436 /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4},
437 /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8},
438 /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4},
439 /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
440 /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
441 /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf},
442 /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
443 /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
444 /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
445 /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
446 };
447 return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
448}
449
450inline std::ostream& operator<<(
451 std::ostream& stream,
452 at::ScalarType scalar_type) {
453 return stream << toString(scalar_type);
454}
455
456#define AT_FORAUTOCAST_SCALAR_TYPES(_) \
457 _(half, Half) /* 0 */ \
458 _(bfloat16, BFloat16) /* 1 */
459
460} // namespace c10
461