1#include <ATen/AccumulateType.h>
2
3namespace at {
4
5c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) {
6 switch (type) {
7#define DEFINE_CASE(scalar_t, TypeNum) \
8 case ScalarType::TypeNum: \
9 return is_cuda ? \
10 CppTypeToScalarType<at::acc_type<scalar_t, true>>::value : \
11 CppTypeToScalarType<at::acc_type<scalar_t, false>>::value;
12
13 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CASE)
14#undef DEFINE_CASE
15
16 default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
17 }
18}
19
20}
21