1#pragma once
2
3#include <c10/core/ScalarType.h>
4#include <c10/util/BFloat16.h>
5#include <c10/util/Exception.h>
6#include <c10/util/Half.h>
7
8namespace at {
9
10// For FP16 or BFloat16 inputs, ops should perform internal math in FP32.
11template <typename scalar_t>
12struct OpMathType {
13 using type = scalar_t;
14};
15template <>
16struct OpMathType<at::Half> {
17 using type = float;
18};
19template <>
20struct OpMathType<at::BFloat16> {
21 using type = float;
22};
23template <>
24struct OpMathType<c10::complex<Half>> {
25 using type = c10::complex<float>;
26};
27
28template <typename T>
29using opmath_type = typename OpMathType<T>::type;
30
31namespace {
32
33inline c10::ScalarType toOpMathType(const c10::ScalarType type) {
34 switch (type) {
35#define DEFINE_CASE(scalar_t, TypeNum) \
36 case ScalarType::TypeNum: \
37 return CppTypeToScalarType<at::opmath_type<scalar_t>>::value;
38
39 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
40#undef DEFINE_CASE
41
42 default:
43 TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
44 }
45}
46
47} // namespace
48
49} // namespace at
50