1#include <ATen/SavedTensorHooks.h>
2#include <c10/util/Exception.h>
3#include <stack>
4#include <utility>
5
6namespace at {
7
8namespace {
9 thread_local impl::SavedTensorDefaultHooksTLS tls;
10
11 // This flag is set to true the first time default hooks are registered
12 // and left at true for the rest of the execution.
13 // It's an optimization so that users who never use default hooks don't need to
14 // read the thread_local variables pack_hook_ and unpack_hook_.
15 static bool is_initialized(false);
16}
17
18static void assertSavedTensorHooksNotDisabled() {
19 TORCH_CHECK(SavedTensorDefaultHooks::is_enabled(), tls.disabled_error_message.value());
20}
21
22bool SavedTensorDefaultHooks::is_enabled() {
23 // See NOTE: [disabled_error_message invariant]
24 return !tls.disabled_error_message.has_value();
25}
26
27void SavedTensorDefaultHooks::disable(const std::string& message) {
28 tls.disabled_error_message = message;
29 if (!tls.stack.empty()) {
30 assertSavedTensorHooksNotDisabled();
31 }
32}
33
34void SavedTensorDefaultHooks::enable() {
35 tls.disabled_error_message = c10::nullopt;
36}
37
38const c10::optional<std::string>& SavedTensorDefaultHooks::get_disabled_error_message() {
39 return tls.disabled_error_message;
40}
41
42const impl::SavedTensorDefaultHooksTLS& SavedTensorDefaultHooks::get_tls_state() {
43 return tls;
44}
45
46void SavedTensorDefaultHooks::set_tls_state(const impl::SavedTensorDefaultHooksTLS& state) {
47 tls = state;
48}
49
50void SavedTensorDefaultHooks::lazy_initialize() {
51 is_initialized = true;
52}
53
54void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
55 // Reference counting is handled by the caller of `push_hooks`
56 TORCH_INTERNAL_ASSERT(is_initialized);
57 TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
58 assertSavedTensorHooksNotDisabled();
59 tls.stack.emplace(pack_hook, unpack_hook);
60}
61
62void SavedTensorDefaultHooks::pop_hooks() {
63 // Reference counting is handled by the caller of `pop_hooks`
64 TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty());
65 tls.stack.pop();
66}
67
68std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
69 if (!is_initialized || tls.stack.empty()) {
70 return std::make_pair(nullptr, nullptr);
71 }
72 return tls.stack.top();
73}
74
75std::stack<std::pair<PyObject*, PyObject*>> SavedTensorDefaultHooks::get_stack() {
76 return tls.stack;
77}
78
79void SavedTensorDefaultHooks::set_stack(std::stack<std::pair<PyObject*, PyObject*>> stack_) {
80 tls.stack = std::move(stack_);
81}
82
83}
84