1#include <ATen/ThreadLocalState.h>
2
3#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
4#include <ATen/core/grad_mode.h>
5#endif
6
7#include <ATen/record_function.h>
8#include <ATen/SavedTensorHooks.h>
9#include <ATen/FunctionalTensorWrapper.h>
10
11namespace at {
12
13ThreadLocalState::ThreadLocalState()
14 : dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
15 debug_info_(c10::ThreadLocalDebugInfo::current()),
16 rf_tls_(at::get_record_function_tls_()), functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
17 autograd_tls_(c10::AutogradState::get_tls_state()),
18 torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
19 python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
20 saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
21 saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {}
22
23void ThreadLocalState::set_grad_mode(bool enabled) {
24 autograd_tls_.set_grad_mode(enabled);
25}
26
27void ThreadLocalState::set_multithreading_enabled(bool enabled) {
28 autograd_tls_.set_multithreading_enabled(enabled);
29}
30
31/* static */
32void ThreadLocalState::setThreadLocalState(
33 const ThreadLocalState& state) {
34 // Note that setting the InferenceMode TLS in this function is ONLY ok because we always
35 // restore the dispatch key set TLS at the same time.
36 c10::AutogradState::set_tls_state(state.autograd_tls_);
37
38 c10::impl::TorchDispatchModeTLS::set_state(state.torch_dispatch_mode_state_);
39
40 at::impl::PythonTorchFunctionTLS::set_state(state.python_torch_function_state_);
41
42 at::set_record_function_tls_(state.rf_tls_);
43
44 at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_);
45
46 c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
47
48 c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
49
50 c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
51
52 functorch::setFuncTorchTLS(state.functorch_tls_);
53
54 at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_);
55
56 at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_);
57}
58
59} // namespace at
60