1#pragma once
2
3#include <torch/csrc/python_headers.h>
4#include <memory>
5#include <typeinfo>
6
7#include <torch/csrc/Exceptions.h>
8#include <torch/csrc/autograd/function.h>
9#include <torch/csrc/utils/object_ptr.h>
10
11namespace torch {
12namespace autograd {
13
14struct THPCppFunction {
15 PyObject_HEAD std::shared_ptr<Node> cdata;
16};
17
18template <typename Ctor>
19PyObject* CppFunction_pynew(
20 PyTypeObject* type,
21 PyObject* args,
22 PyObject* kwds) {
23 THPObjectPtr obj(type->tp_alloc(type, 0));
24 if (!obj)
25 return nullptr;
26 THPCppFunction* f = (THPCppFunction*)obj.get();
27 HANDLE_TH_ERRORS
28 new (&f->cdata) std::shared_ptr<Node>(Ctor()(args));
29 END_HANDLE_TH_ERRORS
30 if (!f->cdata) {
31 return nullptr;
32 }
33 return obj.release();
34}
35
36#define THP_FUNCTION_DEFAULT_METHODS \
37 {(char*)"_register_hook_dict", \
38 THPCppFunction_register_hook_dict, \
39 METH_O, \
40 nullptr}, \
41 {(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \
42 {(char*)"register_prehook", \
43 THPCppFunction_register_prehook, \
44 METH_O, \
45 nullptr}, \
46 { \
47 (char*)"name", THPCppFunction_name, METH_NOARGS, nullptr \
48 }
49
50#define THP_FUNCTION_DEFAULT_PROPERTIES \
51 {(char*)"next_functions", \
52 (getter)THPCppFunction_next_functions, \
53 nullptr, \
54 nullptr, \
55 nullptr}, \
56 {(char*)"requires_grad", \
57 (getter)THPCppFunction_requires_grad, \
58 nullptr, \
59 nullptr, \
60 nullptr}, \
61 { \
62 (char*)"metadata", (getter)THPCppFunction_metadata, nullptr, nullptr, \
63 nullptr \
64 }
65
66PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook);
67PyObject* THPCppFunction_metadata(THPCppFunction* self, void* _unused);
68PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void* _unused);
69PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var);
70PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook);
71PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook);
72
73PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs);
74
75PyTypeObject* _initFunctionPyTypeObject(
76 PyTypeObject& type,
77 const char* name,
78 PyGetSetDef* function_properties,
79 PyMethodDef* function_methods);
80
81PyObject* registerFunctionHook(Node& fn, PyObject* hook);
82
83PyObject* registerFunctionPreHook(Node& fn, PyObject* hook);
84
85template <typename Ctor>
86PyTypeObject* createForwardFunctionPyTypeObject(
87 PyTypeObject& type,
88 const char* name,
89 PyGetSetDef* function_properties = nullptr,
90 PyMethodDef* function_methods = nullptr) {
91 type.tp_new = &CppFunction_pynew<Ctor>;
92 return _initFunctionPyTypeObject(
93 type, name, function_properties, function_methods);
94}
95
96void registerCppFunction(const std::type_info& type, PyTypeObject* pytype);
97PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata);
98
99bool THPCppFunction_Check(PyObject* obj);
100
101} // namespace autograd
102} // namespace torch
103