1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/autograd/python_cpp_function.h> |
3 | |
4 | #include <torch/csrc/python_headers.h> |
5 | #include <cstdio> |
6 | #include <memory> |
7 | #include <typeindex> |
8 | #include <unordered_map> |
9 | |
10 | #include <pybind11/pybind11.h> |
11 | #include <torch/csrc/DynamicTypes.h> |
12 | #include <torch/csrc/Exceptions.h> |
13 | #include <torch/csrc/autograd/python_anomaly_mode.h> |
14 | #include <torch/csrc/autograd/python_function.h> |
15 | #include <torch/csrc/autograd/python_hook.h> |
16 | #include <torch/csrc/autograd/python_variable.h> |
17 | #include <torch/csrc/utils/pybind.h> |
18 | #include <torch/csrc/utils/python_numbers.h> |
19 | #include <torch/csrc/utils/python_strings.h> |
20 | |
21 | using namespace torch::autograd; |
22 | |
23 | namespace torch { |
24 | namespace autograd { |
25 | |
26 | namespace { |
27 | |
28 | PyObject* THPCppFunction_call( |
29 | PyObject* self, |
30 | PyObject* args, |
31 | PyObject* kwargs) { |
32 | if (kwargs && PyDict_Size(kwargs) != 0) { |
33 | return PyErr_Format(PyExc_TypeError, "keyword arguments are not supported" ); |
34 | } |
35 | |
36 | int num_inputs = PyTuple_GET_SIZE(args); |
37 | int num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs(); |
38 | if (num_inputs != num_inputs_required) { |
39 | return PyErr_Format( |
40 | PyExc_TypeError, |
41 | "expected %d arguments, got %d instead" , |
42 | num_inputs_required, |
43 | num_inputs); |
44 | } |
45 | variable_list vars(num_inputs); |
46 | for (int i = 0; i != num_inputs; ++i) { |
47 | PyObject* arg = PyTuple_GET_ITEM(args, i); |
48 | if (arg == Py_None) { |
49 | continue; |
50 | } |
51 | if (!THPVariable_Check(arg)) { |
52 | return PyErr_Format(PyExc_TypeError, "argument %d is not a Variable" , i); |
53 | } |
54 | vars[i] = THPVariable_Unpack(arg); |
55 | } |
56 | |
57 | variable_list output; |
58 | |
59 | HANDLE_TH_ERRORS { |
60 | pybind11::gil_scoped_release nogil; |
61 | output = (*((THPCppFunction*)self)->cdata)(std::move(vars)); |
62 | } |
63 | END_HANDLE_TH_ERRORS |
64 | |
65 | int num_outputs = output.size(); |
66 | if (num_outputs == 1) { |
67 | // assume we want to unpack one element tuples for now |
68 | return THPVariable_Wrap(output[0]); |
69 | } |
70 | |
71 | THPObjectPtr tuple(PyTuple_New(num_outputs)); |
72 | for (int i = 0; i != num_outputs; ++i) { |
73 | PyTuple_SET_ITEM(tuple.get(), i, THPVariable_Wrap(output[i])); |
74 | } |
75 | return tuple.release(); |
76 | } |
77 | |
78 | int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) { |
79 | auto& fn = *((THPCppFunction*)self)->cdata; |
80 | for (const auto& hook : fn.tensor_pre_hooks()) { |
81 | if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) { |
82 | Py_VISIT(pyhook->dict); |
83 | } |
84 | } |
85 | // NOTE [retains_grad_hook PyObject traversal] |
86 | // In theory this shouldn't be necessary, because retains_grad_hooks should |
87 | // not contain any PyFunctionTensorPreHooks. The alternative is to have a |
88 | // check that actually guarantees this. |
89 | for (const auto& pair : fn.retains_grad_hooks()) { |
90 | if (auto pyhook = |
91 | dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) { |
92 | Py_VISIT(pyhook->dict); |
93 | } |
94 | } |
95 | for (const auto& hook : fn.pre_hooks()) { |
96 | if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) { |
97 | Py_VISIT(pyhook->dict); |
98 | } |
99 | } |
100 | for (const auto& hook : fn.post_hooks()) { |
101 | if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) { |
102 | Py_VISIT(pyhook->dict); |
103 | } |
104 | } |
105 | return 0; |
106 | } |
107 | |
108 | int THPCppFunction_clear(PyObject* self) { |
109 | auto f = (THPCppFunction*)self; |
110 | // Remove the weak ref of the c++ object if it exist |
111 | if (f->cdata) { |
112 | f->cdata->set_pyobj(nullptr); |
113 | } |
114 | f->cdata.reset(); |
115 | return 0; |
116 | } |
117 | |
118 | void THPCppFunction_dealloc(PyObject* self) { |
119 | PyObject_GC_UnTrack(self); |
120 | THPCppFunction_clear(self); |
121 | ((THPCppFunction*)self)->cdata.~shared_ptr(); |
122 | Py_TYPE(self)->tp_free(self); |
123 | } |
124 | |
125 | } // namespace |
126 | |
127 | PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook) { |
128 | const auto num_next = self->cdata->num_outputs(); |
129 | THPObjectPtr py_functions(PyTuple_New(num_next)); |
130 | if (!py_functions) |
131 | return nullptr; |
132 | for (const auto i : c10::irange(num_next)) { |
133 | auto& c_tuple = self->cdata->next_edge(i); |
134 | THPObjectPtr tuple(PyTuple_New(2)); |
135 | if (!tuple) |
136 | return nullptr; |
137 | PyObject* py_fn = functionToPyObject(c_tuple.function); |
138 | if (!py_fn) |
139 | return nullptr; |
140 | PyTuple_SET_ITEM(tuple.get(), 0, py_fn); |
141 | PyObject* py_idx = THPUtils_packUInt32(c_tuple.input_nr); |
142 | if (!py_idx) |
143 | return nullptr; |
144 | PyTuple_SET_ITEM(tuple.get(), 1, py_idx); |
145 | PyTuple_SET_ITEM(py_functions.get(), i, tuple.release()); |
146 | } |
147 | return py_functions.release(); |
148 | } |
149 | |
150 | PyObject* THPCppFunction_metadata(THPCppFunction* self, void* _unused) { |
151 | auto* metadata = |
152 | static_cast<PyAnomalyMetadata*>(self->cdata->metadata())->dict(); |
153 | |
154 | Py_XINCREF(metadata); |
155 | return metadata; |
156 | } |
157 | |
158 | PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void* unused) { |
159 | Py_RETURN_TRUE; |
160 | } |
161 | |
162 | PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) { |
163 | if (!THPVariable_Check(_var)) { |
164 | return PyErr_Format( |
165 | PyExc_TypeError, "_register_hook_dict expected a variable" ); |
166 | } |
167 | auto var = (THPVariable*)_var; |
168 | auto& fn = *((THPCppFunction*)self)->cdata; |
169 | std::unique_ptr<FunctionPreHook> hook(new PyFunctionTensorPreHook( |
170 | var->backward_hooks, THPVariable_Unpack(var).output_nr())); |
171 | fn.add_tensor_pre_hook(std::move(hook)); |
172 | Py_RETURN_NONE; |
173 | } |
174 | |
175 | PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) { |
176 | auto& fn = *((THPCppFunction*)self)->cdata; |
177 | return registerFunctionHook(fn, hook); |
178 | } |
179 | |
180 | PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook) { |
181 | auto& fn = *((THPCppFunction*)self)->cdata; |
182 | return registerFunctionPreHook(fn, hook); |
183 | } |
184 | |
185 | PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs) { |
186 | auto& fn = *((THPCppFunction*)self)->cdata; |
187 | return THPUtils_packString(fn.name()); |
188 | } |
189 | |
190 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
191 | static struct PyMethodDef default_methods[] = { |
192 | THP_FUNCTION_DEFAULT_METHODS, |
193 | {nullptr}}; |
194 | |
195 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
196 | static struct PyGetSetDef default_properties[] = { |
197 | THP_FUNCTION_DEFAULT_PROPERTIES, |
198 | {nullptr}}; |
199 | |
200 | PyTypeObject* _initFunctionPyTypeObject( |
201 | PyTypeObject& type, |
202 | const char* name, |
203 | PyGetSetDef* function_properties, |
204 | PyMethodDef* function_methods) { |
205 | type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC; |
206 | type.tp_name = name; |
207 | type.tp_basicsize = sizeof(THPCppFunction); |
208 | type.tp_call = THPCppFunction_call; |
209 | type.tp_methods = function_methods ? function_methods : default_methods; |
210 | type.tp_getset = |
211 | function_properties ? function_properties : default_properties; |
212 | type.tp_dealloc = THPCppFunction_dealloc; |
213 | type.tp_traverse = THPCppFunction_traverse; |
214 | type.tp_clear = THPCppFunction_clear; |
215 | if (PyType_Ready(&type) < 0) { |
216 | auto msg = std::string("Unable to instantiate PyTypeObject for " ) + name; |
217 | throw std::runtime_error(msg); |
218 | } |
219 | return &type; |
220 | } |
221 | |
222 | static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types_map; |
223 | static std::unordered_set<PyTypeObject*> cpp_function_types_set; |
224 | |
225 | struct DefaultFunctionType { |
226 | DefaultFunctionType() : type() { |
227 | _initFunctionPyTypeObject(type, "CppFunction" , nullptr, nullptr); |
228 | Py_INCREF(&type); |
229 | } |
230 | |
231 | PyTypeObject type; |
232 | }; |
233 | |
234 | PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) { |
235 | static DefaultFunctionType default_type; |
236 | |
237 | if (!cdata) { |
238 | Py_RETURN_NONE; |
239 | } |
240 | |
241 | if (auto pfw = dynamic_cast<PyNode*>(cdata.get())) { |
242 | PyObject* obj = pfw->obj; |
243 | Py_INCREF(obj); |
244 | return obj; |
245 | } |
246 | |
247 | if (cdata->pyobj()) { |
248 | Py_INCREF(cdata->pyobj()); |
249 | } else { |
250 | auto& fn = *cdata; |
251 | auto it = cpp_function_types_map.find(std::type_index(typeid(fn))); |
252 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
253 | PyTypeObject* type; |
254 | if (it == cpp_function_types_map.end()) { |
255 | type = &default_type.type; |
256 | } else { |
257 | type = (PyTypeObject*)it->second.get(); |
258 | } |
259 | |
260 | THPObjectPtr obj(type->tp_alloc(type, 0)); |
261 | if (!obj) |
262 | return nullptr; |
263 | THPCppFunction* f = (THPCppFunction*)obj.get(); |
264 | new (&f->cdata) std::shared_ptr<Node>(cdata); |
265 | |
266 | // No INCREF here as we only have a weak reference |
267 | cdata->set_pyobj(obj.release()); |
268 | } |
269 | |
270 | return cdata->pyobj(); |
271 | } |
272 | |
273 | void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) { |
274 | Py_INCREF((PyObject*)pytype); |
275 | cpp_function_types_map[std::type_index(type)] = |
276 | THPObjectPtr((PyObject*)pytype); |
277 | cpp_function_types_set.insert(pytype); |
278 | } |
279 | |
280 | bool THPCppFunction_Check(PyObject* obj) { |
281 | THPObjectPtr type = THPObjectPtr(PyObject_Type(obj)); |
282 | if (cpp_function_types_set.find((PyTypeObject*)type.get()) == |
283 | cpp_function_types_set.end()) { |
284 | return false; |
285 | } else { |
286 | return true; |
287 | } |
288 | } |
289 | |
290 | PyObject* callRegisterFn(PyObject* dict, PyObject* hook) { |
291 | THPObjectPtr register_fn( |
292 | PyObject_GetAttrString(THPFunctionClass, "_register_hook" )); |
293 | if (!register_fn) { |
294 | return nullptr; |
295 | } |
296 | THPObjectPtr res( |
297 | PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr)); |
298 | if (!res) { |
299 | return nullptr; |
300 | } |
301 | return res.release(); |
302 | } |
303 | |
304 | PyObject* registerFunctionHook(Node& fn, PyObject* hook) { |
305 | PyObject* dict = Py_None; |
306 | for (const auto& hook : fn.post_hooks()) { |
307 | if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) { |
308 | dict = pyhook->dict; |
309 | break; |
310 | } |
311 | } |
312 | THPObjectPtr res{callRegisterFn(dict, hook)}; |
313 | if (!res) { |
314 | return nullptr; |
315 | } |
316 | if (dict == Py_None) { |
317 | dict = PyTuple_GET_ITEM(res.get(), 0); |
318 | std::unique_ptr<FunctionPostHook> hook(new PyFunctionPostHook(dict)); |
319 | fn.add_post_hook(std::move(hook)); |
320 | } |
321 | |
322 | PyObject* handle = PyTuple_GET_ITEM(res.get(), 1); |
323 | Py_INCREF(handle); |
324 | return handle; |
325 | } |
326 | |
327 | // This is almost a copy of the function above except post -> pre |
328 | PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) { |
329 | PyObject* dict = Py_None; |
330 | for (const auto& hook : fn.pre_hooks()) { |
331 | if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) { |
332 | dict = pyhook->dict; |
333 | break; |
334 | } |
335 | } |
336 | THPObjectPtr res{callRegisterFn(dict, hook)}; |
337 | if (!res) { |
338 | return nullptr; |
339 | } |
340 | if (dict == Py_None) { |
341 | dict = PyTuple_GET_ITEM(res.get(), 0); |
342 | std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(dict)); |
343 | fn.add_pre_hook(std::move(hook)); |
344 | } |
345 | |
346 | PyObject* handle = PyTuple_GET_ITEM(res.get(), 1); |
347 | Py_INCREF(handle); |
348 | return handle; |
349 | } |
350 | |
351 | } // namespace autograd |
352 | } // namespace torch |
353 | |