1 | #pragma once |
2 | |
3 | #include <c10/core/AutogradState.h> |
4 | #include <c10/core/GradMode.h> |
5 | #include <c10/core/impl/LocalDispatchKeySet.h> |
6 | #include <c10/macros/Macros.h> |
7 | |
8 | namespace c10 { |
9 | |
10 | // A RAII, thread local (!) guard that enables or disables inference mode upon |
11 | // construction, and sets it back to the original value upon destruction. |
12 | struct C10_API InferenceMode { |
13 | // Note [Expected TLS state in InferenceMode]: |
14 | // InferenceMode: ADInplaceOrView not in |
15 | // raw_local_dispatch_key_set.included(), |
16 | // Autograd in raw_local_dispatch_key_set.excluded() |
17 | // GradMode is disabled. |
18 | // NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(), |
19 | // Autograd not in raw_local_dispatch_key_set.excluded() |
20 | // GradMode is enabled by default unless toggled manually |
21 | // through other APIs, e.g. NoGradGuard. |
22 | // |
23 | // Invariant: |
24 | // - ADInplaceOrView is never in the excluded set |
25 | // - Autograd is never in the included set |
26 | // - Setting InferenceMode will set GradMode accordingly, but not vice versa. |
27 | // |
28 | // 1. Why do we put ADInplaceOrView in included set outside InferenceMode? |
29 | // |
30 | // Inplace update to inference tensor outside InferenceMode is not |
31 | // allowed. See Note [Inplace update inference tensor] for more details. |
32 | // Without going through ADInplaceOrView kernel, we cannot throw error |
33 | // for `inference_tensor.add_(1)` case. |
34 | // |
35 | // 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode? |
36 | // |
37 | // For example: |
38 | // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); |
39 | // torch::Tensor k = a + 2; |
40 | // { |
41 | // c10::InferenceMode guard(true); |
42 | // k.add_(2); |
43 | // } |
44 | // `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's |
45 | // prepared for future autograd. |
46 | // |
47 | // 3. Why does setting InferenceMode also set GradMode? |
48 | // |
49 | // This is required since InferenceMode is a faster and more restricive |
50 | // version of NoGradGuard. All runtime checks using GradMode::is_enabled() |
51 | // are applicable to InferenceMode as well, e.g. |
52 | // `tensorTypeInCurrentExecutionContext` in interpreter.cpp. |
53 | InferenceMode(bool enabled = true) |
54 | : prev_mode(AutogradState::get_tls_state()), |
55 | prev_keyset(c10::impl::tls_local_dispatch_key_set()) { |
56 | // Enabling inference mode means disabling grad modes |
57 | // And disabling inference mode means enabling grad modes |
58 | AutogradState::set_tls_state(AutogradState( |
59 | /* grad_mode */ !enabled, |
60 | /* inference_mode */ enabled, |
61 | /* fw_grad_mode */ !enabled, |
62 | /* multithreading_enabled*/ !enabled)); |
63 | DispatchKeySet included = enabled |
64 | ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView) |
65 | : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView); |
66 | DispatchKeySet excluded = enabled |
67 | ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) |
68 | : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); |
69 | c10::impl::PODLocalDispatchKeySet cur_keyset{}; |
70 | cur_keyset.set_included(included); |
71 | cur_keyset.set_excluded(excluded); |
72 | c10::impl::_force_tls_local_dispatch_key_set(cur_keyset); |
73 | } |
74 | |
75 | ~InferenceMode() { |
76 | AutogradState::set_tls_state(prev_mode); |
77 | c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); |
78 | } |
79 | static bool is_enabled(); |
80 | |
81 | private: |
82 | AutogradState prev_mode; |
83 | c10::impl::LocalDispatchKeySet prev_keyset; |
84 | }; |
85 | } // namespace c10 |
86 | |