1#pragma once
2
3#include <c10/core/AutogradState.h>
4#include <c10/core/GradMode.h>
5#include <c10/core/impl/LocalDispatchKeySet.h>
6#include <c10/macros/Macros.h>
7
8namespace c10 {
9
10// A RAII, thread local (!) guard that enables or disables inference mode upon
11// construction, and sets it back to the original value upon destruction.
12struct C10_API InferenceMode {
13 // Note [Expected TLS state in InferenceMode]:
14 // InferenceMode: ADInplaceOrView not in
15 // raw_local_dispatch_key_set.included(),
16 // Autograd in raw_local_dispatch_key_set.excluded()
17 // GradMode is disabled.
18 // NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(),
19 // Autograd not in raw_local_dispatch_key_set.excluded()
20 // GradMode is enabled by default unless toggled manually
21 // through other APIs, e.g. NoGradGuard.
22 //
23 // Invariant:
24 // - ADInplaceOrView is never in the excluded set
25 // - Autograd is never in the included set
26 // - Setting InferenceMode will set GradMode accordingly, but not vice versa.
27 //
28 // 1. Why do we put ADInplaceOrView in included set outside InferenceMode?
29 //
30 // Inplace update to inference tensor outside InferenceMode is not
31 // allowed. See Note [Inplace update inference tensor] for more details.
32 // Without going through ADInplaceOrView kernel, we cannot throw error
33 // for `inference_tensor.add_(1)` case.
34 //
35 // 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode?
36 //
37 // For example:
38 // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
39 // torch::Tensor k = a + 2;
40 // {
41 // c10::InferenceMode guard(true);
42 // k.add_(2);
43 // }
44 // `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's
45 // prepared for future autograd.
46 //
47 // 3. Why does setting InferenceMode also set GradMode?
48 //
49 // This is required since InferenceMode is a faster and more restricive
50 // version of NoGradGuard. All runtime checks using GradMode::is_enabled()
51 // are applicable to InferenceMode as well, e.g.
52 // `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
53 InferenceMode(bool enabled = true)
54 : prev_mode(AutogradState::get_tls_state()),
55 prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
56 // Enabling inference mode means disabling grad modes
57 // And disabling inference mode means enabling grad modes
58 AutogradState::set_tls_state(AutogradState(
59 /* grad_mode */ !enabled,
60 /* inference_mode */ enabled,
61 /* fw_grad_mode */ !enabled,
62 /* multithreading_enabled*/ !enabled));
63 DispatchKeySet included = enabled
64 ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
65 : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
66 DispatchKeySet excluded = enabled
67 ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
68 : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
69 c10::impl::PODLocalDispatchKeySet cur_keyset{};
70 cur_keyset.set_included(included);
71 cur_keyset.set_excluded(excluded);
72 c10::impl::_force_tls_local_dispatch_key_set(cur_keyset);
73 }
74
75 ~InferenceMode() {
76 AutogradState::set_tls_state(prev_mode);
77 c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
78 }
79 static bool is_enabled();
80
81 private:
82 AutogradState prev_mode;
83 c10::impl::LocalDispatchKeySet prev_keyset;
84};
85} // namespace c10
86