1 | #pragma once |
2 | |
3 | #include <ATen/core/DeprecatedTypeProperties.h> |
4 | #include <c10/macros/Macros.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/Half.h> |
7 | #include <c10/util/Metaprogramming.h> |
8 | #include <c10/util/complex.h> |
9 | #include <c10/util/string_view.h> |
10 | |
11 | #ifdef __CUDACC__ |
12 | #include <cuda.h> // For CUDA_VERSION |
13 | #endif |
14 | |
15 | #ifdef TEMPLATE_SELECTIVE_BUILD |
16 | #include <ATen/selected_mobile_ops.h> |
17 | #else |
18 | namespace at { |
19 | /** |
20 | * The method should_include_kernel_dtype() returns true/false |
21 | * based on whether the switching code for a specific dtype should be |
22 | * included based on build time constants generated from tracing model |
23 | * execution. This method will be implmeneted via code-generation and |
24 | * included in this file when code-gen is ready. |
25 | */ |
26 | inline constexpr bool should_include_kernel_dtype( |
27 | const char* /*kernel_tag_str*/, |
28 | at::ScalarType /*scalar_type*/ |
29 | ) { |
30 | return true; |
31 | } |
32 | } // namespace at |
33 | #endif |
34 | |
35 | /** |
36 | * In the Facebook internal build (using BUCK), this macro is enabled by |
37 | * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer |
38 | * binary. |
39 | */ |
40 | #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE |
41 | namespace at { |
42 | namespace detail { |
43 | TORCH_API void record_kernel_function_dtype(std::string name); |
44 | } |
45 | } // namespace at |
46 | |
47 | #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \ |
48 | at::detail::record_kernel_function_dtype( \ |
49 | std::string(NAME) + "$" + toString(enum_type)); |
50 | #else |
51 | #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) |
52 | #endif |
53 | |
54 | // Avoid if_constexpr if possble, as it's more expensive to compile |
55 | #if defined __cpp_if_constexpr |
56 | #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \ |
57 | do { \ |
58 | if constexpr (!at::should_include_kernel_dtype( \ |
59 | at_dispatch_name, enum_type)) { \ |
60 | AT_ERROR( \ |
61 | "dtype '", \ |
62 | toString(enum_type), \ |
63 | "' not selected for kernel tag ", \ |
64 | at_dispatch_name); \ |
65 | } \ |
66 | } while (0) |
67 | #else // defined __cpp_if_constexpr |
68 | #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \ |
69 | at::guts::if_constexpr<!at::should_include_kernel_dtype( \ |
70 | at_dispatch_name, enum_type)>([&] { \ |
71 | AT_ERROR( \ |
72 | "dtype '", \ |
73 | toString(enum_type), \ |
74 | "' not selected for kernel tag ", \ |
75 | at_dispatch_name); \ |
76 | }) |
77 | #endif |
78 | |
79 | #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ |
80 | case enum_type: { \ |
81 | AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ |
82 | using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \ |
83 | return __VA_ARGS__(); \ |
84 | } |
85 | |
86 | #define AT_DISPATCH_CASE(enum_type, ...) \ |
87 | AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__) |
88 | |
89 | #define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \ |
90 | case enum_type: { \ |
91 | AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ |
92 | using scalar_t = scalar_type; \ |
93 | using underlying_t C10_UNUSED = typename scalar_t::underlying; \ |
94 | const auto& SCALAR_TYPE C10_UNUSED = enum_type; \ |
95 | const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \ |
96 | return __VA_ARGS__(); \ |
97 | } |
98 | |
99 | #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ |
100 | enum_type, scalar_type, bitwidth, qmin, qmax, ...) \ |
101 | case enum_type: { \ |
102 | AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ |
103 | using scalar_t = scalar_type; \ |
104 | using underlying_t C10_UNUSED = typename scalar_t::underlying; \ |
105 | const auto& SCALAR_TYPE C10_UNUSED = enum_type; \ |
106 | const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \ |
107 | C10_UNUSED int bit_width = bitwidth; \ |
108 | C10_UNUSED int64_t quant_min = qmin; \ |
109 | C10_UNUSED int64_t quant_max = qmax; \ |
110 | return __VA_ARGS__(); \ |
111 | } |
112 | |
113 | namespace detail { |
114 | |
115 | inline at::ScalarType scalar_type(at::ScalarType s) { |
116 | return s; |
117 | } |
118 | |
119 | C10_DEPRECATED_MESSAGE( |
120 | "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, " |
121 | "pass an at::ScalarType instead" ) |
122 | inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) { |
123 | return t.scalarType(); |
124 | } |
125 | |
126 | C10_DEPRECATED_MESSAGE( |
127 | "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, " |
128 | "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead" ) |
129 | inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {} |
130 | |
131 | C10_DEPRECATED_MESSAGE( |
132 | "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, " |
133 | "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) " |
134 | "instead" ) |
135 | inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} |
136 | |
137 | } // namespace detail |
138 | |
139 | // The AT_DISPATCH_* family of macros provides the ability to |
140 | // conveniently generate specializations of a kernel over all of the |
141 | // dtypes we care about in PyTorch. We call it "dispatch" because |
142 | // we are "dispatching" to the correct, dtype-specific kernel. |
143 | // |
144 | // A standard usage looks like: |
145 | // |
146 | // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] { |
147 | // // Your code here, with 'scalar_t' now defined to |
148 | // // be the dtype in question |
149 | // }); |
150 | // |
151 | // There are many variations of this macro, so it's important to |
152 | // understand exactly /which/ dtypes you want to get instantiated, as |
153 | // well as what the "default" set is. |
154 | // |
155 | // The default set of dtypes that are instantiated (e.g., by |
156 | // AT_DISPATCH_ALL_TYPES) are floating point types (float, double), |
157 | // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t), |
158 | // but NOT booleans (bool), half-precision floats (Half) or |
159 | // complex number (c10::complex<float>, c10::complex<double>). |
160 | // This "cut" is somewhat historical (the default types are the |
161 | // ones that TH historically supported), but it also reflects the |
162 | // fact that the non-default types are "poorly" behaved (booleans |
163 | // are NOT integers mod 2, half precision operations ~essentially |
164 | // don't exist on CPU, complex numbers are an experimental application). |
165 | // |
166 | // Here are the questions you should generally ask to decide which |
167 | // dispatch you want: |
168 | // |
169 | // 1. Is this an integral or floating point specific operation? |
170 | // (If so, you'll want one of the FLOATING or INTEGRAL macros.) |
171 | // |
172 | // 2. Should half be supported? (If you're on CPU, the answer is almost |
173 | // definitely no. If you do want support, use one of the AND_HALF |
174 | // macros) |
175 | // |
176 | // Much rarer situations: |
177 | // |
178 | // 3. Should bool be supported? (You often have to write your kernel |
179 | // differently if arithmetic operations are involved.) If so, |
180 | // Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool |
181 | // |
182 | // 4. Should complex be supported? The answer is almost always no, |
183 | // unless you are working on "generic" code that should work on |
184 | // all dtypes. |
185 | // |
186 | // Parameters: |
187 | // ----------- |
188 | // |
189 | // 1. The NAME argument is a "tag" that is used to trace and then |
190 | // conditionally compile fragments of the case statements such |
191 | // that the kernel functions are specialized only for the dtypes |
192 | // that are needed. The NAME parameter *must* be a build time |
193 | // const char* (can't be std::string, etc...) |
194 | // |
195 | // Please ensure that the NAME is unique for every implementation |
196 | // or you run the risk of over-including code for the kernel |
197 | // functions. There is no risk of missing out on any code, so |
198 | // it's mostly a risk of a Type-2 error, and not a Type-1 error. |
199 | // |
200 | // Switch-like syntax: |
201 | // ------------------- |
202 | // There is also a switch-case like syntax which is useful if a kernel |
203 | // needs to be specialized for particular scalar types |
204 | // |
205 | // AT_DISPATCH_SWITCH(self.scalar_type(), "op_name", |
206 | // AT_DISPATCH_CASE_INTEGRAL_TYPES([&] { |
207 | // op_integral<scalar_t>(iter); |
208 | // }) |
209 | // AT_DISPATCH_CASE_FLOATING_TYPES([&] { |
210 | // op_floating<scalar_t>(iter); |
211 | // }) |
212 | // AT_DISPATCH_CASE(kBool, [&] { |
213 | // op_bool(iter); |
214 | // }) |
215 | // ); |
216 | // |
217 | // For each AT_DISPATCH_FOO macro, there is a corresponding |
218 | // AT_DISPATCH_CASE_FOO macro which can be used inside of an |
219 | // AT_DISPATCH_SWITCH block. |
220 | |
221 | // NB: the the_type variable is not used, but we have kept it for |
222 | // backwards compatibility. It's probably not used by anyone though; |
223 | // but we're just being safe (and it doesn't hurt.) Note we must |
224 | // use it to shut up warnings about unused store. |
225 | |
226 | #define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \ |
227 | [&] { \ |
228 | const auto& the_type = TYPE; \ |
229 | constexpr const char* at_dispatch_name = NAME; \ |
230 | /* don't use TYPE again in case it is an expensive or side-effect op */ \ |
231 | at::ScalarType _st = ::detail::scalar_type(the_type); \ |
232 | RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \ |
233 | switch (_st) { \ |
234 | __VA_ARGS__ \ |
235 | default: \ |
236 | AT_ERROR( \ |
237 | '"', \ |
238 | at_dispatch_name, \ |
239 | "\" not implemented for '", \ |
240 | toString(_st), \ |
241 | "'"); \ |
242 | } \ |
243 | }() |
244 | |
245 | #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \ |
246 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ |
247 | AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) |
248 | |
249 | #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ |
250 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) |
251 | |
252 | #define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \ |
253 | AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ |
254 | AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ |
255 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) |
256 | |
257 | #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ |
258 | AT_DISPATCH_SWITCH( \ |
259 | TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__)) |
260 | |
261 | #define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \ |
262 | AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ |
263 | AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) |
264 | |
265 | #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ |
266 | AT_DISPATCH_SWITCH( \ |
267 | TYPE, \ |
268 | NAME, \ |
269 | AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__)) |
270 | |
271 | #define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ |
272 | AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ |
273 | AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ |
274 | AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) |
275 | |
276 | #define AT_DISPATCH_FLOATING_TYPES_AND2( \ |
277 | SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ |
278 | AT_DISPATCH_SWITCH( \ |
279 | TYPE, \ |
280 | NAME, \ |
281 | AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \ |
282 | SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) |
283 | |
284 | #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \ |
285 | AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \ |
286 | AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__) |
287 | |
288 | #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ |
289 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)) |
290 | |
291 | #define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \ |
292 | AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \ |
293 | AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) |
294 | |
295 | #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ |
296 | AT_DISPATCH_SWITCH( \ |
297 | TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__)) |
298 | |
299 | #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \ |
300 | AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ |
301 | AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) |
302 | |
303 | #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ |
304 | AT_DISPATCH_SWITCH( \ |
305 | TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)) |
306 | |
307 | #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \ |
308 | AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ |
309 | AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) |
310 | |
311 | #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \ |
312 | SCALARTYPE, TYPE, NAME, ...) \ |
313 | AT_DISPATCH_SWITCH( \ |
314 | TYPE, \ |
315 | NAME, \ |
316 | AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \ |
317 | SCALARTYPE, __VA_ARGS__)) |
318 | |
319 | #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \ |
320 | SCALARTYPE1, SCALARTYPE2, ...) \ |
321 | AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ |
322 | AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ |
323 | AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) |
324 | |
325 | #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \ |
326 | SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ |
327 | AT_DISPATCH_SWITCH( \ |
328 | TYPE, \ |
329 | NAME, \ |
330 | AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \ |
331 | SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) |
332 | |
333 | #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \ |
334 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ |
335 | AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ |
336 | AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ |
337 | AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ |
338 | AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) |
339 | |
340 | #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \ |
341 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ |
342 | AT_DISPATCH_SWITCH( \ |
343 | TYPE, \ |
344 | NAME, \ |
345 | AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \ |
346 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) |
347 | |
348 | #define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \ |
349 | AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ |
350 | AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ |
351 | AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ |
352 | AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ |
353 | AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) |
354 | |
355 | #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ |
356 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) |
357 | |
358 | #define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \ |
359 | AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \ |
360 | AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) |
361 | |
362 | #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ |
363 | AT_DISPATCH_SWITCH( \ |
364 | TYPE, \ |
365 | NAME, \ |
366 | AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__)) |
367 | |
368 | #define AT_DISPATCH_CASE_ALL_TYPES(...) \ |
369 | AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \ |
370 | AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) |
371 | |
372 | #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ |
373 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)) |
374 | |
375 | #define AT_DISPATCH_CASE_QINT_TYPES(...) \ |
376 | AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \ |
377 | AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \ |
378 | AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__) |
379 | |
380 | #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \ |
381 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__)) |
382 | |
383 | #define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \ |
384 | AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \ |
385 | AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) |
386 | |
387 | #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \ |
388 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__)) |
389 | |
390 | #define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \ |
391 | AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ |
392 | at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \ |
393 | AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ |
394 | at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \ |
395 | AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ |
396 | at::kQInt32, \ |
397 | at::qint32, \ |
398 | CHAR_BIT * sizeof(int), \ |
399 | INT_MIN, \ |
400 | INT_MAX, \ |
401 | __VA_ARGS__) \ |
402 | AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ |
403 | at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \ |
404 | AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ |
405 | at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__) |
406 | |
407 | #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \ |
408 | AT_DISPATCH_SWITCH( \ |
409 | TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__)) |
410 | |
411 | #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \ |
412 | AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ |
413 | AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) |
414 | |
415 | #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \ |
416 | AT_DISPATCH_SWITCH( \ |
417 | TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)) |
418 | |
419 | #define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \ |
420 | AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ |
421 | AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) |
422 | |
423 | #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ |
424 | AT_DISPATCH_SWITCH( \ |
425 | TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__)) |
426 | |
427 | #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \ |
428 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ |
429 | AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) |
430 | |
431 | #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ |
432 | AT_DISPATCH_SWITCH( \ |
433 | TYPE, \ |
434 | NAME, \ |
435 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__)) |
436 | |
437 | #define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ |
438 | AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ |
439 | AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ |
440 | AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) |
441 | |
442 | #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ |
443 | AT_DISPATCH_SWITCH( \ |
444 | TYPE, \ |
445 | NAME, \ |
446 | AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) |
447 | |
448 | #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \ |
449 | SCALARTYPE1, SCALARTYPE2, ...) \ |
450 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ |
451 | AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ |
452 | AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) |
453 | |
454 | #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ |
455 | SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ |
456 | AT_DISPATCH_SWITCH( \ |
457 | TYPE, \ |
458 | NAME, \ |
459 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \ |
460 | SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) |
461 | |
462 | #define AT_DISPATCH_CASE_ALL_TYPES_AND3( \ |
463 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ |
464 | AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ |
465 | AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ |
466 | AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ |
467 | AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) |
468 | |
469 | #define AT_DISPATCH_ALL_TYPES_AND3( \ |
470 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ |
471 | AT_DISPATCH_SWITCH( \ |
472 | TYPE, \ |
473 | NAME, \ |
474 | AT_DISPATCH_CASE_ALL_TYPES_AND3( \ |
475 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) |
476 | |
477 | #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \ |
478 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ |
479 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ |
480 | AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ |
481 | AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ |
482 | AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) |
483 | |
484 | #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ |
485 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ |
486 | AT_DISPATCH_SWITCH( \ |
487 | TYPE, \ |
488 | NAME, \ |
489 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \ |
490 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) |
491 | |
492 | #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ |
493 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \ |
494 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ |
495 | AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ |
496 | AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ |
497 | AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ |
498 | AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) |
499 | |
500 | #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ |
501 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ |
502 | AT_DISPATCH_SWITCH( \ |
503 | TYPE, \ |
504 | NAME, \ |
505 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ |
506 | SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) |
507 | |
508 | #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \ |
509 | AT_DISPATCH_SWITCH( \ |
510 | TYPE, \ |
511 | NAME, \ |
512 | AT_PRIVATE_CASE_TYPE_USING_HINT( \ |
513 | at::ScalarType::Int, index_t, __VA_ARGS__) \ |
514 | AT_PRIVATE_CASE_TYPE_USING_HINT( \ |
515 | at::ScalarType::Long, index_t, __VA_ARGS__)) |
516 | |
517 | // ---------------------------------------------------------------------------- |
518 | // DEPRECATED MACROS, DON'T USE THESE |
519 | // ---------------------------------------------------------------------------- |
520 | |
521 | #define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \ |
522 | detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \ |
523 | AT_DISPATCH_SWITCH( \ |
524 | TYPE, \ |
525 | NAME, \ |
526 | AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__)) |
527 | |