1#include <torch/csrc/autograd/python_hook.h>
2
3#include <c10/util/irange.h>
4#include <pybind11/pybind11.h>
5#include <torch/csrc/Exceptions.h>
6#include <torch/csrc/THP.h>
7#include <torch/csrc/autograd/python_variable.h>
8#include <torch/csrc/utils/object_ptr.h>
9#include <torch/csrc/utils/pybind.h>
10#include <torch/csrc/utils/python_strings.h>
11
12#include <sstream>
13
14using torch::autograd::Variable;
15using torch::autograd::variable_list;
16
17static PyObject* wrap_variables(const variable_list& c_variables);
18static variable_list unwrap_variables(PyObject* py_variables);
19static std::string hook_name(PyObject* hook);
20static void check_result(PyObject* original, PyObject* result, PyObject* hook);
21static void check_single_result(
22 PyObject* original,
23 PyObject* result,
24 PyObject* hook);
25
26namespace torch {
27namespace autograd {
28
29namespace {
30
31// This function is called in 3 different cases:
32// 1) TensorPreHook
33// 2) PreHook
34// 3) PostHook
35//
36// Depending on the case, args and res can hold different types of objects:
37//
38// args:
39// TensorPreHook (Tensor,)
40// PreHook ((Tensor, ...),) (grad_outputs,)
41// PostHook ((Tensor, ...), (Tensor, ...)) (grad_inputs, grad_outputs)
42//
43// res:
44// TensorPreHook Tensor
45// PreHook ((Tensor, ...),) (grad_outputs,)
46// PostHook ((Tensor, ...),) (grad_inputs,)
47//
48// This function returns True if any hook returned non-None value, and False
49// otherwise.
50bool _call_hooks(PyObject* dict, PyObject* args) {
51 // Note: [Extend Hook Lifetime]
52 // Hold a reference to hooks till we iterate over them.
53 // This is to handle the case when hook calls `handle.remove` inside it
54 // and it's refcount goes to `0`, Python is free to GC it.
55 // We hold onto a stale pointer and subsequent call to
56 // `check_single_result`, which tries to fetch the `hook`'s name segfaults.
57 // So, we use `PyDict_Values` which returns a new reference to the values
58 // i.e. we hold the reference to the hooks till we have iterated over them.
59 // Reference: https://github.com/pytorch/pytorch/issues/58354
60 auto hooks = THPObjectPtr{PyDict_Values(dict)};
61 bool is_modified = false;
62 const auto len = PyList_Size(hooks);
63 for (Py_ssize_t idx = 0; idx < len; ++idx) {
64 const auto hook = PyList_GetItem(hooks, idx);
65
66 THPObjectPtr res(PyObject_CallObject(hook, args));
67 if (!res)
68 throw python_error();
69 if (res == Py_None)
70 continue;
71
72 PyObject* args0 = PyTuple_GetItem(args, 0);
73 if (res == args0)
74 continue;
75
76 if (PyTuple_CheckExact(args0)) {
77 check_result(args0, res, hook);
78 } else {
79 check_single_result(args0, res, hook);
80 }
81 PyTuple_SetItem(args, 0, res.release());
82
83 is_modified = true;
84 }
85 return is_modified;
86}
87
88} // namespace
89
90PyFunctionTensorPreHook::PyFunctionTensorPreHook(PyObject* dict, int value_idx)
91 : dict(dict), value_idx(value_idx) {
92 Py_INCREF(dict);
93}
94
95PyFunctionTensorPreHook::~PyFunctionTensorPreHook() {
96 // If python is already dead, leak the wrapped python objects
97 if (Py_IsInitialized()) {
98 pybind11::gil_scoped_acquire gil;
99 Py_DECREF(dict);
100 }
101}
102
103auto PyFunctionTensorPreHook::operator()(const variable_list& values)
104 -> variable_list {
105 pybind11::gil_scoped_acquire gil;
106 THPObjectPtr value(THPVariable_Wrap(values.at(value_idx)));
107 if (!value)
108 throw python_error();
109 THPObjectPtr tup(PyTuple_New(1));
110 PyTuple_SET_ITEM(tup.get(), 0, value.release());
111 bool is_tup_modified = _call_hooks(dict, tup.get());
112 variable_list results(values);
113 if (is_tup_modified) {
114 results[value_idx] = THPVariable_Unpack(PyTuple_GetItem(tup.get(), 0));
115 }
116 return results;
117}
118
119PyFunctionPreHook::PyFunctionPreHook(PyObject* dict) : dict(dict) {
120 Py_INCREF(dict);
121}
122
123PyFunctionPreHook::~PyFunctionPreHook() {
124 // If python is already dead, leak the wrapped python objects
125 if (Py_IsInitialized()) {
126 pybind11::gil_scoped_acquire gil;
127 Py_DECREF(dict);
128 }
129}
130
131auto PyFunctionPreHook::operator()(const variable_list& grad_outputs_)
132 -> variable_list {
133 pybind11::gil_scoped_acquire gil;
134 THPObjectPtr grad_outputs(wrap_variables(grad_outputs_));
135 THPObjectPtr tup(PyTuple_New(1));
136 PyTuple_SET_ITEM(tup.get(), 0, grad_outputs.release());
137 _call_hooks(dict, tup.get());
138 return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
139}
140
141PyFunctionPostHook::PyFunctionPostHook(PyObject* dict) : dict(dict) {
142 Py_INCREF(dict);
143}
144
145PyFunctionPostHook::~PyFunctionPostHook() {
146 // If python is already dead, leak the wrapped python objects
147 if (Py_IsInitialized()) {
148 pybind11::gil_scoped_acquire gil;
149 Py_DECREF(dict);
150 }
151}
152
153auto PyFunctionPostHook::operator()(
154 const variable_list& _outputs, /* grad_inputs */
155 const variable_list& _inputs /* grad_outputs */) -> variable_list {
156 pybind11::gil_scoped_acquire gil;
157 THPObjectPtr grad_inputs(wrap_variables(_outputs));
158 THPObjectPtr grad_outputs(wrap_variables(_inputs));
159 THPObjectPtr tup(PyTuple_New(2));
160 PyTuple_SET_ITEM(tup.get(), 0, grad_inputs.release());
161 PyTuple_SET_ITEM(tup.get(), 1, grad_outputs.release());
162 _call_hooks(dict, tup.get());
163 return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
164}
165
166} // namespace autograd
167} // namespace torch
168
169static PyObject* wrap_variables(const variable_list& c_variables) {
170 size_t num_vars = c_variables.size();
171 THPObjectPtr tuple(PyTuple_New(num_vars));
172 if (!tuple)
173 throw python_error();
174 for (const auto i : c10::irange(num_vars)) {
175 THPObjectPtr var(THPVariable_Wrap(c_variables[i]));
176 if (!var)
177 throw python_error();
178 PyTuple_SET_ITEM(tuple.get(), i, var.release());
179 }
180 return tuple.release();
181}
182
183static variable_list unwrap_variables(PyObject* py_variables) {
184 variable_list results(PyTuple_GET_SIZE(py_variables));
185 for (const auto i : c10::irange(results.size())) {
186 PyObject* item = PyTuple_GET_ITEM(py_variables, i);
187 if (item == Py_None) {
188 continue;
189 } else if (THPVariable_Check(item)) {
190 results[i] = THPVariable_Unpack(item);
191 } else {
192 // this should never happen, but just in case...
193 std::stringstream ss;
194 ss << "expected variable but got " << Py_TYPE(item)->tp_name;
195 throw std::runtime_error(ss.str());
196 }
197 }
198 return results;
199}
200
201static void check_result(PyObject* prev, PyObject* result, PyObject* hook) {
202 if (!PyTuple_Check(result)) {
203 PyErr_Format(
204 PyExc_TypeError,
205 "expected tuple, but hook returned '%s'",
206 THPUtils_typename(result));
207 throw python_error();
208 }
209
210 auto prev_size = PyTuple_GET_SIZE(prev);
211 auto result_size = PyTuple_GET_SIZE(result);
212 if (prev_size != result_size) {
213 std::stringstream ss;
214 auto name = hook_name(hook);
215 ss << "hook '" << name << "' has returned an incorrect number ";
216 ss << "of values (got " << result_size << ", but expected ";
217 ss << prev_size << ")";
218 throw std::runtime_error(ss.str());
219 }
220
221 for (const auto i : c10::irange(prev_size)) {
222 check_single_result(
223 PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook);
224 }
225}
226
227static void check_single_result(
228 PyObject* _original,
229 PyObject* _result,
230 PyObject* hook) {
231 if (_result == Py_None)
232 return;
233
234 if (_original == Py_None) {
235 throw std::runtime_error(
236 "can't replace a None gradient with a non-None value");
237 }
238
239 if (!PyObject_IsInstance(_result, THPVariableClass)) {
240 PyErr_Format(
241 PyExc_TypeError,
242 "expected Variable, but hook returned '%s'",
243 THPUtils_typename(_result));
244 throw python_error();
245 }
246
247 const auto& original = THPVariable_Unpack(_original);
248 const auto& result = THPVariable_Unpack(_result);
249
250 torch::autograd::check_variable_result(original, result, hook_name(hook));
251}
252
253static std::string hook_name(PyObject* hook) {
254 if (PyObject_HasAttrString(hook, "__name__")) {
255 THPObjectPtr name(PyObject_GetAttrString(hook, "__name__"));
256 if (!name)
257 throw python_error();
258
259 if (name && THPUtils_checkString(name.get())) {
260 return THPUtils_unpackString(name.get());
261 }
262 }
263 return "<unknown>";
264}
265