1#pragma once
2
3#include <ATen/core/DeprecatedTypeProperties.h>
4#include <c10/macros/Macros.h>
5#include <c10/util/Exception.h>
6#include <c10/util/Half.h>
7#include <c10/util/Metaprogramming.h>
8#include <c10/util/complex.h>
9#include <c10/util/string_view.h>
10
11#ifdef __CUDACC__
12#include <cuda.h> // For CUDA_VERSION
13#endif
14
15#ifdef TEMPLATE_SELECTIVE_BUILD
16#include <ATen/selected_mobile_ops.h>
17#else
18namespace at {
19/**
20 * The method should_include_kernel_dtype() returns true/false
21 * based on whether the switching code for a specific dtype should be
22 * included based on build time constants generated from tracing model
23 * execution. This method will be implmeneted via code-generation and
24 * included in this file when code-gen is ready.
25 */
26inline constexpr bool should_include_kernel_dtype(
27 const char* /*kernel_tag_str*/,
28 at::ScalarType /*scalar_type*/
29) {
30 return true;
31}
32} // namespace at
33#endif
34
35/**
36 * In the Facebook internal build (using BUCK), this macro is enabled by
37 * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
38 * binary.
39 */
40#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
41namespace at {
42namespace detail {
43TORCH_API void record_kernel_function_dtype(std::string name);
44}
45} // namespace at
46
47#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
48 at::detail::record_kernel_function_dtype( \
49 std::string(NAME) + "$" + toString(enum_type));
50#else
51#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
52#endif
53
54// Avoid if_constexpr if possble, as it's more expensive to compile
55#if defined __cpp_if_constexpr
56#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
57 do { \
58 if constexpr (!at::should_include_kernel_dtype( \
59 at_dispatch_name, enum_type)) { \
60 AT_ERROR( \
61 "dtype '", \
62 toString(enum_type), \
63 "' not selected for kernel tag ", \
64 at_dispatch_name); \
65 } \
66 } while (0)
67#else // defined __cpp_if_constexpr
68#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
69 at::guts::if_constexpr<!at::should_include_kernel_dtype( \
70 at_dispatch_name, enum_type)>([&] { \
71 AT_ERROR( \
72 "dtype '", \
73 toString(enum_type), \
74 "' not selected for kernel tag ", \
75 at_dispatch_name); \
76 })
77#endif
78
79#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
80 case enum_type: { \
81 AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
82 using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
83 return __VA_ARGS__(); \
84 }
85
86#define AT_DISPATCH_CASE(enum_type, ...) \
87 AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
88
89#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
90 case enum_type: { \
91 AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
92 using scalar_t = scalar_type; \
93 using underlying_t C10_UNUSED = typename scalar_t::underlying; \
94 const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
95 const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
96 return __VA_ARGS__(); \
97 }
98
99#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
100 enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
101 case enum_type: { \
102 AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
103 using scalar_t = scalar_type; \
104 using underlying_t C10_UNUSED = typename scalar_t::underlying; \
105 const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
106 const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
107 C10_UNUSED int bit_width = bitwidth; \
108 C10_UNUSED int64_t quant_min = qmin; \
109 C10_UNUSED int64_t quant_max = qmax; \
110 return __VA_ARGS__(); \
111 }
112
113namespace detail {
114
115inline at::ScalarType scalar_type(at::ScalarType s) {
116 return s;
117}
118
119C10_DEPRECATED_MESSAGE(
120 "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
121 "pass an at::ScalarType instead")
122inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
123 return t.scalarType();
124}
125
126C10_DEPRECATED_MESSAGE(
127 "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
128 "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
129inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
130
131C10_DEPRECATED_MESSAGE(
132 "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
133 "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
134 "instead")
135inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
136
137} // namespace detail
138
139// The AT_DISPATCH_* family of macros provides the ability to
140// conveniently generate specializations of a kernel over all of the
141// dtypes we care about in PyTorch. We call it "dispatch" because
142// we are "dispatching" to the correct, dtype-specific kernel.
143//
144// A standard usage looks like:
145//
146// AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
147// // Your code here, with 'scalar_t' now defined to
148// // be the dtype in question
149// });
150//
151// There are many variations of this macro, so it's important to
152// understand exactly /which/ dtypes you want to get instantiated, as
153// well as what the "default" set is.
154//
155// The default set of dtypes that are instantiated (e.g., by
156// AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
157// and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
158// but NOT booleans (bool), half-precision floats (Half) or
159// complex number (c10::complex<float>, c10::complex<double>).
160// This "cut" is somewhat historical (the default types are the
161// ones that TH historically supported), but it also reflects the
162// fact that the non-default types are "poorly" behaved (booleans
163// are NOT integers mod 2, half precision operations ~essentially
164// don't exist on CPU, complex numbers are an experimental application).
165//
166// Here are the questions you should generally ask to decide which
167// dispatch you want:
168//
169// 1. Is this an integral or floating point specific operation?
170// (If so, you'll want one of the FLOATING or INTEGRAL macros.)
171//
172// 2. Should half be supported? (If you're on CPU, the answer is almost
173// definitely no. If you do want support, use one of the AND_HALF
174// macros)
175//
176// Much rarer situations:
177//
178// 3. Should bool be supported? (You often have to write your kernel
179// differently if arithmetic operations are involved.) If so,
180// Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
181//
182// 4. Should complex be supported? The answer is almost always no,
183// unless you are working on "generic" code that should work on
184// all dtypes.
185//
186// Parameters:
187// -----------
188//
189// 1. The NAME argument is a "tag" that is used to trace and then
190// conditionally compile fragments of the case statements such
191// that the kernel functions are specialized only for the dtypes
192// that are needed. The NAME parameter *must* be a build time
193// const char* (can't be std::string, etc...)
194//
195// Please ensure that the NAME is unique for every implementation
196// or you run the risk of over-including code for the kernel
197// functions. There is no risk of missing out on any code, so
198// it's mostly a risk of a Type-2 error, and not a Type-1 error.
199//
200// Switch-like syntax:
201// -------------------
202// There is also a switch-case like syntax which is useful if a kernel
203// needs to be specialized for particular scalar types
204//
205// AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
206// AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
207// op_integral<scalar_t>(iter);
208// })
209// AT_DISPATCH_CASE_FLOATING_TYPES([&] {
210// op_floating<scalar_t>(iter);
211// })
212// AT_DISPATCH_CASE(kBool, [&] {
213// op_bool(iter);
214// })
215// );
216//
217// For each AT_DISPATCH_FOO macro, there is a corresponding
218// AT_DISPATCH_CASE_FOO macro which can be used inside of an
219// AT_DISPATCH_SWITCH block.
220
221// NB: the the_type variable is not used, but we have kept it for
222// backwards compatibility. It's probably not used by anyone though;
223// but we're just being safe (and it doesn't hurt.) Note we must
224// use it to shut up warnings about unused store.
225
226#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
227 [&] { \
228 const auto& the_type = TYPE; \
229 constexpr const char* at_dispatch_name = NAME; \
230 /* don't use TYPE again in case it is an expensive or side-effect op */ \
231 at::ScalarType _st = ::detail::scalar_type(the_type); \
232 RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
233 switch (_st) { \
234 __VA_ARGS__ \
235 default: \
236 AT_ERROR( \
237 '"', \
238 at_dispatch_name, \
239 "\" not implemented for '", \
240 toString(_st), \
241 "'"); \
242 } \
243 }()
244
245#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
246 AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
247 AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
248
249#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
250 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
251
252#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
253 AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
254 AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
255 AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
256
257#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
258 AT_DISPATCH_SWITCH( \
259 TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
260
261#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
262 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
263 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
264
265#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
266 AT_DISPATCH_SWITCH( \
267 TYPE, \
268 NAME, \
269 AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
270
271#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
272 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
273 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
274 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
275
276#define AT_DISPATCH_FLOATING_TYPES_AND2( \
277 SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
278 AT_DISPATCH_SWITCH( \
279 TYPE, \
280 NAME, \
281 AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
282 SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
283
284#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
285 AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
286 AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
287
288#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
289 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
290
291#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
292 AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
293 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
294
295#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
296 AT_DISPATCH_SWITCH( \
297 TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
298
299#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
300 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
301 AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
302
303#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
304 AT_DISPATCH_SWITCH( \
305 TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
306
307#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
308 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
309 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
310
311#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
312 SCALARTYPE, TYPE, NAME, ...) \
313 AT_DISPATCH_SWITCH( \
314 TYPE, \
315 NAME, \
316 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
317 SCALARTYPE, __VA_ARGS__))
318
319#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
320 SCALARTYPE1, SCALARTYPE2, ...) \
321 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
322 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
323 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
324
325#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
326 SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
327 AT_DISPATCH_SWITCH( \
328 TYPE, \
329 NAME, \
330 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
331 SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
332
333#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
334 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
335 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
336 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
337 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
338 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
339
340#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
341 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
342 AT_DISPATCH_SWITCH( \
343 TYPE, \
344 NAME, \
345 AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
346 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
347
348#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
349 AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
350 AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
351 AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
352 AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
353 AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
354
355#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
356 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
357
358#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
359 AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
360 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
361
362#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
363 AT_DISPATCH_SWITCH( \
364 TYPE, \
365 NAME, \
366 AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
367
368#define AT_DISPATCH_CASE_ALL_TYPES(...) \
369 AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
370 AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
371
372#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
373 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
374
375#define AT_DISPATCH_CASE_QINT_TYPES(...) \
376 AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
377 AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
378 AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
379
380#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
381 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
382
383#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
384 AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
385 AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
386
387#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
388 AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
389
390#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
391 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
392 at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
393 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
394 at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
395 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
396 at::kQInt32, \
397 at::qint32, \
398 CHAR_BIT * sizeof(int), \
399 INT_MIN, \
400 INT_MAX, \
401 __VA_ARGS__) \
402 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
403 at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
404 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
405 at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
406
407#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
408 AT_DISPATCH_SWITCH( \
409 TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
410
411#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
412 AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
413 AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
414
415#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
416 AT_DISPATCH_SWITCH( \
417 TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
418
419#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
420 AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
421 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
422
423#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
424 AT_DISPATCH_SWITCH( \
425 TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
426
427#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
428 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
429 AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
430
431#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
432 AT_DISPATCH_SWITCH( \
433 TYPE, \
434 NAME, \
435 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
436
437#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
438 AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
439 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
440 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
441
442#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
443 AT_DISPATCH_SWITCH( \
444 TYPE, \
445 NAME, \
446 AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
447
448#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
449 SCALARTYPE1, SCALARTYPE2, ...) \
450 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
451 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
452 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
453
454#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
455 SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
456 AT_DISPATCH_SWITCH( \
457 TYPE, \
458 NAME, \
459 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
460 SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
461
462#define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
463 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
464 AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
465 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
466 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
467 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
468
469#define AT_DISPATCH_ALL_TYPES_AND3( \
470 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
471 AT_DISPATCH_SWITCH( \
472 TYPE, \
473 NAME, \
474 AT_DISPATCH_CASE_ALL_TYPES_AND3( \
475 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
476
477#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
478 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
479 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
480 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
481 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
482 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
483
484#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
485 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
486 AT_DISPATCH_SWITCH( \
487 TYPE, \
488 NAME, \
489 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
490 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
491
492#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
493 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
494 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
495 AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
496 AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
497 AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
498 AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
499
500#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
501 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
502 AT_DISPATCH_SWITCH( \
503 TYPE, \
504 NAME, \
505 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
506 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
507
508#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
509 AT_DISPATCH_SWITCH( \
510 TYPE, \
511 NAME, \
512 AT_PRIVATE_CASE_TYPE_USING_HINT( \
513 at::ScalarType::Int, index_t, __VA_ARGS__) \
514 AT_PRIVATE_CASE_TYPE_USING_HINT( \
515 at::ScalarType::Long, index_t, __VA_ARGS__))
516
517// ----------------------------------------------------------------------------
518// DEPRECATED MACROS, DON'T USE THESE
519// ----------------------------------------------------------------------------
520
521#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
522 detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
523 AT_DISPATCH_SWITCH( \
524 TYPE, \
525 NAME, \
526 AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
527