1 | #include <ATen/ATen.h> |
2 | #include <torch/library.h> |
3 | #include <ATen/NativeFunctions.h> |
4 | #include <ATen/autocast_mode.h> |
5 | #include <ATen/Operators.h> |
6 | |
7 | #include <c10/util/intrusive_ptr.h> |
8 | #include <c10/core/impl/LocalDispatchKeySet.h> |
9 | |
10 | #include <iostream> |
11 | #include <exception> |
12 | #include <mutex> |
13 | |
14 | namespace at { |
15 | namespace autocast { |
16 | |
17 | bool is_enabled() { |
18 | return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA); |
19 | } |
20 | |
21 | void set_enabled(bool new_enabled) { |
22 | c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled); |
23 | } |
24 | |
25 | bool is_cpu_enabled() { |
26 | return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCPU); |
27 | } |
28 | |
29 | void set_cpu_enabled(bool new_enabled) { |
30 | c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU, !new_enabled); |
31 | } |
32 | |
33 | bool is_xpu_enabled() { |
34 | return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastXPU); |
35 | } |
36 | |
37 | void set_xpu_enabled(bool new_enabled) { |
38 | c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastXPU, !new_enabled); |
39 | } |
40 | |
41 | bool is_hpu_enabled() { |
42 | return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastHPU); |
43 | } |
44 | |
45 | void set_hpu_enabled(bool new_enabled) { |
46 | c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled); |
47 | } |
48 | |
49 | namespace { |
50 | // Imitate Apex and cache some of the casts to streamline parameter reuse. |
51 | // Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below). |
52 | // |
53 | // After discussion with @ezyang, the cache uses the following structure: |
54 | // The key is the fp32 source tensor's TensorImpl*, a proxy for a Tensor uuid that's |
55 | // unchanged across shallow copies. |
56 | // The value is a tuple with a weakref to the source tensor's TensorImpl as the first |
57 | // element and the casted tensor as the second element. |
58 | // |
59 | // The weakref keeps the source's TensorImpl from being deleted. We need to because we're |
60 | // using the source TensorImpl* as the key. If it were deleted, another random Tensor could |
61 | // be allocated whose TensorImpl* happened to have the same value. This TensorImpl* would |
62 | // then mistakenly hit in cache: a rare, intermittent, unpredictable bug. |
63 | // |
64 | // I'm not using the weak_intrusive_ptr as the key because it's more difficult to compare |
65 | // directly against incoming TensorImpl*s. |
66 | using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>; |
67 | using val_type = std::tuple<weakref_type, Tensor>; |
68 | std::unordered_map<TensorImpl*, val_type> cached_casts; |
69 | std::mutex cached_casts_mutex; |
70 | |
71 | // nesting tracks the nesting depth of the Python-side context manager. |
72 | // When the autocast context manager exits to a nesting level that's outside |
73 | // any instance of autocast (which should occur at the end of each forward pass) |
74 | // it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region. |
75 | thread_local int nesting = 0; |
76 | |
77 | // autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU. |
78 | thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16; |
79 | |
80 | // autocast_xpu_dtype is the lower_precision_fp used by AutocastXPU. |
81 | thread_local at::ScalarType autocast_xpu_dtype = at::kBFloat16; |
82 | |
83 | // autocast_hpu_dtype is the lower_precision_fp used by AutocastHPU. |
84 | thread_local at::ScalarType autocast_hpu_dtype = at::kBFloat16; |
85 | |
86 | // should we enabled the cache inside autocast. |
87 | thread_local bool cache_enabled = true; |
88 | |
89 | // autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU. |
90 | thread_local at::ScalarType autocast_gpu_dtype = at::kHalf; |
91 | } |
92 | |
93 | void clear_cache() { |
94 | const std::lock_guard<std::mutex> lock(cached_casts_mutex); |
95 | cached_casts.clear(); |
96 | } |
97 | |
98 | int increment_nesting() { |
99 | return ++nesting; |
100 | } |
101 | |
102 | int decrement_nesting() { |
103 | return --nesting; |
104 | } |
105 | |
106 | at::ScalarType get_autocast_gpu_dtype() { |
107 | return autocast_gpu_dtype; |
108 | } |
109 | |
110 | at::ScalarType get_autocast_cpu_dtype() { |
111 | return autocast_cpu_dtype; |
112 | } |
113 | |
114 | at::ScalarType get_autocast_xpu_dtype() { |
115 | return autocast_xpu_dtype; |
116 | } |
117 | |
118 | at::ScalarType get_autocast_hpu_dtype() { |
119 | return autocast_hpu_dtype; |
120 | } |
121 | |
122 | void set_autocast_cpu_dtype(at::ScalarType dtype) { |
123 | TORCH_CHECK( |
124 | dtype == at::kBFloat16, |
125 | "Currently, AutocastCPU only support Bfloat16 as the autocast_cpu_dtype" ); |
126 | autocast_cpu_dtype = dtype; |
127 | } |
128 | |
129 | void set_autocast_gpu_dtype(at::ScalarType dtype) { |
130 | autocast_gpu_dtype = dtype; |
131 | } |
132 | |
133 | void set_autocast_xpu_dtype(at::ScalarType dtype) { |
134 | autocast_xpu_dtype = dtype; |
135 | } |
136 | |
137 | void set_autocast_hpu_dtype(at::ScalarType dtype) { |
138 | autocast_hpu_dtype = dtype; |
139 | } |
140 | |
141 | bool is_autocast_cache_enabled() { |
142 | return cache_enabled; |
143 | } |
144 | |
145 | void set_autocast_cache_enabled(bool enabled) { |
146 | cache_enabled = enabled; |
147 | } |
148 | |
149 | // Overload to catch Tensor args |
150 | // TODO (possible optimization): |
151 | // Move cast_cache to an inline function in a header with cached_casts declared as |
152 | // extern thread_local in the header. |
153 | Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_type) { |
154 | if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) { |
155 | // Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves). |
156 | // See cached_casts declaration above for detailed strategy. |
157 | bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) && |
158 | arg.scalar_type() == at::kFloat && arg.requires_grad() && |
159 | arg.is_leaf() && !arg.is_view() && cache_enabled); |
160 | if (can_try_cache) { |
161 | const std::lock_guard<std::mutex> lock(cached_casts_mutex); |
162 | auto it = cached_casts.find(arg.unsafeGetTensorImpl()); |
163 | if (it != cached_casts.end()) { |
164 | return std::get<1>(it->second); |
165 | } else { |
166 | auto casted_arg = arg.to(to_type); |
167 | cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); |
168 | return casted_arg; |
169 | } |
170 | } else { |
171 | return arg.to(to_type); |
172 | } |
173 | } else { |
174 | return arg; |
175 | } |
176 | } |
177 | |
178 | // Policies correspond to op categories that need code-divergent handling. |
179 | // Wrapper templates below are specialized based on a policy template parameter. |
180 | enum class CastPolicy : uint8_t { |
181 | lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before running the op. |
182 | // Currently, lower_precision_fp is fp16 for AutocastCUDA, and is defined by user(default bf16) for AutocastCPU. |
183 | fp32, // Cast all inputs to at::kFloat before running the op. |
184 | fp32_set_opt_dtype, // Treats functions (like softmax) that |
185 | // 1. we'd like to run in fp32 and |
186 | // 2. have a c10::optional<ScalarType> arg that controls the output type. |
187 | // fp32_set_opt_dtype wrappers' policy is: if the output type is already set, |
188 | // don't touch it, otherwise, set it to at::kFloat. |
189 | fp32_append_dtype, // Treats functions (like norm) that |
190 | // 1. we'd like to run in fp32 and |
191 | // 2. have some overloads that accept an output type and other overloads that don't. |
192 | // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype. |
193 | // The wrapper policy is: append at::kFloat to the args, and redispatch to the |
194 | // type-aware overload. |
195 | promote, // Run in the widest dtype among several args. |
196 | }; |
197 | |
198 | /******************************************************************************************************** |
199 | Templates to provide wrapper functions |
200 | |
201 | I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to extract args and return type. |
202 | (see also https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer) |
203 | |
204 | This strategy uses an exterior "WrapFunction" that extracts arguments on behalf of |
205 | (in my case several specializations of) an interior "WrapFunction_". |
206 | Interior WrapFunction_ specializations are defined for each CastPolicy. |
207 | ********************************************************************************************************/ |
208 | |
209 | // Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy |
210 | template<CastPolicy policy, DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class ArgList> struct WrapFunction_ {}; |
211 | |
212 | // CastPolicy::lower_precision_fp General_DeviceType |
213 | template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args> |
214 | struct WrapFunction_<CastPolicy::lower_precision_fp, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> { |
215 | static Ret call(Args... args) { |
216 | c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); |
217 | return (*F)(cached_cast(get_lower_precision_fp_from_device_type(device_type), args, device_type)...); |
218 | } |
219 | }; |
220 | |
221 | // CastPolicy::fp32 General_DeviceType |
222 | template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args> |
223 | struct WrapFunction_<CastPolicy::fp32, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> { |
224 | static Ret call(Args... args) { |
225 | c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); |
226 | return (*F)(cached_cast(at::kFloat, args, device_type)...); |
227 | } |
228 | }; |
229 | |
230 | // CastPolicy::fp32_set_opt_dtype DeviceType::CUDA |
231 | template<class Redispatch, Redispatch* F, class Ret, class... Args> |
232 | struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, DeviceType::CUDA, Redispatch, F, Ret, guts::typelist::typelist<Args...>> { |
233 | static Ret call(Args... args) { |
234 | c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); |
235 | if (firstarg_is_eligible(args...)) { |
236 | return (*F)(set_opt_dtype(at::kFloat, args)...); |
237 | } else { |
238 | // If ineligible, calls F with unaltered args. Does not set opt dtype, because setting |
239 | // opt dtype explicitly may interfere with internal implicit promotion decisions. |
240 | return (*F)(args...); |
241 | } |
242 | } |
243 | }; |
244 | |
245 | // CastPolicy::fp32_append_dtype DeviceType::CUDA |
246 | template<class Redispatch, Redispatch* F, class Ret, class... Args> |
247 | struct WrapFunction_<CastPolicy::fp32_append_dtype, DeviceType::CUDA, Redispatch, F, Ret, guts::typelist::typelist<Args...>> { |
248 | static Ret call(Args... args) { |
249 | c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); |
250 | at::ScalarType out_type = type_from_firstarg(at::kFloat, args...); |
251 | return (*F)(args..., out_type); |
252 | } |
253 | }; |
254 | |
255 | // CastPolicy::promote General_DeviceType |
256 | template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args> |
257 | struct WrapFunction_<CastPolicy::promote, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> { |
258 | static Ret call(Args... args) { |
259 | c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type)); |
260 | auto to_type = promote_type(get_lower_precision_fp_from_device_type(device_type), device_type, args...); |
261 | return (*F)(cached_cast(to_type, args, device_type)...); |
262 | } |
263 | }; |
264 | |
265 | // Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating core/boxing/impl/WrapFunctionIntoFunctor.h) |
266 | template<CastPolicy policy, |
267 | DeviceType device_type, |
268 | class Registered, // The signature for which we're registering. The dispatcher's calling code invokes our |
269 | // registered functions with arguments matching Registered, so we register |
270 | // WrapFunction_::call methods with a matching signature to properly field those arguments. |
271 | // guts::function_traits below extracts return_type and parameter_types from Registered, |
272 | // which WrapFunction_ templates above use to declare their call methods. |
273 | class Redispatch, // The signature for the function we're redispatching to. In most cases this is the same |
274 | // as Registered, but for some ops (for example, ops where we append a dtype) it's useful |
275 | // to redispatch to a function with a different signature. |
276 | Redispatch* F> // The actual function we're redispatching to. |
277 | struct WrapFunction final { |
278 | using type = WrapFunction_<policy, |
279 | device_type, |
280 | Redispatch, |
281 | F, |
282 | typename guts::function_traits<Registered>::return_type, |
283 | typename guts::function_traits<Registered>::parameter_types>; |
284 | }; |
285 | |
286 | /******************************* |
287 | Banned functions |
288 | *******************************/ |
289 | |
290 | Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const c10::optional<Tensor>&, int64_t) { |
291 | AT_ERROR("torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n" |
292 | "Many models use a sigmoid layer right before the binary cross entropy layer.\n" |
293 | "In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n" |
294 | "or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n" |
295 | "safe to autocast." ); |
296 | } |
297 | |
298 | namespace { |
299 | /***************************************************************************************************************** |
300 | This section performs load-time registration for autocast wrappers. |
301 | |
302 | It's debatable at what level operations should be patched. We'd like casts to be autograd-exposed |
303 | and precede autograd history recording, so that for lower_precision_fp ops, input tensors are saved for backward |
304 | in lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp can significantly reduce |
305 | a model's memory footprint. |
306 | |
307 | Option 1 (strawman): Patch only at the level of explicit calls into cudnn/cublas (cudnn_convolution, etc), |
308 | because those are the code paths that are guaranteed to use Tensor Cores, therefore they're the ones that |
309 | will benefit most from lower_precision_fp. Potential pitfall: convolutions (and other ops) are wrapped in several |
310 | layers of at::* calls. If one of those happens to record autograd history, then we've lost the |
311 | opportunity to save inputs in lower_precision_fp. |
312 | |
313 | Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd history |
314 | recording can't sneak in ahead of autocast. This mirrors Apex most closely. |
315 | |
316 | I think Option 2 is the right answer for all ops, not just convolutions. Option 2 is what I implement here. |
317 | *****************************************************************************************************************/ |
318 | |
319 | /******************************************************************************************************************** |
320 | Explicit registration for out-of-place ops |
321 | |
322 | The stuff below could be codegenned. Ed said |
323 | > you are going to have to write the function definition at some point, I wouldn't try to get clever about it |
324 | Therefore, for the moment, this is all copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. |
325 | ********************************************************************************************************************/ |
326 | |
327 | #define ADD_NS(RAW_OP) at::RAW_OP |
328 | |
329 | // Common cases where registration signature matches redispatch signature |
330 | // (that's why SIGNATURE is repeated in the WrapFunction instantiation) |
331 | #define KERNEL(OP, POLICY) \ |
332 | m.impl(TORCH_SELECTIVE_NAME("aten::" #OP), \ |
333 | &WrapFunction<CastPolicy::POLICY, DeviceType::CUDA, decltype(ATEN_FN(OP)), decltype(ATEN_FN(OP)), &ATEN_FN(OP)>::type::call); |
334 | #define KERNEL2(OP, OVERLOAD, POLICY) \ |
335 | m.impl(TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ |
336 | &WrapFunction<CastPolicy::POLICY, DeviceType::CUDA, decltype(ATEN_FN2(OP, OVERLOAD)), decltype(ATEN_FN2(OP, OVERLOAD)), &ATEN_FN2(OP, OVERLOAD)>::type::call); |
337 | |
338 | // Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype) |
339 | #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \ |
340 | m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ |
341 | &WrapFunction<CastPolicy::POLICY, DeviceType::CUDA, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call); |
342 | |
343 | // KERNEL_CPU registration for AutocastCPU |
344 | #define KERNEL_CPU(OP, POLICY) \ |
345 | m.impl(TORCH_SELECTIVE_NAME("aten::" #OP), \ |
346 | &WrapFunction<CastPolicy::POLICY, DeviceType::CPU, decltype(ATEN_FN(OP)), decltype(ATEN_FN(OP)), &ATEN_FN(OP)>::type::call); |
347 | #define KERNEL_CPU2(OP, OVERLOAD, POLICY) \ |
348 | m.impl(TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ |
349 | &WrapFunction<CastPolicy::POLICY, DeviceType::CPU, decltype(ATEN_FN2(OP, OVERLOAD)), decltype(ATEN_FN2(OP, OVERLOAD)), &ATEN_FN2(OP, OVERLOAD)>::type::call); |
350 | |
351 | /***************************************** |
352 | Explicit registration for out-of-place ops |
353 | *****************************************/ |
354 | TORCH_LIBRARY_IMPL(_, Autocast, m) { |
355 | m.fallback(torch::CppFunction::makeFallthrough()); |
356 | } |
357 | |
358 | TORCH_LIBRARY_IMPL(aten, Autocast, m) { |
359 | // lower_precision_fp |
360 | KERNEL2(_convolution, deprecated, lower_precision_fp) |
361 | KERNEL(_convolution, lower_precision_fp) |
362 | KERNEL(conv1d, lower_precision_fp) |
363 | KERNEL(conv2d, lower_precision_fp) |
364 | KERNEL(conv3d, lower_precision_fp) |
365 | KERNEL(conv_tbc, lower_precision_fp) |
366 | KERNEL(conv_transpose1d, lower_precision_fp) |
367 | KERNEL2(conv_transpose2d, input, lower_precision_fp) |
368 | KERNEL2(conv_transpose3d, input, lower_precision_fp) |
369 | KERNEL(convolution, lower_precision_fp) |
370 | KERNEL(cudnn_convolution, lower_precision_fp) |
371 | KERNEL(cudnn_convolution_transpose, lower_precision_fp) |
372 | KERNEL(prelu, lower_precision_fp) |
373 | KERNEL(addmm, lower_precision_fp) |
374 | KERNEL(addmv, lower_precision_fp) |
375 | KERNEL(addr, lower_precision_fp) |
376 | KERNEL(matmul, lower_precision_fp) |
377 | KERNEL(einsum, lower_precision_fp) |
378 | KERNEL(mm, lower_precision_fp) |
379 | KERNEL(mv, lower_precision_fp) |
380 | KERNEL(linear, lower_precision_fp) |
381 | KERNEL(addbmm, lower_precision_fp) |
382 | KERNEL(baddbmm, lower_precision_fp) |
383 | KERNEL(bmm, lower_precision_fp) |
384 | KERNEL(chain_matmul, lower_precision_fp) |
385 | KERNEL(linalg_multi_dot, lower_precision_fp) |
386 | KERNEL(_thnn_fused_lstm_cell, lower_precision_fp) |
387 | KERNEL(_thnn_fused_gru_cell, lower_precision_fp) |
388 | KERNEL(lstm_cell, lower_precision_fp) |
389 | KERNEL(gru_cell, lower_precision_fp) |
390 | KERNEL(rnn_tanh_cell, lower_precision_fp) |
391 | KERNEL(rnn_relu_cell, lower_precision_fp) |
392 | KERNEL(_scaled_dot_product_flash_attention, lower_precision_fp) |
393 | KERNEL(scaled_dot_product_attention, lower_precision_fp) |
394 | |
395 | // fp32 |
396 | KERNEL(acos, fp32) |
397 | KERNEL(asin, fp32) |
398 | KERNEL(cosh, fp32) |
399 | KERNEL(erfinv, fp32) |
400 | KERNEL(exp, fp32) |
401 | KERNEL(expm1, fp32) |
402 | KERNEL(log, fp32) |
403 | KERNEL(log10, fp32) |
404 | KERNEL(log2, fp32) |
405 | KERNEL(log1p, fp32) |
406 | KERNEL(reciprocal, fp32) |
407 | KERNEL(rsqrt, fp32) |
408 | KERNEL(sinh, fp32) |
409 | KERNEL(tan, fp32) |
410 | KERNEL2(pow, Tensor_Scalar, fp32) |
411 | KERNEL2(pow, Tensor_Tensor, fp32) |
412 | KERNEL2(pow, Scalar, fp32) |
413 | KERNEL(softplus, fp32) |
414 | KERNEL(layer_norm, fp32) |
415 | KERNEL(native_layer_norm, fp32) |
416 | KERNEL(group_norm, fp32) |
417 | KERNEL2(frobenius_norm, dim, fp32) |
418 | KERNEL(nuclear_norm, fp32) |
419 | KERNEL2(nuclear_norm, dim, fp32) |
420 | KERNEL(cosine_similarity, fp32) |
421 | KERNEL(poisson_nll_loss, fp32) |
422 | KERNEL(cosine_embedding_loss, fp32) |
423 | KERNEL(nll_loss, fp32) |
424 | KERNEL(nll_loss2d, fp32) |
425 | KERNEL(hinge_embedding_loss, fp32) |
426 | KERNEL(kl_div, fp32) |
427 | KERNEL(l1_loss, fp32) |
428 | KERNEL(smooth_l1_loss, fp32) |
429 | KERNEL(huber_loss, fp32) |
430 | KERNEL(mse_loss, fp32) |
431 | KERNEL(margin_ranking_loss, fp32) |
432 | KERNEL(multilabel_margin_loss, fp32) |
433 | KERNEL(soft_margin_loss, fp32) |
434 | KERNEL(triplet_margin_loss, fp32) |
435 | KERNEL(multi_margin_loss, fp32) |
436 | KERNEL(binary_cross_entropy_with_logits, fp32) |
437 | KERNEL(dist, fp32) |
438 | KERNEL(pdist, fp32) |
439 | KERNEL(cdist, fp32) |
440 | KERNEL(renorm, fp32) |
441 | KERNEL(logsumexp, fp32) |
442 | // fp32_set_opt_dtype |
443 | KERNEL(prod, fp32_set_opt_dtype) |
444 | KERNEL2(prod, dim_int, fp32_set_opt_dtype) |
445 | KERNEL2(prod, dim_Dimname, fp32_set_opt_dtype) |
446 | KERNEL2(softmax, int, fp32_set_opt_dtype) |
447 | KERNEL2(softmax, Dimname, fp32_set_opt_dtype) |
448 | KERNEL2(log_softmax, int, fp32_set_opt_dtype) |
449 | KERNEL2(log_softmax, Dimname, fp32_set_opt_dtype) |
450 | KERNEL(cumprod, fp32_set_opt_dtype) |
451 | KERNEL2(cumprod, dimname, fp32_set_opt_dtype) |
452 | KERNEL(cumsum, fp32_set_opt_dtype) |
453 | KERNEL2(cumsum, dimname, fp32_set_opt_dtype) |
454 | KERNEL(linalg_vector_norm, fp32_set_opt_dtype) |
455 | KERNEL(linalg_matrix_norm, fp32_set_opt_dtype) |
456 | KERNEL2(linalg_matrix_norm, str_ord, fp32_set_opt_dtype) |
457 | // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even |
458 | // when autocasting. |
459 | // KERNEL2(norm, ScalarOpt_dtype, fp32_set_opt_dtype) |
460 | // KERNEL2(norm, ScalarOpt_dim_dtype, fp32_set_opt_dtype) |
461 | // KERNEL2(norm, names_ScalarOpt_dim_dtype, fp32_set_opt_dtype) |
462 | KERNEL(sum, fp32_set_opt_dtype) |
463 | KERNEL2(sum, dim_IntList, fp32_set_opt_dtype) |
464 | KERNEL2(sum, dim_DimnameList, fp32_set_opt_dtype) |
465 | // fp32_append_dtype |
466 | // The fp32_append_dtype wrapper overrides implicit promotion behavior. |
467 | // norm does not implicitly promote, but be aware when adding new ops to this policy. |
468 | KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar" , Tensor (const Tensor &, const Scalar&), Tensor (const Tensor &, const c10::optional<Scalar>&, ScalarType), fp32_append_dtype) |
469 | KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim" , Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool, ScalarType), fp32_append_dtype) |
470 | KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim" , Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool, ScalarType), fp32_append_dtype) |
471 | // promote |
472 | KERNEL(addcdiv, promote) |
473 | KERNEL(addcmul, promote) |
474 | KERNEL(atan2, promote) |
475 | KERNEL(bilinear, promote) |
476 | KERNEL(cross, promote) |
477 | KERNEL(dot, promote) |
478 | KERNEL(grid_sampler, promote) |
479 | KERNEL(index_put, promote) |
480 | KERNEL(tensordot, promote) |
481 | KERNEL(scatter_add, promote) |
482 | |
483 | m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy" ), |
484 | TORCH_FN((&at::autocast::binary_cross_entropy_banned))); |
485 | } |
486 | |
487 | TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { |
488 | m.fallback(torch::CppFunction::makeFallthrough()); |
489 | } |
490 | |
491 | |
492 | TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { |
493 | // lower_precision_fp cast policy |
494 | KERNEL_CPU(conv1d, lower_precision_fp) |
495 | KERNEL_CPU2(conv1d, padding, lower_precision_fp) |
496 | KERNEL_CPU(conv2d, lower_precision_fp) |
497 | KERNEL_CPU2(conv2d, padding, lower_precision_fp) |
498 | KERNEL_CPU(conv3d, lower_precision_fp) |
499 | KERNEL_CPU2(conv3d, padding, lower_precision_fp) |
500 | KERNEL_CPU(bmm, lower_precision_fp) |
501 | KERNEL_CPU(mm, lower_precision_fp) |
502 | KERNEL_CPU(baddbmm, lower_precision_fp) |
503 | KERNEL_CPU(addmm, lower_precision_fp) |
504 | KERNEL_CPU(addbmm, lower_precision_fp) |
505 | KERNEL_CPU(linear, lower_precision_fp) |
506 | KERNEL_CPU2(_convolution, deprecated, lower_precision_fp) |
507 | KERNEL_CPU(matmul, lower_precision_fp) |
508 | KERNEL_CPU(conv_tbc, lower_precision_fp) |
509 | KERNEL_CPU(mkldnn_rnn_layer, lower_precision_fp) |
510 | KERNEL_CPU(conv_transpose1d, lower_precision_fp) |
511 | KERNEL_CPU2(conv_transpose2d, input, lower_precision_fp) |
512 | KERNEL_CPU2(conv_transpose3d, input, lower_precision_fp) |
513 | |
514 | // fp32 cast policy |
515 | KERNEL_CPU(avg_pool3d, fp32) |
516 | KERNEL_CPU(binary_cross_entropy, fp32) |
517 | KERNEL_CPU(grid_sampler, fp32) |
518 | KERNEL_CPU(polar, fp32) |
519 | KERNEL_CPU(prod, fp32) |
520 | KERNEL_CPU2(prod, dim_int, fp32) |
521 | KERNEL_CPU2(prod, dim_Dimname, fp32) |
522 | KERNEL_CPU(quantile, fp32) |
523 | KERNEL_CPU2(quantile, scalar, fp32) |
524 | KERNEL_CPU(nanquantile, fp32) |
525 | KERNEL_CPU2(nanquantile, scalar, fp32) |
526 | KERNEL_CPU(stft, fp32) |
527 | KERNEL_CPU2(stft, center, fp32) |
528 | KERNEL_CPU(cdist, fp32) |
529 | KERNEL_CPU(grid_sampler_2d, fp32) |
530 | KERNEL_CPU(_grid_sampler_2d_cpu_fallback, fp32) |
531 | KERNEL_CPU(grid_sampler_3d, fp32) |
532 | KERNEL_CPU(trace, fp32) |
533 | KERNEL_CPU(view_as_complex, fp32) |
534 | KERNEL_CPU(cholesky, fp32) |
535 | KERNEL_CPU(cholesky_inverse, fp32) |
536 | KERNEL_CPU(cholesky_solve, fp32) |
537 | KERNEL_CPU(inverse, fp32) |
538 | KERNEL_CPU(lu_solve, fp32) |
539 | KERNEL_CPU(orgqr, fp32) |
540 | KERNEL_CPU(ormqr, fp32) |
541 | KERNEL_CPU(pinverse, fp32) |
542 | KERNEL_CPU(max_pool3d, fp32) |
543 | KERNEL_CPU(max_unpool2d, fp32) |
544 | KERNEL_CPU(max_unpool3d, fp32) |
545 | KERNEL_CPU(adaptive_avg_pool3d, fp32) |
546 | KERNEL_CPU(reflection_pad1d, fp32) |
547 | KERNEL_CPU(reflection_pad2d, fp32) |
548 | KERNEL_CPU(replication_pad1d, fp32) |
549 | KERNEL_CPU(replication_pad2d, fp32) |
550 | KERNEL_CPU(replication_pad3d, fp32) |
551 | KERNEL_CPU(mse_loss, fp32) |
552 | KERNEL_CPU(cosine_embedding_loss, fp32) |
553 | KERNEL_CPU(nll_loss, fp32) |
554 | KERNEL_CPU(nll_loss2d, fp32) |
555 | KERNEL_CPU(hinge_embedding_loss, fp32) |
556 | KERNEL_CPU(poisson_nll_loss, fp32) |
557 | KERNEL_CPU(smooth_l1_loss, fp32) |
558 | KERNEL_CPU(cross_entropy_loss, fp32) |
559 | KERNEL_CPU(l1_loss, fp32) |
560 | KERNEL_CPU(huber_loss, fp32) |
561 | KERNEL_CPU(margin_ranking_loss, fp32) |
562 | KERNEL_CPU(soft_margin_loss, fp32) |
563 | KERNEL_CPU(triplet_margin_loss, fp32) |
564 | KERNEL_CPU(multi_margin_loss, fp32) |
565 | KERNEL_CPU2(ctc_loss, IntList, fp32) |
566 | KERNEL_CPU2(ctc_loss, Tensor, fp32) |
567 | KERNEL_CPU(kl_div, fp32) |
568 | KERNEL_CPU(multilabel_margin_loss, fp32) |
569 | KERNEL_CPU(binary_cross_entropy_with_logits, fp32) |
570 | KERNEL_CPU(fft_fft, fp32) |
571 | KERNEL_CPU(fft_ifft, fp32) |
572 | KERNEL_CPU(fft_fft2, fp32) |
573 | KERNEL_CPU(fft_ifft2, fp32) |
574 | KERNEL_CPU(fft_fftn, fp32) |
575 | KERNEL_CPU(fft_ifftn, fp32) |
576 | KERNEL_CPU(fft_rfft, fp32) |
577 | KERNEL_CPU(fft_irfft, fp32) |
578 | KERNEL_CPU(fft_rfft2, fp32) |
579 | KERNEL_CPU(fft_irfft2, fp32) |
580 | KERNEL_CPU(fft_rfftn, fp32) |
581 | KERNEL_CPU(fft_irfftn, fp32) |
582 | KERNEL_CPU(fft_hfft, fp32) |
583 | KERNEL_CPU(fft_ihfft, fp32) |
584 | KERNEL_CPU(linalg_cond, fp32) |
585 | KERNEL_CPU2(linalg_cond, p_str, fp32) |
586 | KERNEL_CPU(linalg_matrix_rank, fp32) |
587 | KERNEL_CPU2(linalg_matrix_rank, tol_tensor, fp32) |
588 | KERNEL_CPU2(linalg_matrix_rank, atol_rtol_tensor, fp32) |
589 | KERNEL_CPU2(linalg_matrix_rank, atol_rtol_float, fp32) |
590 | KERNEL_CPU(linalg_solve, fp32) |
591 | KERNEL_CPU(linalg_cholesky, fp32) |
592 | KERNEL_CPU(linalg_svdvals, fp32) |
593 | KERNEL_CPU(linalg_eigvals, fp32) |
594 | KERNEL_CPU(linalg_eigvalsh, fp32) |
595 | KERNEL_CPU(linalg_inv, fp32) |
596 | KERNEL_CPU(linalg_householder_product, fp32) |
597 | KERNEL_CPU(linalg_tensorinv, fp32) |
598 | KERNEL_CPU(linalg_tensorsolve, fp32) |
599 | KERNEL_CPU(fake_quantize_per_tensor_affine, fp32) |
600 | KERNEL_CPU(geqrf, fp32) |
601 | KERNEL_CPU(_lu_with_info, fp32) |
602 | KERNEL_CPU(qr, fp32) |
603 | KERNEL_CPU(svd, fp32) |
604 | KERNEL_CPU(triangular_solve, fp32) |
605 | KERNEL_CPU(fractional_max_pool2d, fp32) |
606 | KERNEL_CPU(fractional_max_pool3d, fp32) |
607 | KERNEL_CPU(adaptive_max_pool3d, fp32) |
608 | KERNEL_CPU(multilabel_margin_loss_forward, fp32) |
609 | KERNEL_CPU(linalg_qr, fp32) |
610 | KERNEL_CPU(linalg_cholesky_ex, fp32) |
611 | KERNEL_CPU(linalg_svd, fp32) |
612 | KERNEL_CPU(linalg_eig, fp32) |
613 | KERNEL_CPU(linalg_eigh, fp32) |
614 | KERNEL_CPU(linalg_lstsq, fp32) |
615 | KERNEL_CPU(linalg_inv_ex, fp32) |
616 | |
617 | // promote |
618 | KERNEL_CPU(stack, promote) |
619 | KERNEL_CPU(cat, promote) |
620 | KERNEL_CPU(index_copy, promote) |
621 | KERNEL_CPU2(index_copy, dimname, promote) |
622 | |
623 | } |
624 | } // namespace |
625 | } // namespace autocast |
626 | } // namespace at |
627 | |