1#pragma once
2
3#include <stack>
4
5#include <c10/core/InferenceMode.h>
6#include <c10/core/impl/LocalDispatchKeySet.h>
7#include <c10/util/Exception.h>
8#include <c10/util/ThreadLocalDebugInfo.h>
9
10#include <ATen/FuncTorchTLS.h>
11#include <ATen/PythonTorchFunctionTLS.h>
12#include <ATen/SavedTensorHooks.h>
13#include <ATen/ThreadLocalPythonObjects.h>
14#include <ATen/record_function.h>
15#include <c10/core/impl/PythonDispatcherTLS.h>
16#include <c10/core/impl/TorchDispatchModeTLS.h>
17
18namespace at {
19
20// Thread local state contains values that are preserved across
21// thread boundaries (e.g. at::launch/JIT fork, autograd).
22// Note at::parallel_for doesn't preserve TLS across thread boundaries.
23class TORCH_API ThreadLocalState {
24 public:
25 // Saves the thread local variables' values and
26 // returns them as a ThreadLocalState
27 ThreadLocalState();
28
29 // set_grad_mode - force the value of the grad mode TLS in
30 // the current state object. This is used for example in the
31 // autograd engine.
32 void set_grad_mode(bool enabled);
33
34 // set_multithreading_enabled - force the value of the multithreadinmaximum
35 // threads TLS in
36 // the current state object. This is used for example in the
37 // autograd engine.
38 void set_multithreading_enabled(bool enabled);
39
40 // Sets thread local variables in the current thread,
41 // according to the thread boundary specified
42 static void setThreadLocalState(const ThreadLocalState& state);
43
44 private:
45 c10::impl::LocalDispatchKeySet dispatch_key_;
46
47 // ThreadLocalDebugInfo does not change after being created
48 // with DebugInfoGuard
49 std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
50
51 // RecordFunction TLS
52 RecordFunctionTLS rf_tls_;
53
54 // TLS for out-of-tree functorch
55 // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
56 // pointer (spoiler alert: it's due to the indirection)
57 // This needs to be a shared_ptr instead of a unique_ptr because
58 // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
59 // consider adding an explicit copy constructor for ThreadLocalState in the
60 // future but I didn't want to add one just for this.
61 std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
62
63 // TLS for AutogradModes
64 AutogradState autograd_tls_;
65
66 // TLS for enable_torch_dispatch_mode
67 c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
68
69 // TLS for enable_python_dispatcher
70 c10::impl::PyInterpreter* python_dispatcher_state_;
71
72 // TLS for __torch_function__ (mode and disable_torch_function)
73 at::impl::PythonTorchFunctionTLS python_torch_function_state_;
74
75 // TLS for saved tensors default hooks
76 at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
77
78 bool functionalization_reapply_views_state_;
79
80 // TLS for arbitrary python objects that is registered via hooks
81 at::impl::ThreadLocalPythonObjects saved_objects_;
82
83 friend class ThreadLocalStateGuard;
84};
85
86// Guard to set and reset the thread local state
87class TORCH_API ThreadLocalStateGuard {
88 public:
89 explicit ThreadLocalStateGuard(const ThreadLocalState& state)
90 : prev_state_(ThreadLocalState()) {
91 // set the given state across the thread boundary
92 ThreadLocalState::setThreadLocalState(state);
93 }
94
95 ~ThreadLocalStateGuard() {
96 // restore previously set variables
97 ThreadLocalState::setThreadLocalState(prev_state_);
98 }
99
100 private:
101 const ThreadLocalState prev_state_;
102};
103
104template <typename T>
105auto wrapPropagateTLSState(T callback) {
106 return [tls_state = ThreadLocalState(),
107 callback = std::move(callback)](auto&&... args) {
108 ThreadLocalStateGuard g(tls_state);
109 // Propagate value returned by callback().
110 return callback(std::forward<decltype(args)>(args)...);
111 };
112}
113
114} // namespace at
115