1 | #include <ATen/SavedTensorHooks.h> |
---|---|
2 | #include <c10/util/Exception.h> |
3 | #include <stack> |
4 | #include <utility> |
5 | |
6 | namespace at { |
7 | |
8 | namespace { |
9 | thread_local impl::SavedTensorDefaultHooksTLS tls; |
10 | |
11 | // This flag is set to true the first time default hooks are registered |
12 | // and left at true for the rest of the execution. |
13 | // It's an optimization so that users who never use default hooks don't need to |
14 | // read the thread_local variables pack_hook_ and unpack_hook_. |
15 | static bool is_initialized(false); |
16 | } |
17 | |
18 | static void assertSavedTensorHooksNotDisabled() { |
19 | TORCH_CHECK(SavedTensorDefaultHooks::is_enabled(), tls.disabled_error_message.value()); |
20 | } |
21 | |
22 | bool SavedTensorDefaultHooks::is_enabled() { |
23 | // See NOTE: [disabled_error_message invariant] |
24 | return !tls.disabled_error_message.has_value(); |
25 | } |
26 | |
27 | void SavedTensorDefaultHooks::disable(const std::string& message) { |
28 | tls.disabled_error_message = message; |
29 | if (!tls.stack.empty()) { |
30 | assertSavedTensorHooksNotDisabled(); |
31 | } |
32 | } |
33 | |
34 | void SavedTensorDefaultHooks::enable() { |
35 | tls.disabled_error_message = c10::nullopt; |
36 | } |
37 | |
38 | const c10::optional<std::string>& SavedTensorDefaultHooks::get_disabled_error_message() { |
39 | return tls.disabled_error_message; |
40 | } |
41 | |
42 | const impl::SavedTensorDefaultHooksTLS& SavedTensorDefaultHooks::get_tls_state() { |
43 | return tls; |
44 | } |
45 | |
46 | void SavedTensorDefaultHooks::set_tls_state(const impl::SavedTensorDefaultHooksTLS& state) { |
47 | tls = state; |
48 | } |
49 | |
50 | void SavedTensorDefaultHooks::lazy_initialize() { |
51 | is_initialized = true; |
52 | } |
53 | |
54 | void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) { |
55 | // Reference counting is handled by the caller of `push_hooks` |
56 | TORCH_INTERNAL_ASSERT(is_initialized); |
57 | TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); |
58 | assertSavedTensorHooksNotDisabled(); |
59 | tls.stack.emplace(pack_hook, unpack_hook); |
60 | } |
61 | |
62 | void SavedTensorDefaultHooks::pop_hooks() { |
63 | // Reference counting is handled by the caller of `pop_hooks` |
64 | TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty()); |
65 | tls.stack.pop(); |
66 | } |
67 | |
68 | std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() { |
69 | if (!is_initialized || tls.stack.empty()) { |
70 | return std::make_pair(nullptr, nullptr); |
71 | } |
72 | return tls.stack.top(); |
73 | } |
74 | |
75 | std::stack<std::pair<PyObject*, PyObject*>> SavedTensorDefaultHooks::get_stack() { |
76 | return tls.stack; |
77 | } |
78 | |
79 | void SavedTensorDefaultHooks::set_stack(std::stack<std::pair<PyObject*, PyObject*>> stack_) { |
80 | tls.stack = std::move(stack_); |
81 | } |
82 | |
83 | } |
84 |