1#include <ATen/PythonTorchFunctionTLS.h>
2#include <c10/core/TensorImpl.h>
3
4namespace at {
5namespace impl {
6
7static thread_local PythonTorchFunctionTLS pythonTorchFunctionState;
8
9void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
10 pythonTorchFunctionState.stack_.push_back(std::move(mode));
11}
12
13const std::shared_ptr<SafePyObject> PythonTorchFunctionTLS::pop_stack() {
14 TORCH_CHECK(!pythonTorchFunctionState.stack_.empty(), "trying to pop from empty mode stack");
15 auto out = pythonTorchFunctionState.stack_.back();
16 pythonTorchFunctionState.stack_.pop_back();
17 return out;
18}
19
20const std::shared_ptr<SafePyObject>& PythonTorchFunctionTLS::get_stack_at(int64_t idx) {
21 TORCH_CHECK(idx < static_cast<int64_t>(pythonTorchFunctionState.stack_.size()), "Tried to get stack at idx that's too big");
22 return pythonTorchFunctionState.stack_[idx];
23}
24
25int64_t PythonTorchFunctionTLS::stack_len() {
26 return pythonTorchFunctionState.stack_.size();
27}
28
29void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) {
30 pythonTorchFunctionState.disabled_state_ = disabled_state;
31}
32
33TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() {
34 return pythonTorchFunctionState.disabled_state_;
35}
36
37void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) {
38 pythonTorchFunctionState = state;
39}
40
41const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
42 return pythonTorchFunctionState;
43}
44
45bool torch_function_mode_enabled() {
46 return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
47 PythonTorchFunctionTLS::stack_len() > 0;
48}
49
50} // namespace impl
51} // namespace at
52