1 | #pragma once |
---|---|
2 | |
3 | #include <stack> |
4 | |
5 | #include <c10/core/InferenceMode.h> |
6 | #include <c10/core/impl/LocalDispatchKeySet.h> |
7 | #include <c10/util/Exception.h> |
8 | #include <c10/util/ThreadLocalDebugInfo.h> |
9 | |
10 | #include <ATen/FuncTorchTLS.h> |
11 | #include <ATen/PythonTorchFunctionTLS.h> |
12 | #include <ATen/SavedTensorHooks.h> |
13 | #include <ATen/ThreadLocalPythonObjects.h> |
14 | #include <ATen/record_function.h> |
15 | #include <c10/core/impl/PythonDispatcherTLS.h> |
16 | #include <c10/core/impl/TorchDispatchModeTLS.h> |
17 | |
18 | namespace at { |
19 | |
20 | // Thread local state contains values that are preserved across |
21 | // thread boundaries (e.g. at::launch/JIT fork, autograd). |
22 | // Note at::parallel_for doesn't preserve TLS across thread boundaries. |
23 | class TORCH_API ThreadLocalState { |
24 | public: |
25 | // Saves the thread local variables' values and |
26 | // returns them as a ThreadLocalState |
27 | ThreadLocalState(); |
28 | |
29 | // set_grad_mode - force the value of the grad mode TLS in |
30 | // the current state object. This is used for example in the |
31 | // autograd engine. |
32 | void set_grad_mode(bool enabled); |
33 | |
34 | // set_multithreading_enabled - force the value of the multithreadinmaximum |
35 | // threads TLS in |
36 | // the current state object. This is used for example in the |
37 | // autograd engine. |
38 | void set_multithreading_enabled(bool enabled); |
39 | |
40 | // Sets thread local variables in the current thread, |
41 | // according to the thread boundary specified |
42 | static void setThreadLocalState(const ThreadLocalState& state); |
43 | |
44 | private: |
45 | c10::impl::LocalDispatchKeySet dispatch_key_; |
46 | |
47 | // ThreadLocalDebugInfo does not change after being created |
48 | // with DebugInfoGuard |
49 | std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_; |
50 | |
51 | // RecordFunction TLS |
52 | RecordFunctionTLS rf_tls_; |
53 | |
54 | // TLS for out-of-tree functorch |
55 | // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a |
56 | // pointer (spoiler alert: it's due to the indirection) |
57 | // This needs to be a shared_ptr instead of a unique_ptr because |
58 | // ThreadLocalState is copy-able and does indeed get copied. Maybe we can |
59 | // consider adding an explicit copy constructor for ThreadLocalState in the |
60 | // future but I didn't want to add one just for this. |
61 | std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_; |
62 | |
63 | // TLS for AutogradModes |
64 | AutogradState autograd_tls_; |
65 | |
66 | // TLS for enable_torch_dispatch_mode |
67 | c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; |
68 | |
69 | // TLS for enable_python_dispatcher |
70 | c10::impl::PyInterpreter* python_dispatcher_state_; |
71 | |
72 | // TLS for __torch_function__ (mode and disable_torch_function) |
73 | at::impl::PythonTorchFunctionTLS python_torch_function_state_; |
74 | |
75 | // TLS for saved tensors default hooks |
76 | at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; |
77 | |
78 | bool functionalization_reapply_views_state_; |
79 | |
80 | // TLS for arbitrary python objects that is registered via hooks |
81 | at::impl::ThreadLocalPythonObjects saved_objects_; |
82 | |
83 | friend class ThreadLocalStateGuard; |
84 | }; |
85 | |
86 | // Guard to set and reset the thread local state |
87 | class TORCH_API ThreadLocalStateGuard { |
88 | public: |
89 | explicit ThreadLocalStateGuard(const ThreadLocalState& state) |
90 | : prev_state_(ThreadLocalState()) { |
91 | // set the given state across the thread boundary |
92 | ThreadLocalState::setThreadLocalState(state); |
93 | } |
94 | |
95 | ~ThreadLocalStateGuard() { |
96 | // restore previously set variables |
97 | ThreadLocalState::setThreadLocalState(prev_state_); |
98 | } |
99 | |
100 | private: |
101 | const ThreadLocalState prev_state_; |
102 | }; |
103 | |
104 | template <typename T> |
105 | auto wrapPropagateTLSState(T callback) { |
106 | return [tls_state = ThreadLocalState(), |
107 | callback = std::move(callback)](auto&&... args) { |
108 | ThreadLocalStateGuard g(tls_state); |
109 | // Propagate value returned by callback(). |
110 | return callback(std::forward<decltype(args)>(args)...); |
111 | }; |
112 | } |
113 | |
114 | } // namespace at |
115 |