1 | #pragma once |
2 | |
3 | #include <c10/macros/Macros.h> |
4 | #include <memory> |
5 | |
6 | namespace at { |
7 | namespace functorch { |
8 | |
9 | // NOTE [functorch TLS in pytorch/pytorch] |
10 | // |
11 | // functorch lives out-of-tree. However, it has some TLS that needs to be |
12 | // propagated. The solution for that is we store a pointer to the TLS |
13 | // inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to |
14 | // include whatever functorch needs. |
15 | // |
16 | // We need to store a pointer due to the indirection: |
17 | // inside functorch, we will create a subclass of FunctorchTLSBase called |
18 | // FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack. |
19 | // FuncTorchTLSBase doesn't have any metadata because it hasn't been defined |
20 | // yet. |
21 | // |
22 | // Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside |
23 | // functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*. |
24 | // We can't directly pass around FunctorchTLSBase (without a pointer) because |
25 | // FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having |
26 | // more elements. |
27 | struct TORCH_API FuncTorchTLSBase { |
28 | virtual ~FuncTorchTLSBase() = default; |
29 | virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0; |
30 | |
31 | virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0; |
32 | virtual void checkSupportsInplaceRequiresGrad() const = 0; |
33 | virtual void checkSupportsRetainGrad() const = 0; |
34 | }; |
35 | |
36 | // returns deepcopy of the functorch tls |
37 | TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS(); |
38 | |
39 | // sets the functorch tls. always does a deep copy. |
40 | TORCH_API void setFuncTorchTLS( |
41 | const std::shared_ptr<const FuncTorchTLSBase>& state); |
42 | |
43 | // get a mutable reference to the functorch tls |
44 | TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor(); |
45 | |
46 | } // namespace functorch |
47 | } // namespace at |
48 | |