1 | #pragma once |
---|---|
2 | |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/util/BFloat16.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/Half.h> |
7 | |
8 | namespace at { |
9 | |
10 | // For FP16 or BFloat16 inputs, ops should perform internal math in FP32. |
11 | template <typename scalar_t> |
12 | struct OpMathType { |
13 | using type = scalar_t; |
14 | }; |
15 | template <> |
16 | struct OpMathType<at::Half> { |
17 | using type = float; |
18 | }; |
19 | template <> |
20 | struct OpMathType<at::BFloat16> { |
21 | using type = float; |
22 | }; |
23 | template <> |
24 | struct OpMathType<c10::complex<Half>> { |
25 | using type = c10::complex<float>; |
26 | }; |
27 | |
28 | template <typename T> |
29 | using opmath_type = typename OpMathType<T>::type; |
30 | |
31 | namespace { |
32 | |
33 | inline c10::ScalarType toOpMathType(const c10::ScalarType type) { |
34 | switch (type) { |
35 | #define DEFINE_CASE(scalar_t, TypeNum) \ |
36 | case ScalarType::TypeNum: \ |
37 | return CppTypeToScalarType<at::opmath_type<scalar_t>>::value; |
38 | |
39 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) |
40 | #undef DEFINE_CASE |
41 | |
42 | default: |
43 | TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); |
44 | } |
45 | } |
46 | |
47 | } // namespace |
48 | |
49 | } // namespace at |
50 |