1 | #include <ATen/SavedTensorHooks.h> |
2 | #include <torch/csrc/autograd/python_saved_variable_hooks.h> |
3 | |
4 | #include <torch/csrc/THP.h> |
5 | |
6 | namespace py = pybind11; |
7 | |
8 | namespace torch { |
9 | namespace autograd { |
10 | PySavedVariableHooks::PySavedVariableHooks( |
11 | py::function& pack_hook, |
12 | py::function& unpack_hook) |
13 | : // steals the reference (we will decref ourselves) |
14 | pack_hook_(pack_hook.release().ptr()), |
15 | unpack_hook_(unpack_hook.release().ptr()) {} |
16 | |
17 | // We don't use pybind for call_pack_hook and call_unpack_hook to avoid |
18 | // https://github.com/pytorch/pytorch/issues/34172 |
19 | void PySavedVariableHooks::call_pack_hook(const at::Tensor& tensor) { |
20 | py::gil_scoped_acquire acquire; |
21 | THPObjectPtr obj(THPVariable_Wrap(tensor)); |
22 | THPObjectPtr packed( |
23 | PyObject_CallFunctionObjArgs(pack_hook_, obj.get(), nullptr)); |
24 | if (!packed) { |
25 | throw python_error(); |
26 | } |
27 | data_ = packed.release(); |
28 | // obj is decrefed on exit, packed has their references stolen |
29 | // pack_hook_ and data_ will be manually decrefed when the saved variable is |
30 | // released |
31 | } |
32 | |
33 | at::Tensor PySavedVariableHooks::call_unpack_hook() { |
34 | py::gil_scoped_acquire acquire; |
35 | THPObjectPtr res(PyObject_CallFunctionObjArgs(unpack_hook_, data_, nullptr)); |
36 | if (!res) { |
37 | throw python_error(); |
38 | } |
39 | TORCH_CHECK_TYPE( |
40 | THPVariable_Check(res), |
41 | "Output of saved tensor unpack_hook expected to be a Tensor but got result of type " , |
42 | THPUtils_typename(res)); |
43 | return THPVariable_Unpack(res); |
44 | // res is decrefed on exit |
45 | // unpack_hook_ will be manually decrefed when the saved variable is released |
46 | } |
47 | |
48 | PySavedVariableHooks::~PySavedVariableHooks() { |
49 | // If python is already dead, leak the wrapped python objects |
50 | if (Py_IsInitialized()) { |
51 | py::gil_scoped_acquire gil; |
52 | Py_XDECREF(pack_hook_); |
53 | Py_XDECREF(unpack_hook_); |
54 | Py_XDECREF(data_); |
55 | } |
56 | } |
57 | |
58 | void PyDefaultSavedVariableHooks::push_hooks( |
59 | py::function& pack_hook, |
60 | py::function& unpack_hook) { |
61 | at::SavedTensorDefaultHooks::lazy_initialize(); |
62 | at::SavedTensorDefaultHooks::push_hooks( |
63 | pack_hook.release().ptr(), unpack_hook.release().ptr()); |
64 | } |
65 | |
66 | void PyDefaultSavedVariableHooks::pop_hooks() { |
67 | PyObject *pack_hook(nullptr), *unpack_hook(nullptr); |
68 | std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); |
69 | TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); |
70 | if (Py_IsInitialized()) { |
71 | py::gil_scoped_acquire gil; |
72 | Py_XDECREF(pack_hook); |
73 | Py_XDECREF(unpack_hook); |
74 | } |
75 | at::SavedTensorDefaultHooks::pop_hooks(); |
76 | } |
77 | |
78 | std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() { |
79 | PyObject *pack_hook(nullptr), *unpack_hook(nullptr); |
80 | std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); |
81 | if (!pack_hook || !unpack_hook) { |
82 | return nullptr; |
83 | } |
84 | py::gil_scoped_acquire gil; |
85 | py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook); |
86 | py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook); |
87 | return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_); |
88 | } |
89 | |
90 | } // namespace autograd |
91 | } // namespace torch |
92 | |