1#pragma once
2#include <ATen/Config.h>
3#include <c10/core/ScalarType.h>
4#include <c10/util/BFloat16.h>
5#include <c10/util/Half.h>
6
7// Defines the accumulation type for a scalar type.
8// Example:
9// using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
10//
11// Accumulation types are an important concept in numeric computing
12// because you frequently want to perform intermediate computations
13// at a higher precision than the input and output precision, to avoid
14// compounding internal rounding errors. Accumulation is the most
15// well-known intermediate computation (it is of great importance for
16// sum reduction and matrix multiply, for example), but in PyTorch
17// acc_type ends up getting used for all sorts of other intermediate
18// computations, so it perhaps would be more accurately (ahem) called an
19// "accurate" type. acc_type is especially important for reduced
20// precision operations like float16 and bfloat16, where relatively
21// benign looking inputs can easily end up overflowing/underflowing.
22//
23// acc_type is parametrized by whether or not you are running on CUDA
24// or not, because on CUDA double precision operations are expensive
25// and so by default, we don't actually want to use double as an
26// acc_type on CUDA. A lot of things are typed out below, but
27// basically, the table is generated by a few rules:
28//
29// If bool:
30// Use 'bool' as acc_type.
31// If floating point:
32// If CUDA, use 'float' as acc_type (unless scalar_t is double),
33// otherwise (CPU) use 'double'
34// If integral:
35// Use 'int64_t' as acc_type
36//
37// You're not forced to use this template; if you happen to know
38// something specific about your use case, you can specify your own
39// desired behavior. This template, however, will give you a reasonable
40// default that will work for all dtypes supported in PyTorch.
41
42#if defined(__CUDACC__)
43#include <cuda.h>
44#include <cuda_fp16.h>
45#elif defined(__HIPCC__)
46#include <hip/hip_fp16.h>
47#include <hip/hip_runtime.h>
48#endif
49
50namespace at {
51
52template <typename T, bool is_cuda>
53struct AccumulateType {};
54
55#if defined(__CUDACC__) || defined(__HIPCC__)
56template <>
57struct AccumulateType<half, true> {
58 using type = float;
59};
60#endif
61template <>
62struct AccumulateType<BFloat16, true> {
63 using type = float;
64};
65template <>
66struct AccumulateType<Half, true> {
67 using type = float;
68};
69template <>
70struct AccumulateType<float, true> {
71 using type = float;
72};
73template <>
74struct AccumulateType<double, true> {
75 using type = double;
76};
77template <>
78struct AccumulateType<int8_t, true> {
79 using type = int64_t;
80};
81template <>
82struct AccumulateType<uint8_t, true> {
83 using type = int64_t;
84};
85template <>
86struct AccumulateType<char, true> {
87 using type = int64_t;
88};
89template <>
90struct AccumulateType<int16_t, true> {
91 using type = int64_t;
92};
93template <>
94struct AccumulateType<int32_t, true> {
95 using type = int64_t;
96};
97template <>
98struct AccumulateType<int64_t, true> {
99 using type = int64_t;
100};
101template <>
102struct AccumulateType<bool, true> {
103 using type = bool;
104};
105template <>
106struct AccumulateType<Half, false> {
107 using type = float;
108};
109template <>
110struct AccumulateType<BFloat16, false> {
111 using type = float;
112};
113template <>
114struct AccumulateType<c10::complex<Half>, false> {
115 using type = c10::complex<float>;
116};
117template <>
118struct AccumulateType<c10::complex<float>, false> {
119 using type = c10::complex<double>;
120};
121template <>
122struct AccumulateType<c10::complex<double>, false> {
123 using type = c10::complex<double>;
124};
125template <>
126struct AccumulateType<c10::complex<Half>, true> {
127 using type = c10::complex<float>;
128};
129template <>
130struct AccumulateType<c10::complex<float>, true> {
131 using type = c10::complex<float>;
132};
133template <>
134struct AccumulateType<c10::complex<double>, true> {
135 using type = c10::complex<double>;
136};
137template <>
138struct AccumulateType<float, false> {
139 using type = double;
140};
141template <>
142struct AccumulateType<double, false> {
143 using type = double;
144};
145template <>
146struct AccumulateType<int8_t, false> {
147 using type = int64_t;
148};
149template <>
150struct AccumulateType<uint8_t, false> {
151 using type = int64_t;
152};
153template <>
154struct AccumulateType<char, false> {
155 using type = int64_t;
156};
157template <>
158struct AccumulateType<int16_t, false> {
159 using type = int64_t;
160};
161template <>
162struct AccumulateType<int32_t, false> {
163 using type = int64_t;
164};
165template <>
166struct AccumulateType<int64_t, false> {
167 using type = int64_t;
168};
169template <>
170struct AccumulateType<bool, false> {
171 using type = bool;
172};
173
174template <typename T, bool is_cuda>
175using acc_type = typename AccumulateType<T, is_cuda>::type;
176
177TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
178
179} // namespace at
180