1
2#include <torch/csrc/jit/passes/autocast.h>
3
4#include <ATen/autocast_mode.h>
5#include <c10/core/ScalarType.h>
6#include <c10/util/Exception.h>
7#include <c10/util/Optional.h>
8#include <torch/csrc/jit/ir/ir.h>
9#include <torch/csrc/jit/jit_log.h>
10#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
11#include <torch/csrc/jit/passes/quantization/helper.h>
12
13#include <stack>
14#include <unordered_set>
15#include <vector>
16
17namespace torch {
18namespace jit {
19
20namespace {
21
22bool autocast_enabled = true;
23
24struct AutocastContext {
25 bool gpu_enabled = false;
26 bool cpu_enabled = false;
27 c10::ScalarType gpu_scalar_type = c10::ScalarType::Undefined;
28 c10::ScalarType cpu_scalar_type = c10::ScalarType::Undefined;
29
30 operator bool() const {
31 return gpu_enabled || cpu_enabled;
32 }
33};
34
35struct AutocastScope {
36 Value* instance = nullptr;
37 AutocastContext context;
38 void stack(const AutocastContext& parent_context) {}
39};
40
41bool isAutocastNode(Value* value) {
42 const auto class_name = getModuleName(value);
43 return class_name.has_value() &&
44 (*class_name == "__torch__.torch.cuda.amp.autocast_mode.autocast" ||
45 *class_name == "__torch__.torch.cpu.amp.autocast_mode.autocast" ||
46 *class_name == "__torch__.torch.amp.autocast_mode.autocast");
47}
48
49// If we have an autocast instance, return it
50//
51// This is the pattern we're looking for (this is done after
52// autocast.__init__() has been inlined)
53//
54// %4 : bool = prim::Constant[value=1]()
55// %5 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
56// = prim::SetAttr[name="_enabled"](%5, %4)
57//
58// Notes:
59// 1. There's no guarantee that the autocast instance is in the same block
60// as the prim::Enter() node
61// 2. `prim::SetAttr` must follow `prim::CreateObject()` in the same block,
62// but there might be other nodes in between
63//
64c10::optional<AutocastScope> parseAutocast(
65 Value* value,
66 const AutocastContext& context) {
67 if (!isAutocastNode(value)) {
68 // Not an autocast...
69 return c10::nullopt;
70 }
71 if (value->node()->kind() == prim::CreateObject) {
72 AutocastScope scope;
73 scope.instance = value;
74 scope.context = context;
75 c10::optional<bool> enabled;
76 std::string device;
77 c10::ScalarType dtype = c10::ScalarType::Undefined;
78 for (Use use : value->uses()) {
79 // TODO: support runtime flag
80 if (use.user->kind() == prim::SetAttr &&
81 use.user->s(attr::name) == "_enabled") {
82 // Search for `prim::SetAttr[name="_enabled"]`
83 auto ret = constant_as<bool>(use.user->input(1));
84 TORCH_CHECK(
85 ret.has_value(), "Autocast _enabled argument must be a constant");
86 enabled = ret.value();
87 } else if (
88 use.user->kind() == prim::SetAttr &&
89 use.user->s(attr::name) == "device") {
90 // Search for `prim::SetAttr[name="device"]`
91 auto ret = constant_as<std::string>(use.user->input(1));
92 TORCH_CHECK(
93 ret.has_value(), "Autocast device argument must be a constant");
94 device = ret.value();
95 } else if (
96 use.user->kind() == prim::SetAttr &&
97 use.user->s(attr::name) == "fast_dtype") {
98 // Search for `prim::SetAttr[name="fast_dtype"]`
99 auto ret = constant_as<c10::ScalarType>(use.user->input(1));
100 TORCH_CHECK(
101 ret.has_value() && ret.value() != c10::ScalarType::Undefined,
102 "Autocast dtype argument must be a constant and defined");
103 dtype = ret.value();
104 }
105 }
106 TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute");
107 TORCH_CHECK(
108 dtype != c10::ScalarType::Undefined,
109 "Autocast missing fast_dtype attribute");
110 TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
111 if (device == "cuda") {
112 scope.context.gpu_enabled = enabled.value();
113 scope.context.gpu_scalar_type = dtype;
114 } else if (device == "cpu") {
115 scope.context.cpu_enabled = enabled.value();
116 scope.context.cpu_scalar_type = dtype;
117 } else {
118 TORCH_INTERNAL_ASSERT(
119 false, "unrecognized device for autocast pass: ", device);
120 }
121 return scope;
122 } else {
123 // We only support simple and static autocast expressions. For example,
124 // the following should report an error (since the autocast would not
125 // work as expected)
126 //
127 // autocast_on = autocast(enabled=True)
128 // autocast_off = autocast(enabled=False)
129 // with autocast_on if condition else autocast_off:
130 // ...
131 //
132 // TODO: better error message
133 //
134 AT_ERROR("Unsupported autocast syntax");
135 }
136
137 return c10::nullopt;
138}
139
140void castTensorInputs(
141 Node* node,
142 Symbol cast_op,
143 const AutocastContext& context) {
144 if (!context) {
145 return;
146 }
147
148 const auto graph = node->owningGraph();
149
150 std::unordered_set<Value*> casted_inputs;
151 // need to also keep the inputs in order, otherwise tracing fails
152 // sanity checks because casting ops are inserted in random order
153 std::vector<Value*> casted_inputs_ordered;
154 for (auto input : node->inputs()) {
155 // TODO: update cast_op signature to take dynamic context flags
156 auto input_tensor_type = input->type()->cast<TensorType>();
157 if (input_tensor_type && input->node()->kind() != cast_op) {
158 auto has_inserted = casted_inputs.insert(input);
159 if (has_inserted.second) {
160 casted_inputs_ordered.push_back(input);
161 }
162 }
163 }
164
165 WithInsertPoint insert_point(node);
166
167 for (auto input : casted_inputs_ordered) {
168 if (cast_op == aten::_autocast_to_full_precision) {
169 const auto new_input = graph->insert(
170 cast_op,
171 {input,
172 graph->insertConstant(IValue(context.gpu_enabled)),
173 graph->insertConstant(IValue(context.cpu_enabled))});
174 node->replaceInputWith(input, new_input);
175 } else if (cast_op == aten::_autocast_to_reduced_precision) {
176 const auto new_input = graph->insert(
177 cast_op,
178 {input,
179 graph->insertConstant(IValue(context.gpu_enabled)),
180 graph->insertConstant(IValue(context.cpu_enabled)),
181 graph->insertConstant(IValue(context.gpu_scalar_type)),
182 graph->insertConstant(IValue(context.cpu_scalar_type))});
183 node->replaceInputWith(input, new_input);
184 } else {
185 TORCH_INTERNAL_ASSERT(
186 false, "unrecognized cast_op symbol: ", cast_op.toQualString());
187 }
188 }
189}
190
191bool hasExplicitDtypeArgument(Node* node) {
192 if (node->hasNamedInput("dtype")) {
193 Value* dtype_arg = node->namedInput("dtype");
194 return dtype_arg->type()->kind() != TypeKind::NoneType;
195 }
196 return false;
197}
198
199void castInputsToWidestType(Node* node, const AutocastContext& context) {
200 if (!context) {
201 return;
202 }
203 // Figure out the widest type
204 // (really, just looking for any float32 inputs)
205 //
206 // TODO: revisit this (do we need to consider float64 types?)
207 //
208 for (auto input : node->inputs()) {
209 if (auto tensor_type = input->type()->cast<TensorType>()) {
210 const auto dtype = tensor_type->scalarType();
211 if (!dtype.has_value() || *dtype == at::ScalarType::Float) {
212 castTensorInputs(node, aten::_autocast_to_full_precision, context);
213 return;
214 }
215 }
216 }
217}
218
219// Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to
220// determine whether autocasting is enabled. With JIT-scripted functions, we
221// actually need to return true if eager autocast OR jit autocast are enabled.
222//
223// In the case where JIT autocast is enabled, we replace
224// %x : bool = aten::is_autocast_enabled()
225// with a constant "True".
226//
227// More context on eager vs JIT autocasting:
228//
229// Autocasting actually has two settings: eager autocasting, and JIT
230// autocasting. Eager autocasting is the thread-local setting that turns on
231// the relevant bit in the dispatcher settings. JIT autocasting is the pass
232// implemented in this file, which makes changes to the graph to insert casting
233// ops in order to achieve the same behavior as eager autocasting.
234//
235// If eager autocasting is enabled at the time when a JIT-scripted function is
236// invoked, then autocasting will occur regardless of what the JIT-autocasting
237// settings are.
238void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) {
239 if (!is_jit_enabled) {
240 return;
241 }
242
243 auto graph = node->owningGraph();
244
245 WithInsertPoint insert_point(node);
246
247 Value* true_constant = graph->insertConstant(IValue(true));
248 node->output()->replaceAllUsesWith(true_constant);
249 node->destroy();
250}
251
252// [Note: implicit type promotion in Autocast]
253//
254// Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with
255// a few exceptions, e.g. `aten::add`, which is needed to be put to promotion
256// list for JIT autocast.
257// The reason is that in eager amp, some binary ops promote inputs implicitly
258// inside the operation, e.g. `aten::add` with fp16 & fp32 inputs would both be
259// casted to fp32. In backward, autograd would cast dgrad to match their
260// scalar_type in forward graph. So inputs with mismatched scalar_type would
261// get the different dgrad.
262// While in JIT, autodiff doesn't do this, so implicit cast is not visible to
263// autodiff and backward dgrad for mismatched inputs would ended up with dgrads
264// in the same scalar_type. This has caused downstream operations, which
265// expects dgrad to be the same scalar type to throw mismatch error.
266//
267// TODO: Use the list from AMP eager directly
268void handleBlock(Block* block, AutocastContext initial_state) {
269 std::stack<AutocastScope> autocast_stack;
270
271 c10::optional<bool> incompatible_amp = c10::nullopt;
272
273 // The current autocast enabled/disabled state
274 auto current_state = [&] {
275 return autocast_stack.empty() ? initial_state
276 : autocast_stack.top().context;
277 };
278
279 for (Node* node : block->nodes()) {
280 switch (node->kind()) {
281 case prim::CallFunction:
282 // TODO: limit it only to amp related node;
283 if (current_state() == initial_state) {
284 // if the current autocasting state is the same as the global state,
285 // then autocasting will be done correctly on subsequent method and
286 // function calls
287 if (current_state()) {
288 castTensorInputs(
289 node, aten::_autocast_to_full_precision, current_state());
290 }
291 break;
292 }
293 TORCH_INTERNAL_ASSERT(
294 !incompatible_amp.has_value() || incompatible_amp.value(),
295 "Calls are not expected with AMP & JIT");
296 incompatible_amp = true;
297 break;
298
299 case prim::CallMethod:
300 // TODO: limit it only to amp related node;
301 if (current_state() == initial_state) {
302 // if the current autocasting state is the same as the global state,
303 // then autocasting will be done correctly on subsequent method and
304 // function calls
305 if (current_state()) {
306 castTensorInputs(
307 node, aten::_autocast_to_full_precision, current_state());
308 }
309 break;
310 }
311 if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
312 const auto& name = node->s(attr::name);
313 const auto& function = class_type->getMethod(name);
314 if (!function.isGraphFunction()) {
315 TORCH_INTERNAL_ASSERT(
316 !incompatible_amp.has_value() || incompatible_amp.value(),
317 "Calls are not expected with AMP & JIT");
318 incompatible_amp = true;
319 }
320 } else {
321 TORCH_INTERNAL_ASSERT(
322 !incompatible_amp.has_value() || incompatible_amp.value(),
323 "Unexpected prim::CallMethod form with AMP & JIT");
324 incompatible_amp = true;
325 }
326 break;
327
328 case prim::Enter:
329 if (auto autocast_scope =
330 parseAutocast(node->input(), current_state())) {
331 if (node->hasUses()) {
332 // TODO: better error message
333 AT_ERROR("`with autocast() as ...` is not supported");
334 }
335 TORCH_INTERNAL_ASSERT(
336 !incompatible_amp.has_value() || !incompatible_amp.value(),
337 "Unsupported case by AMP & JIT");
338 incompatible_amp = false;
339 autocast_stack.push(*autocast_scope);
340 }
341 break;
342
343 case prim::Exit:
344 if (isAutocastNode(node->input(0))) {
345 TORCH_INTERNAL_ASSERT(!autocast_stack.empty());
346 TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input());
347 TORCH_INTERNAL_ASSERT(
348 !incompatible_amp.has_value() || !incompatible_amp.value(),
349 "Unsupported case by AMP & JIT");
350 incompatible_amp = false;
351 autocast_stack.pop();
352 }
353 break;
354
355 case aten::is_autocast_enabled:
356 updateAutocastEnabledCheck(node, current_state().gpu_enabled);
357 break;
358
359 case aten::is_autocast_cpu_enabled:
360 updateAutocastEnabledCheck(node, current_state().cpu_enabled);
361 break;
362
363 // CastPolicy::fp16 (cast all inputs to float16)
364 case aten::_convolution:
365 case aten::conv1d:
366 case aten::conv2d:
367 case aten::conv3d:
368 case aten::conv_tbc:
369 case aten::conv_transpose1d:
370 case aten::convolution:
371 case aten::cudnn_convolution:
372 case aten::cudnn_convolution_transpose:
373 case aten::prelu:
374 case aten::addmm:
375 case aten::addmv:
376 case aten::addr:
377 case aten::matmul:
378 case aten::mm:
379 case aten::mv:
380 case aten::linear:
381 case aten::addbmm:
382 case aten::baddbmm:
383 case aten::bmm:
384 case aten::chain_matmul:
385 case aten::_thnn_fused_lstm_cell:
386 case aten::_thnn_fused_gru_cell:
387 case aten::lstm_cell:
388 case aten::gru_cell:
389 case aten::rnn_tanh_cell:
390 case aten::rnn_relu_cell:
391 if (!node->schema().is_mutable()) {
392 castTensorInputs(
393 node, aten::_autocast_to_reduced_precision, current_state());
394 }
395 break;
396
397 // CastPolicy::fp32 (cast all inputs to float32)
398 case aten::native_layer_norm:
399 case aten::acos:
400 case aten::asin:
401 case aten::cosh:
402 case aten::erfinv:
403 case aten::exp:
404 case aten::expm1:
405 case aten::log:
406 case aten::log10:
407 case aten::log2:
408 case aten::log1p:
409 case aten::reciprocal:
410 case aten::rsqrt:
411 case aten::sinh:
412 case aten::tan:
413 case aten::pow:
414 case aten::softplus:
415 case aten::gelu:
416 case aten::layer_norm:
417 case aten::group_norm:
418 case aten::frobenius_norm:
419 case aten::nuclear_norm:
420 case aten::cosine_similarity:
421 case aten::cosine_embedding_loss:
422 case aten::nll_loss:
423 case aten::nll_loss2d:
424 case aten::hinge_embedding_loss:
425 case aten::kl_div:
426 case aten::l1_loss:
427 case aten::smooth_l1_loss:
428 case aten::mse_loss:
429 case aten::margin_ranking_loss:
430 case aten::multilabel_margin_loss:
431 case aten::soft_margin_loss:
432 case aten::triplet_margin_loss:
433 case aten::multi_margin_loss:
434 case aten::binary_cross_entropy_with_logits:
435 case aten::dist:
436 case aten::pdist:
437 case aten::cdist:
438 case aten::renorm:
439 case aten::logsumexp:
440 if (!node->schema().is_mutable()) {
441 castTensorInputs(
442 node, aten::_autocast_to_full_precision, current_state());
443 }
444 break;
445
446 // CastPolicy::fp32_set_opt_dtype
447 case aten::prod:
448 case aten::log_softmax:
449 case aten::cumprod:
450 case aten::cumsum:
451 case aten::sum:
452 if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
453 castTensorInputs(
454 node, aten::_autocast_to_full_precision, current_state());
455 }
456 break;
457
458 // cast softmax to fp32 only on GPU
459 case aten::softmax:
460 if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
461 auto context = current_state();
462 context.cpu_enabled = false;
463 castTensorInputs(node, aten::_autocast_to_full_precision, context);
464 }
465 break;
466
467 // CastPolicy::promote (promote inputs to the widest type)
468 case aten::addcdiv:
469 case aten::addcmul:
470 case aten::atan2:
471 case aten::bilinear:
472 case aten::cat:
473 case aten::cross:
474 case aten::dot:
475 case aten::equal:
476 case aten::index_put:
477 case aten::stack:
478 case aten::tensordot:
479 // add, sub, mul, div were added to autocast jit, because aten implicit
480 // type promotion is not visible to JIT and could cause dtype mismatch on
481 // backward
482 // see [Note: implicit type promotion in Autocast]
483 case aten::add:
484 case aten::sub:
485 case aten::mul:
486 case aten::div:
487 if (!node->schema().is_mutable()) {
488 castInputsToWidestType(node, current_state());
489 }
490 break;
491
492 // Banned in autocast, see binary_cross_entropy_banned()
493 case aten::binary_cross_entropy:
494 if (current_state()) {
495 AT_ERROR("Unsafe to autocast");
496 }
497 }
498
499 // process sub-blocks, if any
500 for (Block* sub_block : node->blocks()) {
501 handleBlock(sub_block, current_state());
502 }
503 }
504
505 // Sanity check: make sure there's no unbalanced transition
506 TORCH_INTERNAL_ASSERT(autocast_stack.empty());
507}
508
509} // namespace
510
511bool setAutocastMode(bool value) {
512 auto old_value = autocast_enabled;
513 autocast_enabled = value;
514 return old_value;
515}
516
517bool autocastEnabled() {
518 return autocast_enabled;
519}
520
521void Autocast(const std::shared_ptr<Graph>& graph) {
522 GRAPH_DUMP("\nBefore Autocast: ", graph);
523 if (autocastEnabled()) {
524 AutocastContext init = {
525 at::autocast::is_enabled(),
526 at::autocast::is_cpu_enabled(),
527 at::autocast::get_autocast_gpu_dtype(),
528 at::autocast::get_autocast_cpu_dtype()};
529 handleBlock(graph->block(), init);
530 }
531 GRAPH_DUMP("\nAfter Autocast: ", graph);
532}
533
534} // namespace jit
535} // namespace torch
536