1 | |
2 | #include <torch/csrc/jit/passes/autocast.h> |
3 | |
4 | #include <ATen/autocast_mode.h> |
5 | #include <c10/core/ScalarType.h> |
6 | #include <c10/util/Exception.h> |
7 | #include <c10/util/Optional.h> |
8 | #include <torch/csrc/jit/ir/ir.h> |
9 | #include <torch/csrc/jit/jit_log.h> |
10 | #include <torch/csrc/jit/passes/cuda_graph_fuser.h> |
11 | #include <torch/csrc/jit/passes/quantization/helper.h> |
12 | |
13 | #include <stack> |
14 | #include <unordered_set> |
15 | #include <vector> |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | |
20 | namespace { |
21 | |
22 | bool autocast_enabled = true; |
23 | |
24 | struct AutocastContext { |
25 | bool gpu_enabled = false; |
26 | bool cpu_enabled = false; |
27 | c10::ScalarType gpu_scalar_type = c10::ScalarType::Undefined; |
28 | c10::ScalarType cpu_scalar_type = c10::ScalarType::Undefined; |
29 | |
30 | operator bool() const { |
31 | return gpu_enabled || cpu_enabled; |
32 | } |
33 | }; |
34 | |
35 | struct AutocastScope { |
36 | Value* instance = nullptr; |
37 | AutocastContext context; |
38 | void stack(const AutocastContext& parent_context) {} |
39 | }; |
40 | |
41 | bool isAutocastNode(Value* value) { |
42 | const auto class_name = getModuleName(value); |
43 | return class_name.has_value() && |
44 | (*class_name == "__torch__.torch.cuda.amp.autocast_mode.autocast" || |
45 | *class_name == "__torch__.torch.cpu.amp.autocast_mode.autocast" || |
46 | *class_name == "__torch__.torch.amp.autocast_mode.autocast" ); |
47 | } |
48 | |
49 | // If we have an autocast instance, return it |
50 | // |
51 | // This is the pattern we're looking for (this is done after |
52 | // autocast.__init__() has been inlined) |
53 | // |
54 | // %4 : bool = prim::Constant[value=1]() |
55 | // %5 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject() |
56 | // = prim::SetAttr[name="_enabled"](%5, %4) |
57 | // |
58 | // Notes: |
59 | // 1. There's no guarantee that the autocast instance is in the same block |
60 | // as the prim::Enter() node |
61 | // 2. `prim::SetAttr` must follow `prim::CreateObject()` in the same block, |
62 | // but there might be other nodes in between |
63 | // |
64 | c10::optional<AutocastScope> parseAutocast( |
65 | Value* value, |
66 | const AutocastContext& context) { |
67 | if (!isAutocastNode(value)) { |
68 | // Not an autocast... |
69 | return c10::nullopt; |
70 | } |
71 | if (value->node()->kind() == prim::CreateObject) { |
72 | AutocastScope scope; |
73 | scope.instance = value; |
74 | scope.context = context; |
75 | c10::optional<bool> enabled; |
76 | std::string device; |
77 | c10::ScalarType dtype = c10::ScalarType::Undefined; |
78 | for (Use use : value->uses()) { |
79 | // TODO: support runtime flag |
80 | if (use.user->kind() == prim::SetAttr && |
81 | use.user->s(attr::name) == "_enabled" ) { |
82 | // Search for `prim::SetAttr[name="_enabled"]` |
83 | auto ret = constant_as<bool>(use.user->input(1)); |
84 | TORCH_CHECK( |
85 | ret.has_value(), "Autocast _enabled argument must be a constant" ); |
86 | enabled = ret.value(); |
87 | } else if ( |
88 | use.user->kind() == prim::SetAttr && |
89 | use.user->s(attr::name) == "device" ) { |
90 | // Search for `prim::SetAttr[name="device"]` |
91 | auto ret = constant_as<std::string>(use.user->input(1)); |
92 | TORCH_CHECK( |
93 | ret.has_value(), "Autocast device argument must be a constant" ); |
94 | device = ret.value(); |
95 | } else if ( |
96 | use.user->kind() == prim::SetAttr && |
97 | use.user->s(attr::name) == "fast_dtype" ) { |
98 | // Search for `prim::SetAttr[name="fast_dtype"]` |
99 | auto ret = constant_as<c10::ScalarType>(use.user->input(1)); |
100 | TORCH_CHECK( |
101 | ret.has_value() && ret.value() != c10::ScalarType::Undefined, |
102 | "Autocast dtype argument must be a constant and defined" ); |
103 | dtype = ret.value(); |
104 | } |
105 | } |
106 | TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute" ); |
107 | TORCH_CHECK( |
108 | dtype != c10::ScalarType::Undefined, |
109 | "Autocast missing fast_dtype attribute" ); |
110 | TORCH_CHECK(!device.empty(), "Autocast missing device attribute" ); |
111 | if (device == "cuda" ) { |
112 | scope.context.gpu_enabled = enabled.value(); |
113 | scope.context.gpu_scalar_type = dtype; |
114 | } else if (device == "cpu" ) { |
115 | scope.context.cpu_enabled = enabled.value(); |
116 | scope.context.cpu_scalar_type = dtype; |
117 | } else { |
118 | TORCH_INTERNAL_ASSERT( |
119 | false, "unrecognized device for autocast pass: " , device); |
120 | } |
121 | return scope; |
122 | } else { |
123 | // We only support simple and static autocast expressions. For example, |
124 | // the following should report an error (since the autocast would not |
125 | // work as expected) |
126 | // |
127 | // autocast_on = autocast(enabled=True) |
128 | // autocast_off = autocast(enabled=False) |
129 | // with autocast_on if condition else autocast_off: |
130 | // ... |
131 | // |
132 | // TODO: better error message |
133 | // |
134 | AT_ERROR("Unsupported autocast syntax" ); |
135 | } |
136 | |
137 | return c10::nullopt; |
138 | } |
139 | |
140 | void castTensorInputs( |
141 | Node* node, |
142 | Symbol cast_op, |
143 | const AutocastContext& context) { |
144 | if (!context) { |
145 | return; |
146 | } |
147 | |
148 | const auto graph = node->owningGraph(); |
149 | |
150 | std::unordered_set<Value*> casted_inputs; |
151 | // need to also keep the inputs in order, otherwise tracing fails |
152 | // sanity checks because casting ops are inserted in random order |
153 | std::vector<Value*> casted_inputs_ordered; |
154 | for (auto input : node->inputs()) { |
155 | // TODO: update cast_op signature to take dynamic context flags |
156 | auto input_tensor_type = input->type()->cast<TensorType>(); |
157 | if (input_tensor_type && input->node()->kind() != cast_op) { |
158 | auto has_inserted = casted_inputs.insert(input); |
159 | if (has_inserted.second) { |
160 | casted_inputs_ordered.push_back(input); |
161 | } |
162 | } |
163 | } |
164 | |
165 | WithInsertPoint insert_point(node); |
166 | |
167 | for (auto input : casted_inputs_ordered) { |
168 | if (cast_op == aten::_autocast_to_full_precision) { |
169 | const auto new_input = graph->insert( |
170 | cast_op, |
171 | {input, |
172 | graph->insertConstant(IValue(context.gpu_enabled)), |
173 | graph->insertConstant(IValue(context.cpu_enabled))}); |
174 | node->replaceInputWith(input, new_input); |
175 | } else if (cast_op == aten::_autocast_to_reduced_precision) { |
176 | const auto new_input = graph->insert( |
177 | cast_op, |
178 | {input, |
179 | graph->insertConstant(IValue(context.gpu_enabled)), |
180 | graph->insertConstant(IValue(context.cpu_enabled)), |
181 | graph->insertConstant(IValue(context.gpu_scalar_type)), |
182 | graph->insertConstant(IValue(context.cpu_scalar_type))}); |
183 | node->replaceInputWith(input, new_input); |
184 | } else { |
185 | TORCH_INTERNAL_ASSERT( |
186 | false, "unrecognized cast_op symbol: " , cast_op.toQualString()); |
187 | } |
188 | } |
189 | } |
190 | |
191 | bool hasExplicitDtypeArgument(Node* node) { |
192 | if (node->hasNamedInput("dtype" )) { |
193 | Value* dtype_arg = node->namedInput("dtype" ); |
194 | return dtype_arg->type()->kind() != TypeKind::NoneType; |
195 | } |
196 | return false; |
197 | } |
198 | |
199 | void castInputsToWidestType(Node* node, const AutocastContext& context) { |
200 | if (!context) { |
201 | return; |
202 | } |
203 | // Figure out the widest type |
204 | // (really, just looking for any float32 inputs) |
205 | // |
206 | // TODO: revisit this (do we need to consider float64 types?) |
207 | // |
208 | for (auto input : node->inputs()) { |
209 | if (auto tensor_type = input->type()->cast<TensorType>()) { |
210 | const auto dtype = tensor_type->scalarType(); |
211 | if (!dtype.has_value() || *dtype == at::ScalarType::Float) { |
212 | castTensorInputs(node, aten::_autocast_to_full_precision, context); |
213 | return; |
214 | } |
215 | } |
216 | } |
217 | } |
218 | |
219 | // Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to |
220 | // determine whether autocasting is enabled. With JIT-scripted functions, we |
221 | // actually need to return true if eager autocast OR jit autocast are enabled. |
222 | // |
223 | // In the case where JIT autocast is enabled, we replace |
224 | // %x : bool = aten::is_autocast_enabled() |
225 | // with a constant "True". |
226 | // |
227 | // More context on eager vs JIT autocasting: |
228 | // |
229 | // Autocasting actually has two settings: eager autocasting, and JIT |
230 | // autocasting. Eager autocasting is the thread-local setting that turns on |
231 | // the relevant bit in the dispatcher settings. JIT autocasting is the pass |
232 | // implemented in this file, which makes changes to the graph to insert casting |
233 | // ops in order to achieve the same behavior as eager autocasting. |
234 | // |
235 | // If eager autocasting is enabled at the time when a JIT-scripted function is |
236 | // invoked, then autocasting will occur regardless of what the JIT-autocasting |
237 | // settings are. |
238 | void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) { |
239 | if (!is_jit_enabled) { |
240 | return; |
241 | } |
242 | |
243 | auto graph = node->owningGraph(); |
244 | |
245 | WithInsertPoint insert_point(node); |
246 | |
247 | Value* true_constant = graph->insertConstant(IValue(true)); |
248 | node->output()->replaceAllUsesWith(true_constant); |
249 | node->destroy(); |
250 | } |
251 | |
252 | // [Note: implicit type promotion in Autocast] |
253 | // |
254 | // Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with |
255 | // a few exceptions, e.g. `aten::add`, which is needed to be put to promotion |
256 | // list for JIT autocast. |
257 | // The reason is that in eager amp, some binary ops promote inputs implicitly |
258 | // inside the operation, e.g. `aten::add` with fp16 & fp32 inputs would both be |
259 | // casted to fp32. In backward, autograd would cast dgrad to match their |
260 | // scalar_type in forward graph. So inputs with mismatched scalar_type would |
261 | // get the different dgrad. |
262 | // While in JIT, autodiff doesn't do this, so implicit cast is not visible to |
263 | // autodiff and backward dgrad for mismatched inputs would ended up with dgrads |
264 | // in the same scalar_type. This has caused downstream operations, which |
265 | // expects dgrad to be the same scalar type to throw mismatch error. |
266 | // |
267 | // TODO: Use the list from AMP eager directly |
268 | void handleBlock(Block* block, AutocastContext initial_state) { |
269 | std::stack<AutocastScope> autocast_stack; |
270 | |
271 | c10::optional<bool> incompatible_amp = c10::nullopt; |
272 | |
273 | // The current autocast enabled/disabled state |
274 | auto current_state = [&] { |
275 | return autocast_stack.empty() ? initial_state |
276 | : autocast_stack.top().context; |
277 | }; |
278 | |
279 | for (Node* node : block->nodes()) { |
280 | switch (node->kind()) { |
281 | case prim::CallFunction: |
282 | // TODO: limit it only to amp related node; |
283 | if (current_state() == initial_state) { |
284 | // if the current autocasting state is the same as the global state, |
285 | // then autocasting will be done correctly on subsequent method and |
286 | // function calls |
287 | if (current_state()) { |
288 | castTensorInputs( |
289 | node, aten::_autocast_to_full_precision, current_state()); |
290 | } |
291 | break; |
292 | } |
293 | TORCH_INTERNAL_ASSERT( |
294 | !incompatible_amp.has_value() || incompatible_amp.value(), |
295 | "Calls are not expected with AMP & JIT" ); |
296 | incompatible_amp = true; |
297 | break; |
298 | |
299 | case prim::CallMethod: |
300 | // TODO: limit it only to amp related node; |
301 | if (current_state() == initial_state) { |
302 | // if the current autocasting state is the same as the global state, |
303 | // then autocasting will be done correctly on subsequent method and |
304 | // function calls |
305 | if (current_state()) { |
306 | castTensorInputs( |
307 | node, aten::_autocast_to_full_precision, current_state()); |
308 | } |
309 | break; |
310 | } |
311 | if (auto class_type = node->input(0)->type()->cast<ClassType>()) { |
312 | const auto& name = node->s(attr::name); |
313 | const auto& function = class_type->getMethod(name); |
314 | if (!function.isGraphFunction()) { |
315 | TORCH_INTERNAL_ASSERT( |
316 | !incompatible_amp.has_value() || incompatible_amp.value(), |
317 | "Calls are not expected with AMP & JIT" ); |
318 | incompatible_amp = true; |
319 | } |
320 | } else { |
321 | TORCH_INTERNAL_ASSERT( |
322 | !incompatible_amp.has_value() || incompatible_amp.value(), |
323 | "Unexpected prim::CallMethod form with AMP & JIT" ); |
324 | incompatible_amp = true; |
325 | } |
326 | break; |
327 | |
328 | case prim::Enter: |
329 | if (auto autocast_scope = |
330 | parseAutocast(node->input(), current_state())) { |
331 | if (node->hasUses()) { |
332 | // TODO: better error message |
333 | AT_ERROR("`with autocast() as ...` is not supported" ); |
334 | } |
335 | TORCH_INTERNAL_ASSERT( |
336 | !incompatible_amp.has_value() || !incompatible_amp.value(), |
337 | "Unsupported case by AMP & JIT" ); |
338 | incompatible_amp = false; |
339 | autocast_stack.push(*autocast_scope); |
340 | } |
341 | break; |
342 | |
343 | case prim::Exit: |
344 | if (isAutocastNode(node->input(0))) { |
345 | TORCH_INTERNAL_ASSERT(!autocast_stack.empty()); |
346 | TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input()); |
347 | TORCH_INTERNAL_ASSERT( |
348 | !incompatible_amp.has_value() || !incompatible_amp.value(), |
349 | "Unsupported case by AMP & JIT" ); |
350 | incompatible_amp = false; |
351 | autocast_stack.pop(); |
352 | } |
353 | break; |
354 | |
355 | case aten::is_autocast_enabled: |
356 | updateAutocastEnabledCheck(node, current_state().gpu_enabled); |
357 | break; |
358 | |
359 | case aten::is_autocast_cpu_enabled: |
360 | updateAutocastEnabledCheck(node, current_state().cpu_enabled); |
361 | break; |
362 | |
363 | // CastPolicy::fp16 (cast all inputs to float16) |
364 | case aten::_convolution: |
365 | case aten::conv1d: |
366 | case aten::conv2d: |
367 | case aten::conv3d: |
368 | case aten::conv_tbc: |
369 | case aten::conv_transpose1d: |
370 | case aten::convolution: |
371 | case aten::cudnn_convolution: |
372 | case aten::cudnn_convolution_transpose: |
373 | case aten::prelu: |
374 | case aten::addmm: |
375 | case aten::addmv: |
376 | case aten::addr: |
377 | case aten::matmul: |
378 | case aten::mm: |
379 | case aten::mv: |
380 | case aten::linear: |
381 | case aten::addbmm: |
382 | case aten::baddbmm: |
383 | case aten::bmm: |
384 | case aten::chain_matmul: |
385 | case aten::_thnn_fused_lstm_cell: |
386 | case aten::_thnn_fused_gru_cell: |
387 | case aten::lstm_cell: |
388 | case aten::gru_cell: |
389 | case aten::rnn_tanh_cell: |
390 | case aten::rnn_relu_cell: |
391 | if (!node->schema().is_mutable()) { |
392 | castTensorInputs( |
393 | node, aten::_autocast_to_reduced_precision, current_state()); |
394 | } |
395 | break; |
396 | |
397 | // CastPolicy::fp32 (cast all inputs to float32) |
398 | case aten::native_layer_norm: |
399 | case aten::acos: |
400 | case aten::asin: |
401 | case aten::cosh: |
402 | case aten::erfinv: |
403 | case aten::exp: |
404 | case aten::expm1: |
405 | case aten::log: |
406 | case aten::log10: |
407 | case aten::log2: |
408 | case aten::log1p: |
409 | case aten::reciprocal: |
410 | case aten::rsqrt: |
411 | case aten::sinh: |
412 | case aten::tan: |
413 | case aten::pow: |
414 | case aten::softplus: |
415 | case aten::gelu: |
416 | case aten::layer_norm: |
417 | case aten::group_norm: |
418 | case aten::frobenius_norm: |
419 | case aten::nuclear_norm: |
420 | case aten::cosine_similarity: |
421 | case aten::cosine_embedding_loss: |
422 | case aten::nll_loss: |
423 | case aten::nll_loss2d: |
424 | case aten::hinge_embedding_loss: |
425 | case aten::kl_div: |
426 | case aten::l1_loss: |
427 | case aten::smooth_l1_loss: |
428 | case aten::mse_loss: |
429 | case aten::margin_ranking_loss: |
430 | case aten::multilabel_margin_loss: |
431 | case aten::soft_margin_loss: |
432 | case aten::triplet_margin_loss: |
433 | case aten::multi_margin_loss: |
434 | case aten::binary_cross_entropy_with_logits: |
435 | case aten::dist: |
436 | case aten::pdist: |
437 | case aten::cdist: |
438 | case aten::renorm: |
439 | case aten::logsumexp: |
440 | if (!node->schema().is_mutable()) { |
441 | castTensorInputs( |
442 | node, aten::_autocast_to_full_precision, current_state()); |
443 | } |
444 | break; |
445 | |
446 | // CastPolicy::fp32_set_opt_dtype |
447 | case aten::prod: |
448 | case aten::log_softmax: |
449 | case aten::cumprod: |
450 | case aten::cumsum: |
451 | case aten::sum: |
452 | if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) { |
453 | castTensorInputs( |
454 | node, aten::_autocast_to_full_precision, current_state()); |
455 | } |
456 | break; |
457 | |
458 | // cast softmax to fp32 only on GPU |
459 | case aten::softmax: |
460 | if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) { |
461 | auto context = current_state(); |
462 | context.cpu_enabled = false; |
463 | castTensorInputs(node, aten::_autocast_to_full_precision, context); |
464 | } |
465 | break; |
466 | |
467 | // CastPolicy::promote (promote inputs to the widest type) |
468 | case aten::addcdiv: |
469 | case aten::addcmul: |
470 | case aten::atan2: |
471 | case aten::bilinear: |
472 | case aten::cat: |
473 | case aten::cross: |
474 | case aten::dot: |
475 | case aten::equal: |
476 | case aten::index_put: |
477 | case aten::stack: |
478 | case aten::tensordot: |
479 | // add, sub, mul, div were added to autocast jit, because aten implicit |
480 | // type promotion is not visible to JIT and could cause dtype mismatch on |
481 | // backward |
482 | // see [Note: implicit type promotion in Autocast] |
483 | case aten::add: |
484 | case aten::sub: |
485 | case aten::mul: |
486 | case aten::div: |
487 | if (!node->schema().is_mutable()) { |
488 | castInputsToWidestType(node, current_state()); |
489 | } |
490 | break; |
491 | |
492 | // Banned in autocast, see binary_cross_entropy_banned() |
493 | case aten::binary_cross_entropy: |
494 | if (current_state()) { |
495 | AT_ERROR("Unsafe to autocast" ); |
496 | } |
497 | } |
498 | |
499 | // process sub-blocks, if any |
500 | for (Block* sub_block : node->blocks()) { |
501 | handleBlock(sub_block, current_state()); |
502 | } |
503 | } |
504 | |
505 | // Sanity check: make sure there's no unbalanced transition |
506 | TORCH_INTERNAL_ASSERT(autocast_stack.empty()); |
507 | } |
508 | |
509 | } // namespace |
510 | |
511 | bool setAutocastMode(bool value) { |
512 | auto old_value = autocast_enabled; |
513 | autocast_enabled = value; |
514 | return old_value; |
515 | } |
516 | |
517 | bool autocastEnabled() { |
518 | return autocast_enabled; |
519 | } |
520 | |
521 | void Autocast(const std::shared_ptr<Graph>& graph) { |
522 | GRAPH_DUMP("\nBefore Autocast: " , graph); |
523 | if (autocastEnabled()) { |
524 | AutocastContext init = { |
525 | at::autocast::is_enabled(), |
526 | at::autocast::is_cpu_enabled(), |
527 | at::autocast::get_autocast_gpu_dtype(), |
528 | at::autocast::get_autocast_cpu_dtype()}; |
529 | handleBlock(graph->block(), init); |
530 | } |
531 | GRAPH_DUMP("\nAfter Autocast: " , graph); |
532 | } |
533 | |
534 | } // namespace jit |
535 | } // namespace torch |
536 | |