1#include <ATen/SavedTensorHooks.h>
2#include <torch/csrc/autograd/python_saved_variable_hooks.h>
3
4#include <torch/csrc/THP.h>
5
6namespace py = pybind11;
7
8namespace torch {
9namespace autograd {
10PySavedVariableHooks::PySavedVariableHooks(
11 py::function& pack_hook,
12 py::function& unpack_hook)
13 : // steals the reference (we will decref ourselves)
14 pack_hook_(pack_hook.release().ptr()),
15 unpack_hook_(unpack_hook.release().ptr()) {}
16
17// We don't use pybind for call_pack_hook and call_unpack_hook to avoid
18// https://github.com/pytorch/pytorch/issues/34172
19void PySavedVariableHooks::call_pack_hook(const at::Tensor& tensor) {
20 py::gil_scoped_acquire acquire;
21 THPObjectPtr obj(THPVariable_Wrap(tensor));
22 THPObjectPtr packed(
23 PyObject_CallFunctionObjArgs(pack_hook_, obj.get(), nullptr));
24 if (!packed) {
25 throw python_error();
26 }
27 data_ = packed.release();
28 // obj is decrefed on exit, packed has their references stolen
29 // pack_hook_ and data_ will be manually decrefed when the saved variable is
30 // released
31}
32
33at::Tensor PySavedVariableHooks::call_unpack_hook() {
34 py::gil_scoped_acquire acquire;
35 THPObjectPtr res(PyObject_CallFunctionObjArgs(unpack_hook_, data_, nullptr));
36 if (!res) {
37 throw python_error();
38 }
39 TORCH_CHECK_TYPE(
40 THPVariable_Check(res),
41 "Output of saved tensor unpack_hook expected to be a Tensor but got result of type ",
42 THPUtils_typename(res));
43 return THPVariable_Unpack(res);
44 // res is decrefed on exit
45 // unpack_hook_ will be manually decrefed when the saved variable is released
46}
47
48PySavedVariableHooks::~PySavedVariableHooks() {
49 // If python is already dead, leak the wrapped python objects
50 if (Py_IsInitialized()) {
51 py::gil_scoped_acquire gil;
52 Py_XDECREF(pack_hook_);
53 Py_XDECREF(unpack_hook_);
54 Py_XDECREF(data_);
55 }
56}
57
58void PyDefaultSavedVariableHooks::push_hooks(
59 py::function& pack_hook,
60 py::function& unpack_hook) {
61 at::SavedTensorDefaultHooks::lazy_initialize();
62 at::SavedTensorDefaultHooks::push_hooks(
63 pack_hook.release().ptr(), unpack_hook.release().ptr());
64}
65
66void PyDefaultSavedVariableHooks::pop_hooks() {
67 PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
68 std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
69 TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
70 if (Py_IsInitialized()) {
71 py::gil_scoped_acquire gil;
72 Py_XDECREF(pack_hook);
73 Py_XDECREF(unpack_hook);
74 }
75 at::SavedTensorDefaultHooks::pop_hooks();
76}
77
78std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
79 PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
80 std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
81 if (!pack_hook || !unpack_hook) {
82 return nullptr;
83 }
84 py::gil_scoped_acquire gil;
85 py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook);
86 py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook);
87 return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_);
88}
89
90} // namespace autograd
91} // namespace torch
92