1 | #include <ATen/AccumulateType.h> |
---|---|
2 | |
3 | namespace at { |
4 | |
5 | c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) { |
6 | switch (type) { |
7 | #define DEFINE_CASE(scalar_t, TypeNum) \ |
8 | case ScalarType::TypeNum: \ |
9 | return is_cuda ? \ |
10 | CppTypeToScalarType<at::acc_type<scalar_t, true>>::value : \ |
11 | CppTypeToScalarType<at::acc_type<scalar_t, false>>::value; |
12 | |
13 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CASE) |
14 | #undef DEFINE_CASE |
15 | |
16 | default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); |
17 | } |
18 | } |
19 | |
20 | } |
21 |