1 | #pragma once |
2 | |
3 | #include <torch/csrc/python_headers.h> |
4 | #include <memory> |
5 | #include <typeinfo> |
6 | |
7 | #include <torch/csrc/Exceptions.h> |
8 | #include <torch/csrc/autograd/function.h> |
9 | #include <torch/csrc/utils/object_ptr.h> |
10 | |
11 | namespace torch { |
12 | namespace autograd { |
13 | |
14 | struct THPCppFunction { |
15 | PyObject_HEAD std::shared_ptr<Node> cdata; |
16 | }; |
17 | |
18 | template <typename Ctor> |
19 | PyObject* CppFunction_pynew( |
20 | PyTypeObject* type, |
21 | PyObject* args, |
22 | PyObject* kwds) { |
23 | THPObjectPtr obj(type->tp_alloc(type, 0)); |
24 | if (!obj) |
25 | return nullptr; |
26 | THPCppFunction* f = (THPCppFunction*)obj.get(); |
27 | HANDLE_TH_ERRORS |
28 | new (&f->cdata) std::shared_ptr<Node>(Ctor()(args)); |
29 | END_HANDLE_TH_ERRORS |
30 | if (!f->cdata) { |
31 | return nullptr; |
32 | } |
33 | return obj.release(); |
34 | } |
35 | |
36 | #define THP_FUNCTION_DEFAULT_METHODS \ |
37 | {(char*)"_register_hook_dict", \ |
38 | THPCppFunction_register_hook_dict, \ |
39 | METH_O, \ |
40 | nullptr}, \ |
41 | {(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \ |
42 | {(char*)"register_prehook", \ |
43 | THPCppFunction_register_prehook, \ |
44 | METH_O, \ |
45 | nullptr}, \ |
46 | { \ |
47 | (char*)"name", THPCppFunction_name, METH_NOARGS, nullptr \ |
48 | } |
49 | |
50 | #define THP_FUNCTION_DEFAULT_PROPERTIES \ |
51 | {(char*)"next_functions", \ |
52 | (getter)THPCppFunction_next_functions, \ |
53 | nullptr, \ |
54 | nullptr, \ |
55 | nullptr}, \ |
56 | {(char*)"requires_grad", \ |
57 | (getter)THPCppFunction_requires_grad, \ |
58 | nullptr, \ |
59 | nullptr, \ |
60 | nullptr}, \ |
61 | { \ |
62 | (char*)"metadata", (getter)THPCppFunction_metadata, nullptr, nullptr, \ |
63 | nullptr \ |
64 | } |
65 | |
66 | PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook); |
67 | PyObject* THPCppFunction_metadata(THPCppFunction* self, void* _unused); |
68 | PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void* _unused); |
69 | PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var); |
70 | PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook); |
71 | PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook); |
72 | |
73 | PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs); |
74 | |
75 | PyTypeObject* _initFunctionPyTypeObject( |
76 | PyTypeObject& type, |
77 | const char* name, |
78 | PyGetSetDef* function_properties, |
79 | PyMethodDef* function_methods); |
80 | |
81 | PyObject* registerFunctionHook(Node& fn, PyObject* hook); |
82 | |
83 | PyObject* registerFunctionPreHook(Node& fn, PyObject* hook); |
84 | |
85 | template <typename Ctor> |
86 | PyTypeObject* createForwardFunctionPyTypeObject( |
87 | PyTypeObject& type, |
88 | const char* name, |
89 | PyGetSetDef* function_properties = nullptr, |
90 | PyMethodDef* function_methods = nullptr) { |
91 | type.tp_new = &CppFunction_pynew<Ctor>; |
92 | return _initFunctionPyTypeObject( |
93 | type, name, function_properties, function_methods); |
94 | } |
95 | |
96 | void registerCppFunction(const std::type_info& type, PyTypeObject* pytype); |
97 | PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata); |
98 | |
99 | bool THPCppFunction_Check(PyObject* obj); |
100 | |
101 | } // namespace autograd |
102 | } // namespace torch |
103 | |