1#pragma once
2
3#include <ATen/ATen.h>
4
5namespace at {
6namespace autocast {
7
8TORCH_API bool is_enabled();
9TORCH_API void set_enabled(bool enabled);
10TORCH_API void clear_cache();
11TORCH_API int increment_nesting();
12TORCH_API int decrement_nesting();
13TORCH_API bool is_cpu_enabled();
14TORCH_API void set_cpu_enabled(bool enabled);
15TORCH_API at::ScalarType get_autocast_gpu_dtype();
16TORCH_API at::ScalarType get_autocast_cpu_dtype();
17TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
18TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
19TORCH_API bool is_xpu_enabled();
20TORCH_API void set_xpu_enabled(bool enabled);
21TORCH_API at::ScalarType get_autocast_xpu_dtype();
22TORCH_API void set_autocast_xpu_dtype(at::ScalarType dtype);
23TORCH_API bool is_hpu_enabled();
24TORCH_API void set_hpu_enabled(bool enabled);
25TORCH_API at::ScalarType get_autocast_hpu_dtype();
26TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
27TORCH_API bool is_autocast_cache_enabled();
28TORCH_API void set_autocast_cache_enabled(bool enabled);
29
30namespace {
31bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
32 switch (device_type) {
33 case DeviceType::CUDA:
34 return (tensor.is_cuda() || tensor.is_xla()) &&
35 tensor.is_floating_point();
36 case DeviceType::CPU:
37 return (tensor.is_cpu() || tensor.is_mkldnn()) &&
38 tensor.is_floating_point();
39 case DeviceType::XPU:
40 return tensor.is_xpu() && tensor.is_floating_point();
41 case DeviceType::HPU:
42 return tensor.is_hpu() && tensor.is_floating_point();
43 default:
44 return false;
45 }
46}
47} // namespace
48
49inline DispatchKey get_autocast_dispatch_key_from_device_type(
50 DeviceType device_type) {
51 switch (device_type) {
52 case DeviceType::CUDA:
53 return DispatchKey::Autocast;
54 case DeviceType::CPU:
55 return DispatchKey::AutocastCPU;
56 case DeviceType::XPU:
57 return DispatchKey::AutocastXPU;
58 case DeviceType::HPU:
59 return DispatchKey::AutocastHPU;
60 default:
61 throw std::runtime_error(
62 "unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
63 }
64}
65
66inline at::ScalarType get_lower_precision_fp_from_device_type(
67 DeviceType device_type) {
68 switch (device_type) {
69 case DeviceType::CUDA:
70 return get_autocast_gpu_dtype();
71 case DeviceType::CPU:
72 return get_autocast_cpu_dtype();
73 case DeviceType::XPU:
74 return get_autocast_xpu_dtype();
75 case DeviceType::HPU:
76 return get_autocast_hpu_dtype();
77 default:
78 throw std::runtime_error(
79 "unknown device type for autocast in get_lower_precision_fp_from_device_type");
80 }
81}
82
83/********************************************************************
84Logic to extract the promote type from any Tensor or TensorList args.
85********************************************************************/
86
87// Overload to catch Tensor args.
88// If nextArg is floating-point, compare its scalar_type with our
89// current best guess for the promote type, and update if necessary.
90inline at::ScalarType prioritize(
91 at::ScalarType current,
92 const Tensor& nextArg,
93 DeviceType device_type = DeviceType::CUDA) {
94 if (current == at::kDouble) {
95 AT_ERROR("promote type is double in at::autocast::prioritize");
96 return current;
97 }
98 at::ScalarType lower_precision_fp =
99 get_lower_precision_fp_from_device_type(device_type);
100 if (is_autocast_eligible(nextArg, device_type)) {
101 auto next = nextArg.scalar_type();
102 if (next == at::kDouble) {
103 return current; // ignores double tensors
104 } else if (current == at::kFloat || next == at::kFloat) {
105 return at::kFloat; // prioritizes float over lower_precision_fp
106 } else if (current == lower_precision_fp && next == lower_precision_fp) {
107 return lower_precision_fp;
108 } else {
109 AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
110 return current;
111 }
112 } else {
113 return current;
114 }
115}
116
117// Overload to catch TensorList args (for e.g. cat, stack).
118// Reuses the overload above to process each Tensor in the list.
119inline at::ScalarType prioritize(
120 at::ScalarType current,
121 const TensorList& list,
122 DeviceType device_type = DeviceType::CUDA) {
123 for (const auto& tensor : list) {
124 current = prioritize(current, tensor, device_type);
125 }
126 return current;
127}
128
129inline at::ScalarType prioritize(
130 at::ScalarType current,
131 const ITensorListRef& list,
132 DeviceType device_type = DeviceType::CUDA) {
133 for (const auto& tensor : list) {
134 current = prioritize(current, tensor, device_type);
135 }
136 return current;
137}
138
139// Template to catch non-Tensor args (no-op that returns current best guess)
140template <typename T>
141inline at::ScalarType prioritize(
142 at::ScalarType current,
143 T nextArg,
144 DeviceType device_type = DeviceType::CUDA) {
145 return current;
146}
147
148// Overload for the tail case.
149inline at::ScalarType promote_type(
150 at::ScalarType current,
151 DeviceType device_type) {
152 return current;
153}
154
155// Unpack args and determine if incoming lower_precision_fp tensors need to be
156// promoted to float32. Non-Tensor arguments are ignored.
157template <typename Arg0, typename... Args>
158inline at::ScalarType promote_type(
159 at::ScalarType current,
160 DeviceType device_type,
161 Arg0 arg0,
162 Args... args) {
163 auto new_current = prioritize(current, arg0, device_type);
164 return promote_type(new_current, device_type, args...);
165}
166
167/****************************************************
168Logic to apply cached casting to any Tensor argument.
169****************************************************/
170inline bool is_eligible(
171 const Tensor& arg,
172 DeviceType device_type = DeviceType::CUDA) {
173 return (
174 arg.defined() && is_autocast_eligible(arg, device_type) &&
175 (arg.scalar_type() != at::kDouble));
176}
177
178// Overload to catch Tensor args
179TORCH_API Tensor cached_cast(
180 at::ScalarType to_type,
181 const Tensor& arg,
182 DeviceType device_type = DeviceType::CUDA);
183
184// Overload to process optional<Tensor>
185inline c10::optional<Tensor> cached_cast(
186 at::ScalarType to_type,
187 const c10::optional<Tensor>& arg,
188 DeviceType device_type = DeviceType::CUDA) {
189 if (arg.has_value()) {
190 return cached_cast(to_type, *arg, device_type);
191 } else {
192 return c10::nullopt;
193 }
194}
195
196// Overload to process TensorLists
197inline std::vector<Tensor> cached_cast(
198 at::ScalarType to_type,
199 const TensorList& arg,
200 DeviceType device_type = DeviceType::CUDA) {
201 std::vector<Tensor> vec;
202 vec.reserve(arg.size());
203 for (const auto& t : arg) {
204 vec.emplace_back(cached_cast(to_type, t, device_type));
205 }
206 return vec;
207}
208
209inline std::vector<Tensor> cached_cast(
210 at::ScalarType to_type,
211 const ITensorListRef& arg,
212 DeviceType device_type = DeviceType::CUDA) {
213 std::vector<Tensor> vec;
214 vec.reserve(arg.size());
215 for (const auto& t : arg) {
216 vec.emplace_back(cached_cast(to_type, t, device_type));
217 }
218 return vec;
219}
220
221// Template to catch non-Tensor args.
222template <typename T>
223inline T cached_cast(
224 at::ScalarType to_type,
225 T arg,
226 DeviceType device_type = DeviceType::CUDA) {
227 return arg;
228}
229
230/*******************************************************
231Logic to flip an output dtype flag.
232Keep it simple for now by assuming only one such flag is
233present in the argument list. If I ever need a function
234with more than flag I'll figure out something else.
235The policy is:
236If the user has explicity specified a dtype, respect it.
237Otherwise, set it to the autocast type.
238********************************************************/
239
240// Overload to catch dtype flags
241c10::optional<ScalarType> inline set_opt_dtype(
242 at::ScalarType to_type,
243 const c10::optional<ScalarType>& dtype) {
244 return dtype.has_value() ? dtype : to_type;
245}
246
247// Template to catch other args
248template <typename T>
249inline T set_opt_dtype(at::ScalarType to_type, T arg) {
250 return arg;
251}
252
253template <typename... Args>
254inline bool firstarg_is_eligible(const Tensor& arg, Args... args) {
255 return is_eligible(arg);
256}
257
258template <typename... Args>
259inline at::ScalarType type_from_firstarg(
260 at::ScalarType to_type,
261 const Tensor& arg,
262 Args... args) {
263 return (is_eligible(arg) ? to_type : arg.scalar_type());
264}
265
266} // namespace autocast
267} // namespace at
268