1 | #include <ATen/ThreadLocalState.h> |
2 | |
3 | #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) |
4 | #include <ATen/core/grad_mode.h> |
5 | #endif |
6 | |
7 | #include <ATen/record_function.h> |
8 | #include <ATen/SavedTensorHooks.h> |
9 | #include <ATen/FunctionalTensorWrapper.h> |
10 | |
11 | namespace at { |
12 | |
13 | ThreadLocalState::ThreadLocalState() |
14 | : dispatch_key_(c10::impl::tls_local_dispatch_key_set()), |
15 | debug_info_(c10::ThreadLocalDebugInfo::current()), |
16 | rf_tls_(at::get_record_function_tls_()), functorch_tls_(functorch::getCopyOfFuncTorchTLS()), |
17 | autograd_tls_(c10::AutogradState::get_tls_state()), |
18 | torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), |
19 | python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()), |
20 | saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()), |
21 | saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {} |
22 | |
23 | void ThreadLocalState::set_grad_mode(bool enabled) { |
24 | autograd_tls_.set_grad_mode(enabled); |
25 | } |
26 | |
27 | void ThreadLocalState::set_multithreading_enabled(bool enabled) { |
28 | autograd_tls_.set_multithreading_enabled(enabled); |
29 | } |
30 | |
31 | /* static */ |
32 | void ThreadLocalState::setThreadLocalState( |
33 | const ThreadLocalState& state) { |
34 | // Note that setting the InferenceMode TLS in this function is ONLY ok because we always |
35 | // restore the dispatch key set TLS at the same time. |
36 | c10::AutogradState::set_tls_state(state.autograd_tls_); |
37 | |
38 | c10::impl::TorchDispatchModeTLS::set_state(state.torch_dispatch_mode_state_); |
39 | |
40 | at::impl::PythonTorchFunctionTLS::set_state(state.python_torch_function_state_); |
41 | |
42 | at::set_record_function_tls_(state.rf_tls_); |
43 | |
44 | at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_); |
45 | |
46 | c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_); |
47 | |
48 | c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_); |
49 | |
50 | c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_); |
51 | |
52 | functorch::setFuncTorchTLS(state.functorch_tls_); |
53 | |
54 | at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_); |
55 | |
56 | at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_); |
57 | } |
58 | |
59 | } // namespace at |
60 | |