1 | #include <ATen/FuncTorchTLS.h> |
---|---|
2 | |
3 | namespace at { namespace functorch { |
4 | |
5 | namespace { |
6 | |
7 | thread_local std::unique_ptr<FuncTorchTLSBase> kFuncTorchTLS = nullptr; |
8 | |
9 | } |
10 | |
11 | std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS() { |
12 | if (kFuncTorchTLS == nullptr) { |
13 | return nullptr; |
14 | } |
15 | return kFuncTorchTLS->deepcopy(); |
16 | } |
17 | |
18 | void setFuncTorchTLS(const std::shared_ptr<const FuncTorchTLSBase>& state) { |
19 | if (state == nullptr) { |
20 | kFuncTorchTLS = nullptr; |
21 | return; |
22 | } |
23 | kFuncTorchTLS = state->deepcopy(); |
24 | } |
25 | |
26 | std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor() { |
27 | return kFuncTorchTLS; |
28 | } |
29 | |
30 | |
31 | }} |
32 |