1 | #pragma once |
2 | #include <ATen/Config.h> |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/util/BFloat16.h> |
5 | #include <c10/util/Half.h> |
6 | |
7 | // Defines the accumulation type for a scalar type. |
8 | // Example: |
9 | // using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>; |
10 | // |
11 | // Accumulation types are an important concept in numeric computing |
12 | // because you frequently want to perform intermediate computations |
13 | // at a higher precision than the input and output precision, to avoid |
14 | // compounding internal rounding errors. Accumulation is the most |
15 | // well-known intermediate computation (it is of great importance for |
16 | // sum reduction and matrix multiply, for example), but in PyTorch |
17 | // acc_type ends up getting used for all sorts of other intermediate |
18 | // computations, so it perhaps would be more accurately (ahem) called an |
19 | // "accurate" type. acc_type is especially important for reduced |
20 | // precision operations like float16 and bfloat16, where relatively |
21 | // benign looking inputs can easily end up overflowing/underflowing. |
22 | // |
23 | // acc_type is parametrized by whether or not you are running on CUDA |
24 | // or not, because on CUDA double precision operations are expensive |
25 | // and so by default, we don't actually want to use double as an |
26 | // acc_type on CUDA. A lot of things are typed out below, but |
27 | // basically, the table is generated by a few rules: |
28 | // |
29 | // If bool: |
30 | // Use 'bool' as acc_type. |
31 | // If floating point: |
32 | // If CUDA, use 'float' as acc_type (unless scalar_t is double), |
33 | // otherwise (CPU) use 'double' |
34 | // If integral: |
35 | // Use 'int64_t' as acc_type |
36 | // |
37 | // You're not forced to use this template; if you happen to know |
38 | // something specific about your use case, you can specify your own |
39 | // desired behavior. This template, however, will give you a reasonable |
40 | // default that will work for all dtypes supported in PyTorch. |
41 | |
42 | #if defined(__CUDACC__) |
43 | #include <cuda.h> |
44 | #include <cuda_fp16.h> |
45 | #elif defined(__HIPCC__) |
46 | #include <hip/hip_fp16.h> |
47 | #include <hip/hip_runtime.h> |
48 | #endif |
49 | |
50 | namespace at { |
51 | |
52 | template <typename T, bool is_cuda> |
53 | struct AccumulateType {}; |
54 | |
55 | #if defined(__CUDACC__) || defined(__HIPCC__) |
56 | template <> |
57 | struct AccumulateType<half, true> { |
58 | using type = float; |
59 | }; |
60 | #endif |
61 | template <> |
62 | struct AccumulateType<BFloat16, true> { |
63 | using type = float; |
64 | }; |
65 | template <> |
66 | struct AccumulateType<Half, true> { |
67 | using type = float; |
68 | }; |
69 | template <> |
70 | struct AccumulateType<float, true> { |
71 | using type = float; |
72 | }; |
73 | template <> |
74 | struct AccumulateType<double, true> { |
75 | using type = double; |
76 | }; |
77 | template <> |
78 | struct AccumulateType<int8_t, true> { |
79 | using type = int64_t; |
80 | }; |
81 | template <> |
82 | struct AccumulateType<uint8_t, true> { |
83 | using type = int64_t; |
84 | }; |
85 | template <> |
86 | struct AccumulateType<char, true> { |
87 | using type = int64_t; |
88 | }; |
89 | template <> |
90 | struct AccumulateType<int16_t, true> { |
91 | using type = int64_t; |
92 | }; |
93 | template <> |
94 | struct AccumulateType<int32_t, true> { |
95 | using type = int64_t; |
96 | }; |
97 | template <> |
98 | struct AccumulateType<int64_t, true> { |
99 | using type = int64_t; |
100 | }; |
101 | template <> |
102 | struct AccumulateType<bool, true> { |
103 | using type = bool; |
104 | }; |
105 | template <> |
106 | struct AccumulateType<Half, false> { |
107 | using type = float; |
108 | }; |
109 | template <> |
110 | struct AccumulateType<BFloat16, false> { |
111 | using type = float; |
112 | }; |
113 | template <> |
114 | struct AccumulateType<c10::complex<Half>, false> { |
115 | using type = c10::complex<float>; |
116 | }; |
117 | template <> |
118 | struct AccumulateType<c10::complex<float>, false> { |
119 | using type = c10::complex<double>; |
120 | }; |
121 | template <> |
122 | struct AccumulateType<c10::complex<double>, false> { |
123 | using type = c10::complex<double>; |
124 | }; |
125 | template <> |
126 | struct AccumulateType<c10::complex<Half>, true> { |
127 | using type = c10::complex<float>; |
128 | }; |
129 | template <> |
130 | struct AccumulateType<c10::complex<float>, true> { |
131 | using type = c10::complex<float>; |
132 | }; |
133 | template <> |
134 | struct AccumulateType<c10::complex<double>, true> { |
135 | using type = c10::complex<double>; |
136 | }; |
137 | template <> |
138 | struct AccumulateType<float, false> { |
139 | using type = double; |
140 | }; |
141 | template <> |
142 | struct AccumulateType<double, false> { |
143 | using type = double; |
144 | }; |
145 | template <> |
146 | struct AccumulateType<int8_t, false> { |
147 | using type = int64_t; |
148 | }; |
149 | template <> |
150 | struct AccumulateType<uint8_t, false> { |
151 | using type = int64_t; |
152 | }; |
153 | template <> |
154 | struct AccumulateType<char, false> { |
155 | using type = int64_t; |
156 | }; |
157 | template <> |
158 | struct AccumulateType<int16_t, false> { |
159 | using type = int64_t; |
160 | }; |
161 | template <> |
162 | struct AccumulateType<int32_t, false> { |
163 | using type = int64_t; |
164 | }; |
165 | template <> |
166 | struct AccumulateType<int64_t, false> { |
167 | using type = int64_t; |
168 | }; |
169 | template <> |
170 | struct AccumulateType<bool, false> { |
171 | using type = bool; |
172 | }; |
173 | |
174 | template <typename T, bool is_cuda> |
175 | using acc_type = typename AccumulateType<T, is_cuda>::type; |
176 | |
177 | TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda); |
178 | |
179 | } // namespace at |
180 | |