1 | #pragma once |
---|---|
2 | |
3 | #include <c10/core/SafePyObject.h> |
4 | #include <c10/macros/Macros.h> |
5 | |
6 | namespace at { |
7 | namespace impl { |
8 | |
9 | enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; |
10 | |
11 | struct TORCH_API PythonTorchFunctionTLS { |
12 | static void set_disabled_state(TorchFunctionDisabledState disabled_state_); |
13 | static TorchFunctionDisabledState get_disabled_state(); |
14 | |
15 | static void push_onto_stack(std::shared_ptr<SafePyObject> mode); |
16 | static const std::shared_ptr<SafePyObject> pop_stack(); |
17 | static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx); |
18 | static int64_t stack_len(); |
19 | |
20 | static const PythonTorchFunctionTLS& get_state(); |
21 | static void set_state(const PythonTorchFunctionTLS& state); |
22 | |
23 | private: |
24 | // The mode TLS is split into |
25 | // - disabled_state, which says which part of torch function are disabled |
26 | // - stack_, which is a vector of modes representing the stack of user |
27 | // defined modes |
28 | TorchFunctionDisabledState disabled_state_ = |
29 | TorchFunctionDisabledState::ENABLED; |
30 | std::vector<std::shared_ptr<c10::SafePyObject>> stack_; |
31 | }; |
32 | |
33 | TORCH_API bool torch_function_mode_enabled(); |
34 | |
35 | } // namespace impl |
36 | } // namespace at |
37 |