1#pragma once
2
3#include <c10/macros/Export.h>
4#include <c10/util/Optional.h>
5#include <c10/util/python_stub.h>
6#include <stack>
7#include <string>
8
9#include <utility>
10
11namespace at {
12
13namespace impl {
14
15struct TORCH_API SavedTensorDefaultHooksTLS {
16 // PyObject is defined in c10/util/python_stub.h
17 std::stack<std::pair<PyObject*, PyObject*>> stack;
18
19 // See NOTE: [Disabling SavedTensorDefaultHooks] for context
20 // NOTE: [disabled_error_message invariant]
21 // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
22 // We did this for efficiency (so we didn't have to keep a separate bool
23 // around)
24 c10::optional<std::string> disabled_error_message;
25};
26
27} // namespace impl
28
29struct TORCH_API SavedTensorDefaultHooks {
30 static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
31 static void pop_hooks();
32 static std::pair<PyObject*, PyObject*> get_hooks();
33 static void lazy_initialize();
34 static std::stack<std::pair<PyObject*, PyObject*>> get_stack();
35 static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>);
36
37 static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
38 static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
39
40 // NOTE: [Disabling SavedTensorDefaultHooks]
41 // A developer of a PyTorch feature may choose to disable SavedTensorDefault
42 // hooks, especially if their feature does not work with it. If they are
43 // disabled, then the following will raise an error:
44 // - Attempting to push_hooks
45 // - calling disable(message) with a non-zero stack (from get_stack) size
46 static void disable(const std::string& error_message);
47 static void enable();
48 static bool is_enabled();
49 static const c10::optional<std::string>& get_disabled_error_message();
50};
51
52} // namespace at
53