1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | |
5 | namespace at { |
6 | namespace autocast { |
7 | |
8 | TORCH_API bool is_enabled(); |
9 | TORCH_API void set_enabled(bool enabled); |
10 | TORCH_API void clear_cache(); |
11 | TORCH_API int increment_nesting(); |
12 | TORCH_API int decrement_nesting(); |
13 | TORCH_API bool is_cpu_enabled(); |
14 | TORCH_API void set_cpu_enabled(bool enabled); |
15 | TORCH_API at::ScalarType get_autocast_gpu_dtype(); |
16 | TORCH_API at::ScalarType get_autocast_cpu_dtype(); |
17 | TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype); |
18 | TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype); |
19 | TORCH_API bool is_xpu_enabled(); |
20 | TORCH_API void set_xpu_enabled(bool enabled); |
21 | TORCH_API at::ScalarType get_autocast_xpu_dtype(); |
22 | TORCH_API void set_autocast_xpu_dtype(at::ScalarType dtype); |
23 | TORCH_API bool is_hpu_enabled(); |
24 | TORCH_API void set_hpu_enabled(bool enabled); |
25 | TORCH_API at::ScalarType get_autocast_hpu_dtype(); |
26 | TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype); |
27 | TORCH_API bool is_autocast_cache_enabled(); |
28 | TORCH_API void set_autocast_cache_enabled(bool enabled); |
29 | |
30 | namespace { |
31 | bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) { |
32 | switch (device_type) { |
33 | case DeviceType::CUDA: |
34 | return (tensor.is_cuda() || tensor.is_xla()) && |
35 | tensor.is_floating_point(); |
36 | case DeviceType::CPU: |
37 | return (tensor.is_cpu() || tensor.is_mkldnn()) && |
38 | tensor.is_floating_point(); |
39 | case DeviceType::XPU: |
40 | return tensor.is_xpu() && tensor.is_floating_point(); |
41 | case DeviceType::HPU: |
42 | return tensor.is_hpu() && tensor.is_floating_point(); |
43 | default: |
44 | return false; |
45 | } |
46 | } |
47 | } // namespace |
48 | |
49 | inline DispatchKey get_autocast_dispatch_key_from_device_type( |
50 | DeviceType device_type) { |
51 | switch (device_type) { |
52 | case DeviceType::CUDA: |
53 | return DispatchKey::Autocast; |
54 | case DeviceType::CPU: |
55 | return DispatchKey::AutocastCPU; |
56 | case DeviceType::XPU: |
57 | return DispatchKey::AutocastXPU; |
58 | case DeviceType::HPU: |
59 | return DispatchKey::AutocastHPU; |
60 | default: |
61 | throw std::runtime_error( |
62 | "unknown device type for autocast in get_autocast_dispatch_key_from_device_type" ); |
63 | } |
64 | } |
65 | |
66 | inline at::ScalarType get_lower_precision_fp_from_device_type( |
67 | DeviceType device_type) { |
68 | switch (device_type) { |
69 | case DeviceType::CUDA: |
70 | return get_autocast_gpu_dtype(); |
71 | case DeviceType::CPU: |
72 | return get_autocast_cpu_dtype(); |
73 | case DeviceType::XPU: |
74 | return get_autocast_xpu_dtype(); |
75 | case DeviceType::HPU: |
76 | return get_autocast_hpu_dtype(); |
77 | default: |
78 | throw std::runtime_error( |
79 | "unknown device type for autocast in get_lower_precision_fp_from_device_type" ); |
80 | } |
81 | } |
82 | |
83 | /******************************************************************** |
84 | Logic to extract the promote type from any Tensor or TensorList args. |
85 | ********************************************************************/ |
86 | |
87 | // Overload to catch Tensor args. |
88 | // If nextArg is floating-point, compare its scalar_type with our |
89 | // current best guess for the promote type, and update if necessary. |
90 | inline at::ScalarType prioritize( |
91 | at::ScalarType current, |
92 | const Tensor& nextArg, |
93 | DeviceType device_type = DeviceType::CUDA) { |
94 | if (current == at::kDouble) { |
95 | AT_ERROR("promote type is double in at::autocast::prioritize" ); |
96 | return current; |
97 | } |
98 | at::ScalarType lower_precision_fp = |
99 | get_lower_precision_fp_from_device_type(device_type); |
100 | if (is_autocast_eligible(nextArg, device_type)) { |
101 | auto next = nextArg.scalar_type(); |
102 | if (next == at::kDouble) { |
103 | return current; // ignores double tensors |
104 | } else if (current == at::kFloat || next == at::kFloat) { |
105 | return at::kFloat; // prioritizes float over lower_precision_fp |
106 | } else if (current == lower_precision_fp && next == lower_precision_fp) { |
107 | return lower_precision_fp; |
108 | } else { |
109 | AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize" ); |
110 | return current; |
111 | } |
112 | } else { |
113 | return current; |
114 | } |
115 | } |
116 | |
117 | // Overload to catch TensorList args (for e.g. cat, stack). |
118 | // Reuses the overload above to process each Tensor in the list. |
119 | inline at::ScalarType prioritize( |
120 | at::ScalarType current, |
121 | const TensorList& list, |
122 | DeviceType device_type = DeviceType::CUDA) { |
123 | for (const auto& tensor : list) { |
124 | current = prioritize(current, tensor, device_type); |
125 | } |
126 | return current; |
127 | } |
128 | |
129 | inline at::ScalarType prioritize( |
130 | at::ScalarType current, |
131 | const ITensorListRef& list, |
132 | DeviceType device_type = DeviceType::CUDA) { |
133 | for (const auto& tensor : list) { |
134 | current = prioritize(current, tensor, device_type); |
135 | } |
136 | return current; |
137 | } |
138 | |
139 | // Template to catch non-Tensor args (no-op that returns current best guess) |
140 | template <typename T> |
141 | inline at::ScalarType prioritize( |
142 | at::ScalarType current, |
143 | T nextArg, |
144 | DeviceType device_type = DeviceType::CUDA) { |
145 | return current; |
146 | } |
147 | |
148 | // Overload for the tail case. |
149 | inline at::ScalarType promote_type( |
150 | at::ScalarType current, |
151 | DeviceType device_type) { |
152 | return current; |
153 | } |
154 | |
155 | // Unpack args and determine if incoming lower_precision_fp tensors need to be |
156 | // promoted to float32. Non-Tensor arguments are ignored. |
157 | template <typename Arg0, typename... Args> |
158 | inline at::ScalarType promote_type( |
159 | at::ScalarType current, |
160 | DeviceType device_type, |
161 | Arg0 arg0, |
162 | Args... args) { |
163 | auto new_current = prioritize(current, arg0, device_type); |
164 | return promote_type(new_current, device_type, args...); |
165 | } |
166 | |
167 | /**************************************************** |
168 | Logic to apply cached casting to any Tensor argument. |
169 | ****************************************************/ |
170 | inline bool is_eligible( |
171 | const Tensor& arg, |
172 | DeviceType device_type = DeviceType::CUDA) { |
173 | return ( |
174 | arg.defined() && is_autocast_eligible(arg, device_type) && |
175 | (arg.scalar_type() != at::kDouble)); |
176 | } |
177 | |
178 | // Overload to catch Tensor args |
179 | TORCH_API Tensor cached_cast( |
180 | at::ScalarType to_type, |
181 | const Tensor& arg, |
182 | DeviceType device_type = DeviceType::CUDA); |
183 | |
184 | // Overload to process optional<Tensor> |
185 | inline c10::optional<Tensor> cached_cast( |
186 | at::ScalarType to_type, |
187 | const c10::optional<Tensor>& arg, |
188 | DeviceType device_type = DeviceType::CUDA) { |
189 | if (arg.has_value()) { |
190 | return cached_cast(to_type, *arg, device_type); |
191 | } else { |
192 | return c10::nullopt; |
193 | } |
194 | } |
195 | |
196 | // Overload to process TensorLists |
197 | inline std::vector<Tensor> cached_cast( |
198 | at::ScalarType to_type, |
199 | const TensorList& arg, |
200 | DeviceType device_type = DeviceType::CUDA) { |
201 | std::vector<Tensor> vec; |
202 | vec.reserve(arg.size()); |
203 | for (const auto& t : arg) { |
204 | vec.emplace_back(cached_cast(to_type, t, device_type)); |
205 | } |
206 | return vec; |
207 | } |
208 | |
209 | inline std::vector<Tensor> cached_cast( |
210 | at::ScalarType to_type, |
211 | const ITensorListRef& arg, |
212 | DeviceType device_type = DeviceType::CUDA) { |
213 | std::vector<Tensor> vec; |
214 | vec.reserve(arg.size()); |
215 | for (const auto& t : arg) { |
216 | vec.emplace_back(cached_cast(to_type, t, device_type)); |
217 | } |
218 | return vec; |
219 | } |
220 | |
221 | // Template to catch non-Tensor args. |
222 | template <typename T> |
223 | inline T cached_cast( |
224 | at::ScalarType to_type, |
225 | T arg, |
226 | DeviceType device_type = DeviceType::CUDA) { |
227 | return arg; |
228 | } |
229 | |
230 | /******************************************************* |
231 | Logic to flip an output dtype flag. |
232 | Keep it simple for now by assuming only one such flag is |
233 | present in the argument list. If I ever need a function |
234 | with more than flag I'll figure out something else. |
235 | The policy is: |
236 | If the user has explicity specified a dtype, respect it. |
237 | Otherwise, set it to the autocast type. |
238 | ********************************************************/ |
239 | |
240 | // Overload to catch dtype flags |
241 | c10::optional<ScalarType> inline set_opt_dtype( |
242 | at::ScalarType to_type, |
243 | const c10::optional<ScalarType>& dtype) { |
244 | return dtype.has_value() ? dtype : to_type; |
245 | } |
246 | |
247 | // Template to catch other args |
248 | template <typename T> |
249 | inline T set_opt_dtype(at::ScalarType to_type, T arg) { |
250 | return arg; |
251 | } |
252 | |
253 | template <typename... Args> |
254 | inline bool firstarg_is_eligible(const Tensor& arg, Args... args) { |
255 | return is_eligible(arg); |
256 | } |
257 | |
258 | template <typename... Args> |
259 | inline at::ScalarType type_from_firstarg( |
260 | at::ScalarType to_type, |
261 | const Tensor& arg, |
262 | Args... args) { |
263 | return (is_eligible(arg) ? to_type : arg.scalar_type()); |
264 | } |
265 | |
266 | } // namespace autocast |
267 | } // namespace at |
268 | |