1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <c10/util/Optional.h> |
5 | #include <c10/util/python_stub.h> |
6 | #include <stack> |
7 | #include <string> |
8 | |
9 | #include <utility> |
10 | |
11 | namespace at { |
12 | |
13 | namespace impl { |
14 | |
15 | struct TORCH_API SavedTensorDefaultHooksTLS { |
16 | // PyObject is defined in c10/util/python_stub.h |
17 | std::stack<std::pair<PyObject*, PyObject*>> stack; |
18 | |
19 | // See NOTE: [Disabling SavedTensorDefaultHooks] for context |
20 | // NOTE: [disabled_error_message invariant] |
21 | // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled |
22 | // We did this for efficiency (so we didn't have to keep a separate bool |
23 | // around) |
24 | c10::optional<std::string> disabled_error_message; |
25 | }; |
26 | |
27 | } // namespace impl |
28 | |
29 | struct TORCH_API SavedTensorDefaultHooks { |
30 | static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook); |
31 | static void pop_hooks(); |
32 | static std::pair<PyObject*, PyObject*> get_hooks(); |
33 | static void lazy_initialize(); |
34 | static std::stack<std::pair<PyObject*, PyObject*>> get_stack(); |
35 | static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>); |
36 | |
37 | static const impl::SavedTensorDefaultHooksTLS& get_tls_state(); |
38 | static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls); |
39 | |
40 | // NOTE: [Disabling SavedTensorDefaultHooks] |
41 | // A developer of a PyTorch feature may choose to disable SavedTensorDefault |
42 | // hooks, especially if their feature does not work with it. If they are |
43 | // disabled, then the following will raise an error: |
44 | // - Attempting to push_hooks |
45 | // - calling disable(message) with a non-zero stack (from get_stack) size |
46 | static void disable(const std::string& error_message); |
47 | static void enable(); |
48 | static bool is_enabled(); |
49 | static const c10::optional<std::string>& get_disabled_error_message(); |
50 | }; |
51 | |
52 | } // namespace at |
53 | |