1#pragma once
2
3#include <c10/core/SafePyObject.h>
4#include <c10/macros/Macros.h>
5
6namespace at {
7namespace impl {
8
9enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
10
11struct TORCH_API PythonTorchFunctionTLS {
12 static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
13 static TorchFunctionDisabledState get_disabled_state();
14
15 static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
16 static const std::shared_ptr<SafePyObject> pop_stack();
17 static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
18 static int64_t stack_len();
19
20 static const PythonTorchFunctionTLS& get_state();
21 static void set_state(const PythonTorchFunctionTLS& state);
22
23 private:
24 // The mode TLS is split into
25 // - disabled_state, which says which part of torch function are disabled
26 // - stack_, which is a vector of modes representing the stack of user
27 // defined modes
28 TorchFunctionDisabledState disabled_state_ =
29 TorchFunctionDisabledState::ENABLED;
30 std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
31};
32
33TORCH_API bool torch_function_mode_enabled();
34
35} // namespace impl
36} // namespace at
37