1 | #include <torch/csrc/autograd/python_hook.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <pybind11/pybind11.h> |
5 | #include <torch/csrc/Exceptions.h> |
6 | #include <torch/csrc/THP.h> |
7 | #include <torch/csrc/autograd/python_variable.h> |
8 | #include <torch/csrc/utils/object_ptr.h> |
9 | #include <torch/csrc/utils/pybind.h> |
10 | #include <torch/csrc/utils/python_strings.h> |
11 | |
12 | #include <sstream> |
13 | |
14 | using torch::autograd::Variable; |
15 | using torch::autograd::variable_list; |
16 | |
17 | static PyObject* wrap_variables(const variable_list& c_variables); |
18 | static variable_list unwrap_variables(PyObject* py_variables); |
19 | static std::string hook_name(PyObject* hook); |
20 | static void check_result(PyObject* original, PyObject* result, PyObject* hook); |
21 | static void check_single_result( |
22 | PyObject* original, |
23 | PyObject* result, |
24 | PyObject* hook); |
25 | |
26 | namespace torch { |
27 | namespace autograd { |
28 | |
29 | namespace { |
30 | |
31 | // This function is called in 3 different cases: |
32 | // 1) TensorPreHook |
33 | // 2) PreHook |
34 | // 3) PostHook |
35 | // |
36 | // Depending on the case, args and res can hold different types of objects: |
37 | // |
38 | // args: |
39 | // TensorPreHook (Tensor,) |
40 | // PreHook ((Tensor, ...),) (grad_outputs,) |
41 | // PostHook ((Tensor, ...), (Tensor, ...)) (grad_inputs, grad_outputs) |
42 | // |
43 | // res: |
44 | // TensorPreHook Tensor |
45 | // PreHook ((Tensor, ...),) (grad_outputs,) |
46 | // PostHook ((Tensor, ...),) (grad_inputs,) |
47 | // |
48 | // This function returns True if any hook returned non-None value, and False |
49 | // otherwise. |
50 | bool _call_hooks(PyObject* dict, PyObject* args) { |
51 | // Note: [Extend Hook Lifetime] |
52 | // Hold a reference to hooks till we iterate over them. |
53 | // This is to handle the case when hook calls `handle.remove` inside it |
54 | // and it's refcount goes to `0`, Python is free to GC it. |
55 | // We hold onto a stale pointer and subsequent call to |
56 | // `check_single_result`, which tries to fetch the `hook`'s name segfaults. |
57 | // So, we use `PyDict_Values` which returns a new reference to the values |
58 | // i.e. we hold the reference to the hooks till we have iterated over them. |
59 | // Reference: https://github.com/pytorch/pytorch/issues/58354 |
60 | auto hooks = THPObjectPtr{PyDict_Values(dict)}; |
61 | bool is_modified = false; |
62 | const auto len = PyList_Size(hooks); |
63 | for (Py_ssize_t idx = 0; idx < len; ++idx) { |
64 | const auto hook = PyList_GetItem(hooks, idx); |
65 | |
66 | THPObjectPtr res(PyObject_CallObject(hook, args)); |
67 | if (!res) |
68 | throw python_error(); |
69 | if (res == Py_None) |
70 | continue; |
71 | |
72 | PyObject* args0 = PyTuple_GetItem(args, 0); |
73 | if (res == args0) |
74 | continue; |
75 | |
76 | if (PyTuple_CheckExact(args0)) { |
77 | check_result(args0, res, hook); |
78 | } else { |
79 | check_single_result(args0, res, hook); |
80 | } |
81 | PyTuple_SetItem(args, 0, res.release()); |
82 | |
83 | is_modified = true; |
84 | } |
85 | return is_modified; |
86 | } |
87 | |
88 | } // namespace |
89 | |
90 | PyFunctionTensorPreHook::PyFunctionTensorPreHook(PyObject* dict, int value_idx) |
91 | : dict(dict), value_idx(value_idx) { |
92 | Py_INCREF(dict); |
93 | } |
94 | |
95 | PyFunctionTensorPreHook::~PyFunctionTensorPreHook() { |
96 | // If python is already dead, leak the wrapped python objects |
97 | if (Py_IsInitialized()) { |
98 | pybind11::gil_scoped_acquire gil; |
99 | Py_DECREF(dict); |
100 | } |
101 | } |
102 | |
103 | auto PyFunctionTensorPreHook::operator()(const variable_list& values) |
104 | -> variable_list { |
105 | pybind11::gil_scoped_acquire gil; |
106 | THPObjectPtr value(THPVariable_Wrap(values.at(value_idx))); |
107 | if (!value) |
108 | throw python_error(); |
109 | THPObjectPtr tup(PyTuple_New(1)); |
110 | PyTuple_SET_ITEM(tup.get(), 0, value.release()); |
111 | bool is_tup_modified = _call_hooks(dict, tup.get()); |
112 | variable_list results(values); |
113 | if (is_tup_modified) { |
114 | results[value_idx] = THPVariable_Unpack(PyTuple_GetItem(tup.get(), 0)); |
115 | } |
116 | return results; |
117 | } |
118 | |
119 | PyFunctionPreHook::PyFunctionPreHook(PyObject* dict) : dict(dict) { |
120 | Py_INCREF(dict); |
121 | } |
122 | |
123 | PyFunctionPreHook::~PyFunctionPreHook() { |
124 | // If python is already dead, leak the wrapped python objects |
125 | if (Py_IsInitialized()) { |
126 | pybind11::gil_scoped_acquire gil; |
127 | Py_DECREF(dict); |
128 | } |
129 | } |
130 | |
131 | auto PyFunctionPreHook::operator()(const variable_list& grad_outputs_) |
132 | -> variable_list { |
133 | pybind11::gil_scoped_acquire gil; |
134 | THPObjectPtr grad_outputs(wrap_variables(grad_outputs_)); |
135 | THPObjectPtr tup(PyTuple_New(1)); |
136 | PyTuple_SET_ITEM(tup.get(), 0, grad_outputs.release()); |
137 | _call_hooks(dict, tup.get()); |
138 | return unwrap_variables(PyTuple_GetItem(tup.get(), 0)); |
139 | } |
140 | |
141 | PyFunctionPostHook::PyFunctionPostHook(PyObject* dict) : dict(dict) { |
142 | Py_INCREF(dict); |
143 | } |
144 | |
145 | PyFunctionPostHook::~PyFunctionPostHook() { |
146 | // If python is already dead, leak the wrapped python objects |
147 | if (Py_IsInitialized()) { |
148 | pybind11::gil_scoped_acquire gil; |
149 | Py_DECREF(dict); |
150 | } |
151 | } |
152 | |
153 | auto PyFunctionPostHook::operator()( |
154 | const variable_list& _outputs, /* grad_inputs */ |
155 | const variable_list& _inputs /* grad_outputs */) -> variable_list { |
156 | pybind11::gil_scoped_acquire gil; |
157 | THPObjectPtr grad_inputs(wrap_variables(_outputs)); |
158 | THPObjectPtr grad_outputs(wrap_variables(_inputs)); |
159 | THPObjectPtr tup(PyTuple_New(2)); |
160 | PyTuple_SET_ITEM(tup.get(), 0, grad_inputs.release()); |
161 | PyTuple_SET_ITEM(tup.get(), 1, grad_outputs.release()); |
162 | _call_hooks(dict, tup.get()); |
163 | return unwrap_variables(PyTuple_GetItem(tup.get(), 0)); |
164 | } |
165 | |
166 | } // namespace autograd |
167 | } // namespace torch |
168 | |
169 | static PyObject* wrap_variables(const variable_list& c_variables) { |
170 | size_t num_vars = c_variables.size(); |
171 | THPObjectPtr tuple(PyTuple_New(num_vars)); |
172 | if (!tuple) |
173 | throw python_error(); |
174 | for (const auto i : c10::irange(num_vars)) { |
175 | THPObjectPtr var(THPVariable_Wrap(c_variables[i])); |
176 | if (!var) |
177 | throw python_error(); |
178 | PyTuple_SET_ITEM(tuple.get(), i, var.release()); |
179 | } |
180 | return tuple.release(); |
181 | } |
182 | |
183 | static variable_list unwrap_variables(PyObject* py_variables) { |
184 | variable_list results(PyTuple_GET_SIZE(py_variables)); |
185 | for (const auto i : c10::irange(results.size())) { |
186 | PyObject* item = PyTuple_GET_ITEM(py_variables, i); |
187 | if (item == Py_None) { |
188 | continue; |
189 | } else if (THPVariable_Check(item)) { |
190 | results[i] = THPVariable_Unpack(item); |
191 | } else { |
192 | // this should never happen, but just in case... |
193 | std::stringstream ss; |
194 | ss << "expected variable but got " << Py_TYPE(item)->tp_name; |
195 | throw std::runtime_error(ss.str()); |
196 | } |
197 | } |
198 | return results; |
199 | } |
200 | |
201 | static void check_result(PyObject* prev, PyObject* result, PyObject* hook) { |
202 | if (!PyTuple_Check(result)) { |
203 | PyErr_Format( |
204 | PyExc_TypeError, |
205 | "expected tuple, but hook returned '%s'" , |
206 | THPUtils_typename(result)); |
207 | throw python_error(); |
208 | } |
209 | |
210 | auto prev_size = PyTuple_GET_SIZE(prev); |
211 | auto result_size = PyTuple_GET_SIZE(result); |
212 | if (prev_size != result_size) { |
213 | std::stringstream ss; |
214 | auto name = hook_name(hook); |
215 | ss << "hook '" << name << "' has returned an incorrect number " ; |
216 | ss << "of values (got " << result_size << ", but expected " ; |
217 | ss << prev_size << ")" ; |
218 | throw std::runtime_error(ss.str()); |
219 | } |
220 | |
221 | for (const auto i : c10::irange(prev_size)) { |
222 | check_single_result( |
223 | PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook); |
224 | } |
225 | } |
226 | |
227 | static void check_single_result( |
228 | PyObject* _original, |
229 | PyObject* _result, |
230 | PyObject* hook) { |
231 | if (_result == Py_None) |
232 | return; |
233 | |
234 | if (_original == Py_None) { |
235 | throw std::runtime_error( |
236 | "can't replace a None gradient with a non-None value" ); |
237 | } |
238 | |
239 | if (!PyObject_IsInstance(_result, THPVariableClass)) { |
240 | PyErr_Format( |
241 | PyExc_TypeError, |
242 | "expected Variable, but hook returned '%s'" , |
243 | THPUtils_typename(_result)); |
244 | throw python_error(); |
245 | } |
246 | |
247 | const auto& original = THPVariable_Unpack(_original); |
248 | const auto& result = THPVariable_Unpack(_result); |
249 | |
250 | torch::autograd::check_variable_result(original, result, hook_name(hook)); |
251 | } |
252 | |
253 | static std::string hook_name(PyObject* hook) { |
254 | if (PyObject_HasAttrString(hook, "__name__" )) { |
255 | THPObjectPtr name(PyObject_GetAttrString(hook, "__name__" )); |
256 | if (!name) |
257 | throw python_error(); |
258 | |
259 | if (name && THPUtils_checkString(name.get())) { |
260 | return THPUtils_unpackString(name.get()); |
261 | } |
262 | } |
263 | return "<unknown>" ; |
264 | } |
265 | |