1#include <c10/util/irange.h>
2#include <torch/csrc/autograd/python_cpp_function.h>
3
4#include <torch/csrc/python_headers.h>
5#include <cstdio>
6#include <memory>
7#include <typeindex>
8#include <unordered_map>
9
10#include <pybind11/pybind11.h>
11#include <torch/csrc/DynamicTypes.h>
12#include <torch/csrc/Exceptions.h>
13#include <torch/csrc/autograd/python_anomaly_mode.h>
14#include <torch/csrc/autograd/python_function.h>
15#include <torch/csrc/autograd/python_hook.h>
16#include <torch/csrc/autograd/python_variable.h>
17#include <torch/csrc/utils/pybind.h>
18#include <torch/csrc/utils/python_numbers.h>
19#include <torch/csrc/utils/python_strings.h>
20
21using namespace torch::autograd;
22
23namespace torch {
24namespace autograd {
25
26namespace {
27
28PyObject* THPCppFunction_call(
29 PyObject* self,
30 PyObject* args,
31 PyObject* kwargs) {
32 if (kwargs && PyDict_Size(kwargs) != 0) {
33 return PyErr_Format(PyExc_TypeError, "keyword arguments are not supported");
34 }
35
36 int num_inputs = PyTuple_GET_SIZE(args);
37 int num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs();
38 if (num_inputs != num_inputs_required) {
39 return PyErr_Format(
40 PyExc_TypeError,
41 "expected %d arguments, got %d instead",
42 num_inputs_required,
43 num_inputs);
44 }
45 variable_list vars(num_inputs);
46 for (int i = 0; i != num_inputs; ++i) {
47 PyObject* arg = PyTuple_GET_ITEM(args, i);
48 if (arg == Py_None) {
49 continue;
50 }
51 if (!THPVariable_Check(arg)) {
52 return PyErr_Format(PyExc_TypeError, "argument %d is not a Variable", i);
53 }
54 vars[i] = THPVariable_Unpack(arg);
55 }
56
57 variable_list output;
58
59 HANDLE_TH_ERRORS {
60 pybind11::gil_scoped_release nogil;
61 output = (*((THPCppFunction*)self)->cdata)(std::move(vars));
62 }
63 END_HANDLE_TH_ERRORS
64
65 int num_outputs = output.size();
66 if (num_outputs == 1) {
67 // assume we want to unpack one element tuples for now
68 return THPVariable_Wrap(output[0]);
69 }
70
71 THPObjectPtr tuple(PyTuple_New(num_outputs));
72 for (int i = 0; i != num_outputs; ++i) {
73 PyTuple_SET_ITEM(tuple.get(), i, THPVariable_Wrap(output[i]));
74 }
75 return tuple.release();
76}
77
78int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
79 auto& fn = *((THPCppFunction*)self)->cdata;
80 for (const auto& hook : fn.tensor_pre_hooks()) {
81 if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
82 Py_VISIT(pyhook->dict);
83 }
84 }
85 // NOTE [retains_grad_hook PyObject traversal]
86 // In theory this shouldn't be necessary, because retains_grad_hooks should
87 // not contain any PyFunctionTensorPreHooks. The alternative is to have a
88 // check that actually guarantees this.
89 for (const auto& pair : fn.retains_grad_hooks()) {
90 if (auto pyhook =
91 dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
92 Py_VISIT(pyhook->dict);
93 }
94 }
95 for (const auto& hook : fn.pre_hooks()) {
96 if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
97 Py_VISIT(pyhook->dict);
98 }
99 }
100 for (const auto& hook : fn.post_hooks()) {
101 if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
102 Py_VISIT(pyhook->dict);
103 }
104 }
105 return 0;
106}
107
108int THPCppFunction_clear(PyObject* self) {
109 auto f = (THPCppFunction*)self;
110 // Remove the weak ref of the c++ object if it exist
111 if (f->cdata) {
112 f->cdata->set_pyobj(nullptr);
113 }
114 f->cdata.reset();
115 return 0;
116}
117
118void THPCppFunction_dealloc(PyObject* self) {
119 PyObject_GC_UnTrack(self);
120 THPCppFunction_clear(self);
121 ((THPCppFunction*)self)->cdata.~shared_ptr();
122 Py_TYPE(self)->tp_free(self);
123}
124
125} // namespace
126
127PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook) {
128 const auto num_next = self->cdata->num_outputs();
129 THPObjectPtr py_functions(PyTuple_New(num_next));
130 if (!py_functions)
131 return nullptr;
132 for (const auto i : c10::irange(num_next)) {
133 auto& c_tuple = self->cdata->next_edge(i);
134 THPObjectPtr tuple(PyTuple_New(2));
135 if (!tuple)
136 return nullptr;
137 PyObject* py_fn = functionToPyObject(c_tuple.function);
138 if (!py_fn)
139 return nullptr;
140 PyTuple_SET_ITEM(tuple.get(), 0, py_fn);
141 PyObject* py_idx = THPUtils_packUInt32(c_tuple.input_nr);
142 if (!py_idx)
143 return nullptr;
144 PyTuple_SET_ITEM(tuple.get(), 1, py_idx);
145 PyTuple_SET_ITEM(py_functions.get(), i, tuple.release());
146 }
147 return py_functions.release();
148}
149
150PyObject* THPCppFunction_metadata(THPCppFunction* self, void* _unused) {
151 auto* metadata =
152 static_cast<PyAnomalyMetadata*>(self->cdata->metadata())->dict();
153
154 Py_XINCREF(metadata);
155 return metadata;
156}
157
158PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void* unused) {
159 Py_RETURN_TRUE;
160}
161
162PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) {
163 if (!THPVariable_Check(_var)) {
164 return PyErr_Format(
165 PyExc_TypeError, "_register_hook_dict expected a variable");
166 }
167 auto var = (THPVariable*)_var;
168 auto& fn = *((THPCppFunction*)self)->cdata;
169 std::unique_ptr<FunctionPreHook> hook(new PyFunctionTensorPreHook(
170 var->backward_hooks, THPVariable_Unpack(var).output_nr()));
171 fn.add_tensor_pre_hook(std::move(hook));
172 Py_RETURN_NONE;
173}
174
175PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) {
176 auto& fn = *((THPCppFunction*)self)->cdata;
177 return registerFunctionHook(fn, hook);
178}
179
180PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook) {
181 auto& fn = *((THPCppFunction*)self)->cdata;
182 return registerFunctionPreHook(fn, hook);
183}
184
185PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs) {
186 auto& fn = *((THPCppFunction*)self)->cdata;
187 return THPUtils_packString(fn.name());
188}
189
190// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
191static struct PyMethodDef default_methods[] = {
192 THP_FUNCTION_DEFAULT_METHODS,
193 {nullptr}};
194
195// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
196static struct PyGetSetDef default_properties[] = {
197 THP_FUNCTION_DEFAULT_PROPERTIES,
198 {nullptr}};
199
200PyTypeObject* _initFunctionPyTypeObject(
201 PyTypeObject& type,
202 const char* name,
203 PyGetSetDef* function_properties,
204 PyMethodDef* function_methods) {
205 type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
206 type.tp_name = name;
207 type.tp_basicsize = sizeof(THPCppFunction);
208 type.tp_call = THPCppFunction_call;
209 type.tp_methods = function_methods ? function_methods : default_methods;
210 type.tp_getset =
211 function_properties ? function_properties : default_properties;
212 type.tp_dealloc = THPCppFunction_dealloc;
213 type.tp_traverse = THPCppFunction_traverse;
214 type.tp_clear = THPCppFunction_clear;
215 if (PyType_Ready(&type) < 0) {
216 auto msg = std::string("Unable to instantiate PyTypeObject for ") + name;
217 throw std::runtime_error(msg);
218 }
219 return &type;
220}
221
222static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types_map;
223static std::unordered_set<PyTypeObject*> cpp_function_types_set;
224
225struct DefaultFunctionType {
226 DefaultFunctionType() : type() {
227 _initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
228 Py_INCREF(&type);
229 }
230
231 PyTypeObject type;
232};
233
234PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
235 static DefaultFunctionType default_type;
236
237 if (!cdata) {
238 Py_RETURN_NONE;
239 }
240
241 if (auto pfw = dynamic_cast<PyNode*>(cdata.get())) {
242 PyObject* obj = pfw->obj;
243 Py_INCREF(obj);
244 return obj;
245 }
246
247 if (cdata->pyobj()) {
248 Py_INCREF(cdata->pyobj());
249 } else {
250 auto& fn = *cdata;
251 auto it = cpp_function_types_map.find(std::type_index(typeid(fn)));
252 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
253 PyTypeObject* type;
254 if (it == cpp_function_types_map.end()) {
255 type = &default_type.type;
256 } else {
257 type = (PyTypeObject*)it->second.get();
258 }
259
260 THPObjectPtr obj(type->tp_alloc(type, 0));
261 if (!obj)
262 return nullptr;
263 THPCppFunction* f = (THPCppFunction*)obj.get();
264 new (&f->cdata) std::shared_ptr<Node>(cdata);
265
266 // No INCREF here as we only have a weak reference
267 cdata->set_pyobj(obj.release());
268 }
269
270 return cdata->pyobj();
271}
272
273void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
274 Py_INCREF((PyObject*)pytype);
275 cpp_function_types_map[std::type_index(type)] =
276 THPObjectPtr((PyObject*)pytype);
277 cpp_function_types_set.insert(pytype);
278}
279
280bool THPCppFunction_Check(PyObject* obj) {
281 THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
282 if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
283 cpp_function_types_set.end()) {
284 return false;
285 } else {
286 return true;
287 }
288}
289
290PyObject* callRegisterFn(PyObject* dict, PyObject* hook) {
291 THPObjectPtr register_fn(
292 PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
293 if (!register_fn) {
294 return nullptr;
295 }
296 THPObjectPtr res(
297 PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
298 if (!res) {
299 return nullptr;
300 }
301 return res.release();
302}
303
304PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
305 PyObject* dict = Py_None;
306 for (const auto& hook : fn.post_hooks()) {
307 if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
308 dict = pyhook->dict;
309 break;
310 }
311 }
312 THPObjectPtr res{callRegisterFn(dict, hook)};
313 if (!res) {
314 return nullptr;
315 }
316 if (dict == Py_None) {
317 dict = PyTuple_GET_ITEM(res.get(), 0);
318 std::unique_ptr<FunctionPostHook> hook(new PyFunctionPostHook(dict));
319 fn.add_post_hook(std::move(hook));
320 }
321
322 PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
323 Py_INCREF(handle);
324 return handle;
325}
326
327// This is almost a copy of the function above except post -> pre
328PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
329 PyObject* dict = Py_None;
330 for (const auto& hook : fn.pre_hooks()) {
331 if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
332 dict = pyhook->dict;
333 break;
334 }
335 }
336 THPObjectPtr res{callRegisterFn(dict, hook)};
337 if (!res) {
338 return nullptr;
339 }
340 if (dict == Py_None) {
341 dict = PyTuple_GET_ITEM(res.get(), 0);
342 std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook(dict));
343 fn.add_pre_hook(std::move(hook));
344 }
345
346 PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
347 Py_INCREF(handle);
348 return handle;
349}
350
351} // namespace autograd
352} // namespace torch
353