1#include <ATen/FuncTorchTLS.h>
2
3namespace at { namespace functorch {
4
5namespace {
6
7thread_local std::unique_ptr<FuncTorchTLSBase> kFuncTorchTLS = nullptr;
8
9}
10
11std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS() {
12 if (kFuncTorchTLS == nullptr) {
13 return nullptr;
14 }
15 return kFuncTorchTLS->deepcopy();
16}
17
18void setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase>& state) {
19 if (state == nullptr) {
20 kFuncTorchTLS = nullptr;
21 return;
22 }
23 kFuncTorchTLS = state->deepcopy();
24}
25
26std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor() {
27 return kFuncTorchTLS;
28}
29
30
31}}
32