1#include <ATen/ATen.h>
2#include <torch/library.h>
3#include <ATen/NativeFunctions.h>
4#include <ATen/autocast_mode.h>
5#include <ATen/Operators.h>
6
7#include <c10/util/intrusive_ptr.h>
8#include <c10/core/impl/LocalDispatchKeySet.h>
9
10#include <iostream>
11#include <exception>
12#include <mutex>
13
14namespace at {
15namespace autocast {
16
17bool is_enabled() {
18 return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA);
19}
20
21void set_enabled(bool new_enabled) {
22 c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled);
23}
24
25bool is_cpu_enabled() {
26 return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCPU);
27}
28
29void set_cpu_enabled(bool new_enabled) {
30 c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU, !new_enabled);
31}
32
33bool is_xpu_enabled() {
34 return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastXPU);
35}
36
37void set_xpu_enabled(bool new_enabled) {
38 c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastXPU, !new_enabled);
39}
40
41bool is_hpu_enabled() {
42 return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastHPU);
43}
44
45void set_hpu_enabled(bool new_enabled) {
46 c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled);
47}
48
49namespace {
50// Imitate Apex and cache some of the casts to streamline parameter reuse.
51// Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below).
52//
53// After discussion with @ezyang, the cache uses the following structure:
54// The key is the fp32 source tensor's TensorImpl*, a proxy for a Tensor uuid that's
55// unchanged across shallow copies.
56// The value is a tuple with a weakref to the source tensor's TensorImpl as the first
57// element and the casted tensor as the second element.
58//
59// The weakref keeps the source's TensorImpl from being deleted. We need to because we're
60// using the source TensorImpl* as the key. If it were deleted, another random Tensor could
61// be allocated whose TensorImpl* happened to have the same value. This TensorImpl* would
62// then mistakenly hit in cache: a rare, intermittent, unpredictable bug.
63//
64// I'm not using the weak_intrusive_ptr as the key because it's more difficult to compare
65// directly against incoming TensorImpl*s.
66using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
67using val_type = std::tuple<weakref_type, Tensor>;
68std::unordered_map<TensorImpl*, val_type> cached_casts;
69std::mutex cached_casts_mutex;
70
71// nesting tracks the nesting depth of the Python-side context manager.
72// When the autocast context manager exits to a nesting level that's outside
73// any instance of autocast (which should occur at the end of each forward pass)
74// it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region.
75thread_local int nesting = 0;
76
77// autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU.
78thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16;
79
80// autocast_xpu_dtype is the lower_precision_fp used by AutocastXPU.
81thread_local at::ScalarType autocast_xpu_dtype = at::kBFloat16;
82
83// autocast_hpu_dtype is the lower_precision_fp used by AutocastHPU.
84thread_local at::ScalarType autocast_hpu_dtype = at::kBFloat16;
85
86// should we enabled the cache inside autocast.
87thread_local bool cache_enabled = true;
88
89// autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
90thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
91}
92
93void clear_cache() {
94 const std::lock_guard<std::mutex> lock(cached_casts_mutex);
95 cached_casts.clear();
96}
97
98int increment_nesting() {
99 return ++nesting;
100}
101
102int decrement_nesting() {
103 return --nesting;
104}
105
106at::ScalarType get_autocast_gpu_dtype() {
107 return autocast_gpu_dtype;
108}
109
110at::ScalarType get_autocast_cpu_dtype() {
111 return autocast_cpu_dtype;
112}
113
114at::ScalarType get_autocast_xpu_dtype() {
115 return autocast_xpu_dtype;
116}
117
118at::ScalarType get_autocast_hpu_dtype() {
119 return autocast_hpu_dtype;
120}
121
122void set_autocast_cpu_dtype(at::ScalarType dtype) {
123 TORCH_CHECK(
124 dtype == at::kBFloat16,
125 "Currently, AutocastCPU only support Bfloat16 as the autocast_cpu_dtype");
126 autocast_cpu_dtype = dtype;
127}
128
129void set_autocast_gpu_dtype(at::ScalarType dtype) {
130 autocast_gpu_dtype = dtype;
131}
132
133void set_autocast_xpu_dtype(at::ScalarType dtype) {
134 autocast_xpu_dtype = dtype;
135}
136
137void set_autocast_hpu_dtype(at::ScalarType dtype) {
138 autocast_hpu_dtype = dtype;
139}
140
141bool is_autocast_cache_enabled() {
142 return cache_enabled;
143}
144
145void set_autocast_cache_enabled(bool enabled) {
146 cache_enabled = enabled;
147}
148
149// Overload to catch Tensor args
150// TODO (possible optimization):
151// Move cast_cache to an inline function in a header with cached_casts declared as
152// extern thread_local in the header.
153Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_type) {
154 if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
155 // Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
156 // See cached_casts declaration above for detailed strategy.
157 bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
158 arg.scalar_type() == at::kFloat && arg.requires_grad() &&
159 arg.is_leaf() && !arg.is_view() && cache_enabled);
160 if (can_try_cache) {
161 const std::lock_guard<std::mutex> lock(cached_casts_mutex);
162 auto it = cached_casts.find(arg.unsafeGetTensorImpl());
163 if (it != cached_casts.end()) {
164 return std::get<1>(it->second);
165 } else {
166 auto casted_arg = arg.to(to_type);
167 cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
168 return casted_arg;
169 }
170 } else {
171 return arg.to(to_type);
172 }
173 } else {
174 return arg;
175 }
176}
177
178// Policies correspond to op categories that need code-divergent handling.
179// Wrapper templates below are specialized based on a policy template parameter.
180enum class CastPolicy : uint8_t {
181 lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before running the op.
182 // Currently, lower_precision_fp is fp16 for AutocastCUDA, and is defined by user(default bf16) for AutocastCPU.
183 fp32, // Cast all inputs to at::kFloat before running the op.
184 fp32_set_opt_dtype, // Treats functions (like softmax) that
185 // 1. we'd like to run in fp32 and
186 // 2. have a c10::optional<ScalarType> arg that controls the output type.
187 // fp32_set_opt_dtype wrappers' policy is: if the output type is already set,
188 // don't touch it, otherwise, set it to at::kFloat.
189 fp32_append_dtype, // Treats functions (like norm) that
190 // 1. we'd like to run in fp32 and
191 // 2. have some overloads that accept an output type and other overloads that don't.
192 // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
193 // The wrapper policy is: append at::kFloat to the args, and redispatch to the
194 // type-aware overload.
195 promote, // Run in the widest dtype among several args.
196};
197
198/********************************************************************************************************
199Templates to provide wrapper functions
200
201I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to extract args and return type.
202(see also https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
203
204This strategy uses an exterior "WrapFunction" that extracts arguments on behalf of
205(in my case several specializations of) an interior "WrapFunction_".
206Interior WrapFunction_ specializations are defined for each CastPolicy.
207********************************************************************************************************/
208
209// Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy
210template<CastPolicy policy, DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class ArgList> struct WrapFunction_ {};
211
212// CastPolicy::lower_precision_fp General_DeviceType
213template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args>
214struct WrapFunction_<CastPolicy::lower_precision_fp, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
215 static Ret call(Args... args) {
216 c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type));
217 return (*F)(cached_cast(get_lower_precision_fp_from_device_type(device_type), args, device_type)...);
218 }
219};
220
221// CastPolicy::fp32 General_DeviceType
222template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args>
223struct WrapFunction_<CastPolicy::fp32, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
224 static Ret call(Args... args) {
225 c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type));
226 return (*F)(cached_cast(at::kFloat, args, device_type)...);
227 }
228};
229
230// CastPolicy::fp32_set_opt_dtype DeviceType::CUDA
231template<class Redispatch, Redispatch* F, class Ret, class... Args>
232struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, DeviceType::CUDA, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
233 static Ret call(Args... args) {
234 c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
235 if (firstarg_is_eligible(args...)) {
236 return (*F)(set_opt_dtype(at::kFloat, args)...);
237 } else {
238 // If ineligible, calls F with unaltered args. Does not set opt dtype, because setting
239 // opt dtype explicitly may interfere with internal implicit promotion decisions.
240 return (*F)(args...);
241 }
242 }
243};
244
245// CastPolicy::fp32_append_dtype DeviceType::CUDA
246template<class Redispatch, Redispatch* F, class Ret, class... Args>
247struct WrapFunction_<CastPolicy::fp32_append_dtype, DeviceType::CUDA, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
248 static Ret call(Args... args) {
249 c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
250 at::ScalarType out_type = type_from_firstarg(at::kFloat, args...);
251 return (*F)(args..., out_type);
252 }
253};
254
255// CastPolicy::promote General_DeviceType
256template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args>
257struct WrapFunction_<CastPolicy::promote, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
258 static Ret call(Args... args) {
259 c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type));
260 auto to_type = promote_type(get_lower_precision_fp_from_device_type(device_type), device_type, args...);
261 return (*F)(cached_cast(to_type, args, device_type)...);
262 }
263};
264
265// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating core/boxing/impl/WrapFunctionIntoFunctor.h)
266template<CastPolicy policy,
267 DeviceType device_type,
268 class Registered, // The signature for which we're registering. The dispatcher's calling code invokes our
269 // registered functions with arguments matching Registered, so we register
270 // WrapFunction_::call methods with a matching signature to properly field those arguments.
271 // guts::function_traits below extracts return_type and parameter_types from Registered,
272 // which WrapFunction_ templates above use to declare their call methods.
273 class Redispatch, // The signature for the function we're redispatching to. In most cases this is the same
274 // as Registered, but for some ops (for example, ops where we append a dtype) it's useful
275 // to redispatch to a function with a different signature.
276 Redispatch* F> // The actual function we're redispatching to.
277struct WrapFunction final {
278 using type = WrapFunction_<policy,
279 device_type,
280 Redispatch,
281 F,
282 typename guts::function_traits<Registered>::return_type,
283 typename guts::function_traits<Registered>::parameter_types>;
284};
285
286/*******************************
287Banned functions
288*******************************/
289
290Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const c10::optional<Tensor>&, int64_t) {
291 AT_ERROR("torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
292 "Many models use a sigmoid layer right before the binary cross entropy layer.\n"
293 "In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
294 "or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n"
295 "safe to autocast.");
296}
297
298namespace {
299/*****************************************************************************************************************
300This section performs load-time registration for autocast wrappers.
301
302It's debatable at what level operations should be patched. We'd like casts to be autograd-exposed
303and precede autograd history recording, so that for lower_precision_fp ops, input tensors are saved for backward
304in lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp can significantly reduce
305a model's memory footprint.
306
307Option 1 (strawman): Patch only at the level of explicit calls into cudnn/cublas (cudnn_convolution, etc),
308because those are the code paths that are guaranteed to use Tensor Cores, therefore they're the ones that
309will benefit most from lower_precision_fp. Potential pitfall: convolutions (and other ops) are wrapped in several
310layers of at::* calls. If one of those happens to record autograd history, then we've lost the
311opportunity to save inputs in lower_precision_fp.
312
313Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd history
314recording can't sneak in ahead of autocast. This mirrors Apex most closely.
315
316I think Option 2 is the right answer for all ops, not just convolutions. Option 2 is what I implement here.
317*****************************************************************************************************************/
318
319/********************************************************************************************************************
320Explicit registration for out-of-place ops
321
322The stuff below could be codegenned. Ed said
323> you are going to have to write the function definition at some point, I wouldn't try to get clever about it
324Therefore, for the moment, this is all copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
325********************************************************************************************************************/
326
327#define ADD_NS(RAW_OP) at::RAW_OP
328
329// Common cases where registration signature matches redispatch signature
330// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
331#define KERNEL(OP, POLICY) \
332 m.impl(TORCH_SELECTIVE_NAME("aten::" #OP), \
333 &WrapFunction<CastPolicy::POLICY, DeviceType::CUDA, decltype(ATEN_FN(OP)), decltype(ATEN_FN(OP)), &ATEN_FN(OP)>::type::call);
334#define KERNEL2(OP, OVERLOAD, POLICY) \
335 m.impl(TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
336 &WrapFunction<CastPolicy::POLICY, DeviceType::CUDA, decltype(ATEN_FN2(OP, OVERLOAD)), decltype(ATEN_FN2(OP, OVERLOAD)), &ATEN_FN2(OP, OVERLOAD)>::type::call);
337
338// Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype)
339#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
340 m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
341 &WrapFunction<CastPolicy::POLICY, DeviceType::CUDA, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call);
342
343// KERNEL_CPU registration for AutocastCPU
344#define KERNEL_CPU(OP, POLICY) \
345 m.impl(TORCH_SELECTIVE_NAME("aten::" #OP), \
346 &WrapFunction<CastPolicy::POLICY, DeviceType::CPU, decltype(ATEN_FN(OP)), decltype(ATEN_FN(OP)), &ATEN_FN(OP)>::type::call);
347#define KERNEL_CPU2(OP, OVERLOAD, POLICY) \
348 m.impl(TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
349 &WrapFunction<CastPolicy::POLICY, DeviceType::CPU, decltype(ATEN_FN2(OP, OVERLOAD)), decltype(ATEN_FN2(OP, OVERLOAD)), &ATEN_FN2(OP, OVERLOAD)>::type::call);
350
351/*****************************************
352Explicit registration for out-of-place ops
353*****************************************/
354TORCH_LIBRARY_IMPL(_, Autocast, m) {
355 m.fallback(torch::CppFunction::makeFallthrough());
356}
357
358TORCH_LIBRARY_IMPL(aten, Autocast, m) {
359 // lower_precision_fp
360 KERNEL2(_convolution, deprecated, lower_precision_fp)
361 KERNEL(_convolution, lower_precision_fp)
362 KERNEL(conv1d, lower_precision_fp)
363 KERNEL(conv2d, lower_precision_fp)
364 KERNEL(conv3d, lower_precision_fp)
365 KERNEL(conv_tbc, lower_precision_fp)
366 KERNEL(conv_transpose1d, lower_precision_fp)
367 KERNEL2(conv_transpose2d, input, lower_precision_fp)
368 KERNEL2(conv_transpose3d, input, lower_precision_fp)
369 KERNEL(convolution, lower_precision_fp)
370 KERNEL(cudnn_convolution, lower_precision_fp)
371 KERNEL(cudnn_convolution_transpose, lower_precision_fp)
372 KERNEL(prelu, lower_precision_fp)
373 KERNEL(addmm, lower_precision_fp)
374 KERNEL(addmv, lower_precision_fp)
375 KERNEL(addr, lower_precision_fp)
376 KERNEL(matmul, lower_precision_fp)
377 KERNEL(einsum, lower_precision_fp)
378 KERNEL(mm, lower_precision_fp)
379 KERNEL(mv, lower_precision_fp)
380 KERNEL(linear, lower_precision_fp)
381 KERNEL(addbmm, lower_precision_fp)
382 KERNEL(baddbmm, lower_precision_fp)
383 KERNEL(bmm, lower_precision_fp)
384 KERNEL(chain_matmul, lower_precision_fp)
385 KERNEL(linalg_multi_dot, lower_precision_fp)
386 KERNEL(_thnn_fused_lstm_cell, lower_precision_fp)
387 KERNEL(_thnn_fused_gru_cell, lower_precision_fp)
388 KERNEL(lstm_cell, lower_precision_fp)
389 KERNEL(gru_cell, lower_precision_fp)
390 KERNEL(rnn_tanh_cell, lower_precision_fp)
391 KERNEL(rnn_relu_cell, lower_precision_fp)
392 KERNEL(_scaled_dot_product_flash_attention, lower_precision_fp)
393 KERNEL(scaled_dot_product_attention, lower_precision_fp)
394
395 // fp32
396 KERNEL(acos, fp32)
397 KERNEL(asin, fp32)
398 KERNEL(cosh, fp32)
399 KERNEL(erfinv, fp32)
400 KERNEL(exp, fp32)
401 KERNEL(expm1, fp32)
402 KERNEL(log, fp32)
403 KERNEL(log10, fp32)
404 KERNEL(log2, fp32)
405 KERNEL(log1p, fp32)
406 KERNEL(reciprocal, fp32)
407 KERNEL(rsqrt, fp32)
408 KERNEL(sinh, fp32)
409 KERNEL(tan, fp32)
410 KERNEL2(pow, Tensor_Scalar, fp32)
411 KERNEL2(pow, Tensor_Tensor, fp32)
412 KERNEL2(pow, Scalar, fp32)
413 KERNEL(softplus, fp32)
414 KERNEL(layer_norm, fp32)
415 KERNEL(native_layer_norm, fp32)
416 KERNEL(group_norm, fp32)
417 KERNEL2(frobenius_norm, dim, fp32)
418 KERNEL(nuclear_norm, fp32)
419 KERNEL2(nuclear_norm, dim, fp32)
420 KERNEL(cosine_similarity, fp32)
421 KERNEL(poisson_nll_loss, fp32)
422 KERNEL(cosine_embedding_loss, fp32)
423 KERNEL(nll_loss, fp32)
424 KERNEL(nll_loss2d, fp32)
425 KERNEL(hinge_embedding_loss, fp32)
426 KERNEL(kl_div, fp32)
427 KERNEL(l1_loss, fp32)
428 KERNEL(smooth_l1_loss, fp32)
429 KERNEL(huber_loss, fp32)
430 KERNEL(mse_loss, fp32)
431 KERNEL(margin_ranking_loss, fp32)
432 KERNEL(multilabel_margin_loss, fp32)
433 KERNEL(soft_margin_loss, fp32)
434 KERNEL(triplet_margin_loss, fp32)
435 KERNEL(multi_margin_loss, fp32)
436 KERNEL(binary_cross_entropy_with_logits, fp32)
437 KERNEL(dist, fp32)
438 KERNEL(pdist, fp32)
439 KERNEL(cdist, fp32)
440 KERNEL(renorm, fp32)
441 KERNEL(logsumexp, fp32)
442 // fp32_set_opt_dtype
443 KERNEL(prod, fp32_set_opt_dtype)
444 KERNEL2(prod, dim_int, fp32_set_opt_dtype)
445 KERNEL2(prod, dim_Dimname, fp32_set_opt_dtype)
446 KERNEL2(softmax, int, fp32_set_opt_dtype)
447 KERNEL2(softmax, Dimname, fp32_set_opt_dtype)
448 KERNEL2(log_softmax, int, fp32_set_opt_dtype)
449 KERNEL2(log_softmax, Dimname, fp32_set_opt_dtype)
450 KERNEL(cumprod, fp32_set_opt_dtype)
451 KERNEL2(cumprod, dimname, fp32_set_opt_dtype)
452 KERNEL(cumsum, fp32_set_opt_dtype)
453 KERNEL2(cumsum, dimname, fp32_set_opt_dtype)
454 KERNEL(linalg_vector_norm, fp32_set_opt_dtype)
455 KERNEL(linalg_matrix_norm, fp32_set_opt_dtype)
456 KERNEL2(linalg_matrix_norm, str_ord, fp32_set_opt_dtype)
457 // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
458 // when autocasting.
459 // KERNEL2(norm, ScalarOpt_dtype, fp32_set_opt_dtype)
460 // KERNEL2(norm, ScalarOpt_dim_dtype, fp32_set_opt_dtype)
461 // KERNEL2(norm, names_ScalarOpt_dim_dtype, fp32_set_opt_dtype)
462 KERNEL(sum, fp32_set_opt_dtype)
463 KERNEL2(sum, dim_IntList, fp32_set_opt_dtype)
464 KERNEL2(sum, dim_DimnameList, fp32_set_opt_dtype)
465 // fp32_append_dtype
466 // The fp32_append_dtype wrapper overrides implicit promotion behavior.
467 // norm does not implicitly promote, but be aware when adding new ops to this policy.
468 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, const Scalar&), Tensor (const Tensor &, const c10::optional<Scalar>&, ScalarType), fp32_append_dtype)
469 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool, ScalarType), fp32_append_dtype)
470 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool, ScalarType), fp32_append_dtype)
471 // promote
472 KERNEL(addcdiv, promote)
473 KERNEL(addcmul, promote)
474 KERNEL(atan2, promote)
475 KERNEL(bilinear, promote)
476 KERNEL(cross, promote)
477 KERNEL(dot, promote)
478 KERNEL(grid_sampler, promote)
479 KERNEL(index_put, promote)
480 KERNEL(tensordot, promote)
481 KERNEL(scatter_add, promote)
482
483 m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
484 TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
485}
486
487TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
488 m.fallback(torch::CppFunction::makeFallthrough());
489}
490
491
492TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
493 // lower_precision_fp cast policy
494 KERNEL_CPU(conv1d, lower_precision_fp)
495 KERNEL_CPU2(conv1d, padding, lower_precision_fp)
496 KERNEL_CPU(conv2d, lower_precision_fp)
497 KERNEL_CPU2(conv2d, padding, lower_precision_fp)
498 KERNEL_CPU(conv3d, lower_precision_fp)
499 KERNEL_CPU2(conv3d, padding, lower_precision_fp)
500 KERNEL_CPU(bmm, lower_precision_fp)
501 KERNEL_CPU(mm, lower_precision_fp)
502 KERNEL_CPU(baddbmm, lower_precision_fp)
503 KERNEL_CPU(addmm, lower_precision_fp)
504 KERNEL_CPU(addbmm, lower_precision_fp)
505 KERNEL_CPU(linear, lower_precision_fp)
506 KERNEL_CPU2(_convolution, deprecated, lower_precision_fp)
507 KERNEL_CPU(matmul, lower_precision_fp)
508 KERNEL_CPU(conv_tbc, lower_precision_fp)
509 KERNEL_CPU(mkldnn_rnn_layer, lower_precision_fp)
510 KERNEL_CPU(conv_transpose1d, lower_precision_fp)
511 KERNEL_CPU2(conv_transpose2d, input, lower_precision_fp)
512 KERNEL_CPU2(conv_transpose3d, input, lower_precision_fp)
513
514 // fp32 cast policy
515 KERNEL_CPU(avg_pool3d, fp32)
516 KERNEL_CPU(binary_cross_entropy, fp32)
517 KERNEL_CPU(grid_sampler, fp32)
518 KERNEL_CPU(polar, fp32)
519 KERNEL_CPU(prod, fp32)
520 KERNEL_CPU2(prod, dim_int, fp32)
521 KERNEL_CPU2(prod, dim_Dimname, fp32)
522 KERNEL_CPU(quantile, fp32)
523 KERNEL_CPU2(quantile, scalar, fp32)
524 KERNEL_CPU(nanquantile, fp32)
525 KERNEL_CPU2(nanquantile, scalar, fp32)
526 KERNEL_CPU(stft, fp32)
527 KERNEL_CPU2(stft, center, fp32)
528 KERNEL_CPU(cdist, fp32)
529 KERNEL_CPU(grid_sampler_2d, fp32)
530 KERNEL_CPU(_grid_sampler_2d_cpu_fallback, fp32)
531 KERNEL_CPU(grid_sampler_3d, fp32)
532 KERNEL_CPU(trace, fp32)
533 KERNEL_CPU(view_as_complex, fp32)
534 KERNEL_CPU(cholesky, fp32)
535 KERNEL_CPU(cholesky_inverse, fp32)
536 KERNEL_CPU(cholesky_solve, fp32)
537 KERNEL_CPU(inverse, fp32)
538 KERNEL_CPU(lu_solve, fp32)
539 KERNEL_CPU(orgqr, fp32)
540 KERNEL_CPU(ormqr, fp32)
541 KERNEL_CPU(pinverse, fp32)
542 KERNEL_CPU(max_pool3d, fp32)
543 KERNEL_CPU(max_unpool2d, fp32)
544 KERNEL_CPU(max_unpool3d, fp32)
545 KERNEL_CPU(adaptive_avg_pool3d, fp32)
546 KERNEL_CPU(reflection_pad1d, fp32)
547 KERNEL_CPU(reflection_pad2d, fp32)
548 KERNEL_CPU(replication_pad1d, fp32)
549 KERNEL_CPU(replication_pad2d, fp32)
550 KERNEL_CPU(replication_pad3d, fp32)
551 KERNEL_CPU(mse_loss, fp32)
552 KERNEL_CPU(cosine_embedding_loss, fp32)
553 KERNEL_CPU(nll_loss, fp32)
554 KERNEL_CPU(nll_loss2d, fp32)
555 KERNEL_CPU(hinge_embedding_loss, fp32)
556 KERNEL_CPU(poisson_nll_loss, fp32)
557 KERNEL_CPU(smooth_l1_loss, fp32)
558 KERNEL_CPU(cross_entropy_loss, fp32)
559 KERNEL_CPU(l1_loss, fp32)
560 KERNEL_CPU(huber_loss, fp32)
561 KERNEL_CPU(margin_ranking_loss, fp32)
562 KERNEL_CPU(soft_margin_loss, fp32)
563 KERNEL_CPU(triplet_margin_loss, fp32)
564 KERNEL_CPU(multi_margin_loss, fp32)
565 KERNEL_CPU2(ctc_loss, IntList, fp32)
566 KERNEL_CPU2(ctc_loss, Tensor, fp32)
567 KERNEL_CPU(kl_div, fp32)
568 KERNEL_CPU(multilabel_margin_loss, fp32)
569 KERNEL_CPU(binary_cross_entropy_with_logits, fp32)
570 KERNEL_CPU(fft_fft, fp32)
571 KERNEL_CPU(fft_ifft, fp32)
572 KERNEL_CPU(fft_fft2, fp32)
573 KERNEL_CPU(fft_ifft2, fp32)
574 KERNEL_CPU(fft_fftn, fp32)
575 KERNEL_CPU(fft_ifftn, fp32)
576 KERNEL_CPU(fft_rfft, fp32)
577 KERNEL_CPU(fft_irfft, fp32)
578 KERNEL_CPU(fft_rfft2, fp32)
579 KERNEL_CPU(fft_irfft2, fp32)
580 KERNEL_CPU(fft_rfftn, fp32)
581 KERNEL_CPU(fft_irfftn, fp32)
582 KERNEL_CPU(fft_hfft, fp32)
583 KERNEL_CPU(fft_ihfft, fp32)
584 KERNEL_CPU(linalg_cond, fp32)
585 KERNEL_CPU2(linalg_cond, p_str, fp32)
586 KERNEL_CPU(linalg_matrix_rank, fp32)
587 KERNEL_CPU2(linalg_matrix_rank, tol_tensor, fp32)
588 KERNEL_CPU2(linalg_matrix_rank, atol_rtol_tensor, fp32)
589 KERNEL_CPU2(linalg_matrix_rank, atol_rtol_float, fp32)
590 KERNEL_CPU(linalg_solve, fp32)
591 KERNEL_CPU(linalg_cholesky, fp32)
592 KERNEL_CPU(linalg_svdvals, fp32)
593 KERNEL_CPU(linalg_eigvals, fp32)
594 KERNEL_CPU(linalg_eigvalsh, fp32)
595 KERNEL_CPU(linalg_inv, fp32)
596 KERNEL_CPU(linalg_householder_product, fp32)
597 KERNEL_CPU(linalg_tensorinv, fp32)
598 KERNEL_CPU(linalg_tensorsolve, fp32)
599 KERNEL_CPU(fake_quantize_per_tensor_affine, fp32)
600 KERNEL_CPU(geqrf, fp32)
601 KERNEL_CPU(_lu_with_info, fp32)
602 KERNEL_CPU(qr, fp32)
603 KERNEL_CPU(svd, fp32)
604 KERNEL_CPU(triangular_solve, fp32)
605 KERNEL_CPU(fractional_max_pool2d, fp32)
606 KERNEL_CPU(fractional_max_pool3d, fp32)
607 KERNEL_CPU(adaptive_max_pool3d, fp32)
608 KERNEL_CPU(multilabel_margin_loss_forward, fp32)
609 KERNEL_CPU(linalg_qr, fp32)
610 KERNEL_CPU(linalg_cholesky_ex, fp32)
611 KERNEL_CPU(linalg_svd, fp32)
612 KERNEL_CPU(linalg_eig, fp32)
613 KERNEL_CPU(linalg_eigh, fp32)
614 KERNEL_CPU(linalg_lstsq, fp32)
615 KERNEL_CPU(linalg_inv_ex, fp32)
616
617 // promote
618 KERNEL_CPU(stack, promote)
619 KERNEL_CPU(cat, promote)
620 KERNEL_CPU(index_copy, promote)
621 KERNEL_CPU2(index_copy, dimname, promote)
622
623}
624} // namespace
625} // namespace autocast
626} // namespace at
627