1#pragma once
2
3#include <c10/macros/Macros.h>
4#include <memory>
5
6namespace at {
7namespace functorch {
8
9// NOTE [functorch TLS in pytorch/pytorch]
10//
11// functorch lives out-of-tree. However, it has some TLS that needs to be
12// propagated. The solution for that is we store a pointer to the TLS
13// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
14// include whatever functorch needs.
15//
16// We need to store a pointer due to the indirection:
17// inside functorch, we will create a subclass of FunctorchTLSBase called
18// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
19// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
20// yet.
21//
22// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
23// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
24// We can't directly pass around FunctorchTLSBase (without a pointer) because
25// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
26// more elements.
27struct TORCH_API FuncTorchTLSBase {
28 virtual ~FuncTorchTLSBase() = default;
29 virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
30
31 virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
32 virtual void checkSupportsInplaceRequiresGrad() const = 0;
33 virtual void checkSupportsRetainGrad() const = 0;
34};
35
36// returns deepcopy of the functorch tls
37TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
38
39// sets the functorch tls. always does a deep copy.
40TORCH_API void setFuncTorchTLS(
41 const std::shared_ptr<const FuncTorchTLSBase>& state);
42
43// get a mutable reference to the functorch tls
44TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
45
46} // namespace functorch
47} // namespace at
48