1 | #pragma once |
2 | |
3 | #include <c10/core/SafePyObject.h> |
4 | #include <c10/core/SymNodeImpl.h> |
5 | |
6 | #include <torch/csrc/PyInterpreter.h> |
7 | #include <torch/csrc/autograd/python_variable.h> |
8 | #include <torch/csrc/utils/pybind.h> |
9 | |
10 | namespace torch { |
11 | |
12 | TORCH_PYTHON_API py::handle get_symint_class(); |
13 | TORCH_PYTHON_API py::handle get_symfloat_class(); |
14 | |
15 | // NB: These functions must not be called too early, otherwise torch not setup. |
16 | // Alternate design is to have torch "register" the object to us |
17 | inline bool is_symint(py::handle obj) { |
18 | return py::isinstance(obj, get_symint_class()); |
19 | } |
20 | inline bool is_symfloat(py::handle obj) { |
21 | return py::isinstance(obj, get_symfloat_class()); |
22 | } |
23 | |
24 | namespace impl { |
25 | |
26 | // This c10::SymNodeImpl simply backends to a Python object that |
27 | // implements the API. The Python object is the source of truth, |
28 | // this is just an adapter so C++ calls can get to the object. |
29 | class PythonSymNodeImpl : public c10::SymNodeImpl { |
30 | public: |
31 | PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() { |
32 | pyobj_ = std::make_shared<c10::SafePyObject>( |
33 | pyobj.release().ptr(), getPyInterpreter()); |
34 | }; |
35 | |
36 | c10::SymNode wrap_int(int64_t num) override { |
37 | py::gil_scoped_acquire acquire; |
38 | auto r = getPyObj().attr("wrap_int" )(num); |
39 | return c10::make_intrusive<PythonSymNodeImpl>(std::move(r)); |
40 | } |
41 | |
42 | c10::SymNode wrap_float(double num) override { |
43 | py::gil_scoped_acquire acquire; |
44 | auto r = getPyObj().attr("wrap_float" )(num); |
45 | return c10::make_intrusive<PythonSymNodeImpl>(std::move(r)); |
46 | } |
47 | |
48 | c10::SymNode wrap_bool(bool num) override { |
49 | py::gil_scoped_acquire acquire; |
50 | auto r = getPyObj().attr("wrap_bool" )(num); |
51 | return c10::make_intrusive<PythonSymNodeImpl>(std::move(r)); |
52 | } |
53 | |
54 | c10::SymNode is_non_overlapping_and_dense( |
55 | c10::ArrayRef<c10::SymNode> sizes, |
56 | c10::ArrayRef<c10::SymNode> strides) override { |
57 | py::gil_scoped_acquire acquire; |
58 | auto r = getPyObj().attr("is_non_overlapping_and_dense" )(sizes, strides); |
59 | return c10::make_intrusive<PythonSymNodeImpl>(std::move(r)); |
60 | } |
61 | |
62 | bool bool_() override { |
63 | py::gil_scoped_acquire acquire; |
64 | return getPyObj().attr("bool_" )().is(py::handle(Py_True)); |
65 | } |
66 | |
67 | bool is_int() override { |
68 | py::gil_scoped_acquire acquire; |
69 | return getPyObj().attr("is_int" )().is(py::handle(Py_True)); |
70 | } |
71 | |
72 | bool is_float() override { |
73 | py::gil_scoped_acquire acquire; |
74 | return getPyObj().attr("is_float" )().is(py::handle(Py_True)); |
75 | } |
76 | |
77 | bool is_bool() override { |
78 | py::gil_scoped_acquire acquire; |
79 | return getPyObj().attr("is_bool" )().is(py::handle(Py_True)); |
80 | } |
81 | |
82 | int64_t guard_int(const char* file, int64_t line) override { |
83 | py::gil_scoped_acquire acquire; |
84 | return getPyObj().attr("guard_int" )(file, line).cast<int64_t>(); |
85 | } |
86 | |
87 | double guard_float(const char* file, int64_t line) override { |
88 | py::gil_scoped_acquire acquire; |
89 | return getPyObj().attr("guard_float" )(file, line).cast<double>(); |
90 | } |
91 | |
92 | bool guard_bool(const char* file, int64_t line) override { |
93 | py::gil_scoped_acquire acquire; |
94 | return getPyObj().attr("guard_bool" )(file, line).cast<bool>(); |
95 | } |
96 | |
97 | int64_t int_() override { |
98 | py::gil_scoped_acquire acquire; |
99 | return getPyObj().attr("int_" )().cast<int64_t>(); |
100 | } |
101 | |
102 | std::string str() override { |
103 | py::gil_scoped_acquire acquire; |
104 | return getPyObj().attr("str" )().cast<std::string>(); |
105 | } |
106 | |
107 | c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) { |
108 | auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get()); |
109 | TORCH_CHECK(pother); |
110 | py::gil_scoped_acquire acquire; |
111 | auto r = getPyObj().attr(fname)(pother->getPyObj()); |
112 | return c10::make_intrusive<PythonSymNodeImpl>(r); |
113 | } |
114 | |
115 | c10::SymNode dispatch_common_(const char* fname) { |
116 | py::gil_scoped_acquire acquire; |
117 | auto r = getPyObj().attr(fname)(); |
118 | return c10::make_intrusive<PythonSymNodeImpl>(r); |
119 | } |
120 | |
121 | c10::SymNode add(const c10::SymNode& other) override { |
122 | return dispatch_common_(__func__, other); |
123 | } |
124 | |
125 | c10::SymNode sub(const c10::SymNode& other) override { |
126 | return dispatch_common_(__func__, other); |
127 | } |
128 | |
129 | c10::SymNode mul(const c10::SymNode& other) override { |
130 | return dispatch_common_(__func__, other); |
131 | } |
132 | |
133 | c10::SymNode truediv(const c10::SymNode& other) override { |
134 | return dispatch_common_(__func__, other); |
135 | } |
136 | |
137 | c10::SymNode pow(const c10::SymNode& other) override { |
138 | return dispatch_common_(__func__, other); |
139 | } |
140 | |
141 | c10::SymNode floordiv(const c10::SymNode& other) override { |
142 | return dispatch_common_(__func__, other); |
143 | } |
144 | |
145 | c10::SymNode mod(const c10::SymNode& other) override { |
146 | return dispatch_common_(__func__, other); |
147 | } |
148 | |
149 | c10::SymNode eq(const c10::SymNode& other) override { |
150 | return dispatch_common_(__func__, other); |
151 | } |
152 | |
153 | c10::SymNode ne(const c10::SymNode& other) override { |
154 | return dispatch_common_(__func__, other); |
155 | } |
156 | |
157 | c10::SymNode gt(const c10::SymNode& other) override { |
158 | return dispatch_common_(__func__, other); |
159 | } |
160 | |
161 | c10::SymNode lt(const c10::SymNode& other) override { |
162 | return dispatch_common_(__func__, other); |
163 | } |
164 | |
165 | c10::SymNode le(const c10::SymNode& other) override { |
166 | return dispatch_common_(__func__, other); |
167 | } |
168 | |
169 | c10::SymNode ge(const c10::SymNode& other) override { |
170 | return dispatch_common_(__func__, other); |
171 | } |
172 | |
173 | c10::SymNode sym_min(const c10::SymNode& other) override { |
174 | return dispatch_common_(__func__, other); |
175 | } |
176 | c10::SymNode sym_max(const c10::SymNode& other) override { |
177 | return dispatch_common_(__func__, other); |
178 | } |
179 | |
180 | c10::SymNode sym_and(const c10::SymNode& other) override { |
181 | return dispatch_common_(__func__, other); |
182 | } |
183 | |
184 | c10::SymNode sym_or(const c10::SymNode& other) override { |
185 | return dispatch_common_(__func__, other); |
186 | } |
187 | |
188 | c10::SymNode sym_not() override { |
189 | return dispatch_common_(__func__); |
190 | } |
191 | |
192 | c10::SymNode ceil() override { |
193 | return dispatch_common_(__func__); |
194 | } |
195 | |
196 | c10::SymNode floor() override { |
197 | return dispatch_common_(__func__); |
198 | } |
199 | |
200 | c10::SymNode neg() override { |
201 | return dispatch_common_(__func__); |
202 | } |
203 | |
204 | c10::SymNode clone() override { |
205 | return dispatch_common_(__func__); |
206 | } |
207 | |
208 | c10::SymNode sym_float() override { |
209 | return dispatch_common_(__func__); |
210 | } |
211 | |
212 | py::handle getPyObj() { |
213 | return py::handle(pyobj_.get()->ptr(getPyInterpreter())); |
214 | } |
215 | std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr; |
216 | }; |
217 | |
218 | } // namespace impl |
219 | } // namespace torch |
220 | |