1#pragma once
2
3#include <c10/core/SafePyObject.h>
4#include <c10/core/SymNodeImpl.h>
5
6#include <torch/csrc/PyInterpreter.h>
7#include <torch/csrc/autograd/python_variable.h>
8#include <torch/csrc/utils/pybind.h>
9
10namespace torch {
11
12TORCH_PYTHON_API py::handle get_symint_class();
13TORCH_PYTHON_API py::handle get_symfloat_class();
14
15// NB: These functions must not be called too early, otherwise torch not setup.
16// Alternate design is to have torch "register" the object to us
17inline bool is_symint(py::handle obj) {
18 return py::isinstance(obj, get_symint_class());
19}
20inline bool is_symfloat(py::handle obj) {
21 return py::isinstance(obj, get_symfloat_class());
22}
23
24namespace impl {
25
26// This c10::SymNodeImpl simply backends to a Python object that
27// implements the API. The Python object is the source of truth,
28// this is just an adapter so C++ calls can get to the object.
29class PythonSymNodeImpl : public c10::SymNodeImpl {
30 public:
31 PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
32 pyobj_ = std::make_shared<c10::SafePyObject>(
33 pyobj.release().ptr(), getPyInterpreter());
34 };
35
36 c10::SymNode wrap_int(int64_t num) override {
37 py::gil_scoped_acquire acquire;
38 auto r = getPyObj().attr("wrap_int")(num);
39 return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
40 }
41
42 c10::SymNode wrap_float(double num) override {
43 py::gil_scoped_acquire acquire;
44 auto r = getPyObj().attr("wrap_float")(num);
45 return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
46 }
47
48 c10::SymNode wrap_bool(bool num) override {
49 py::gil_scoped_acquire acquire;
50 auto r = getPyObj().attr("wrap_bool")(num);
51 return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
52 }
53
54 c10::SymNode is_non_overlapping_and_dense(
55 c10::ArrayRef<c10::SymNode> sizes,
56 c10::ArrayRef<c10::SymNode> strides) override {
57 py::gil_scoped_acquire acquire;
58 auto r = getPyObj().attr("is_non_overlapping_and_dense")(sizes, strides);
59 return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
60 }
61
62 bool bool_() override {
63 py::gil_scoped_acquire acquire;
64 return getPyObj().attr("bool_")().is(py::handle(Py_True));
65 }
66
67 bool is_int() override {
68 py::gil_scoped_acquire acquire;
69 return getPyObj().attr("is_int")().is(py::handle(Py_True));
70 }
71
72 bool is_float() override {
73 py::gil_scoped_acquire acquire;
74 return getPyObj().attr("is_float")().is(py::handle(Py_True));
75 }
76
77 bool is_bool() override {
78 py::gil_scoped_acquire acquire;
79 return getPyObj().attr("is_bool")().is(py::handle(Py_True));
80 }
81
82 int64_t guard_int(const char* file, int64_t line) override {
83 py::gil_scoped_acquire acquire;
84 return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
85 }
86
87 double guard_float(const char* file, int64_t line) override {
88 py::gil_scoped_acquire acquire;
89 return getPyObj().attr("guard_float")(file, line).cast<double>();
90 }
91
92 bool guard_bool(const char* file, int64_t line) override {
93 py::gil_scoped_acquire acquire;
94 return getPyObj().attr("guard_bool")(file, line).cast<bool>();
95 }
96
97 int64_t int_() override {
98 py::gil_scoped_acquire acquire;
99 return getPyObj().attr("int_")().cast<int64_t>();
100 }
101
102 std::string str() override {
103 py::gil_scoped_acquire acquire;
104 return getPyObj().attr("str")().cast<std::string>();
105 }
106
107 c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
108 auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
109 TORCH_CHECK(pother);
110 py::gil_scoped_acquire acquire;
111 auto r = getPyObj().attr(fname)(pother->getPyObj());
112 return c10::make_intrusive<PythonSymNodeImpl>(r);
113 }
114
115 c10::SymNode dispatch_common_(const char* fname) {
116 py::gil_scoped_acquire acquire;
117 auto r = getPyObj().attr(fname)();
118 return c10::make_intrusive<PythonSymNodeImpl>(r);
119 }
120
121 c10::SymNode add(const c10::SymNode& other) override {
122 return dispatch_common_(__func__, other);
123 }
124
125 c10::SymNode sub(const c10::SymNode& other) override {
126 return dispatch_common_(__func__, other);
127 }
128
129 c10::SymNode mul(const c10::SymNode& other) override {
130 return dispatch_common_(__func__, other);
131 }
132
133 c10::SymNode truediv(const c10::SymNode& other) override {
134 return dispatch_common_(__func__, other);
135 }
136
137 c10::SymNode pow(const c10::SymNode& other) override {
138 return dispatch_common_(__func__, other);
139 }
140
141 c10::SymNode floordiv(const c10::SymNode& other) override {
142 return dispatch_common_(__func__, other);
143 }
144
145 c10::SymNode mod(const c10::SymNode& other) override {
146 return dispatch_common_(__func__, other);
147 }
148
149 c10::SymNode eq(const c10::SymNode& other) override {
150 return dispatch_common_(__func__, other);
151 }
152
153 c10::SymNode ne(const c10::SymNode& other) override {
154 return dispatch_common_(__func__, other);
155 }
156
157 c10::SymNode gt(const c10::SymNode& other) override {
158 return dispatch_common_(__func__, other);
159 }
160
161 c10::SymNode lt(const c10::SymNode& other) override {
162 return dispatch_common_(__func__, other);
163 }
164
165 c10::SymNode le(const c10::SymNode& other) override {
166 return dispatch_common_(__func__, other);
167 }
168
169 c10::SymNode ge(const c10::SymNode& other) override {
170 return dispatch_common_(__func__, other);
171 }
172
173 c10::SymNode sym_min(const c10::SymNode& other) override {
174 return dispatch_common_(__func__, other);
175 }
176 c10::SymNode sym_max(const c10::SymNode& other) override {
177 return dispatch_common_(__func__, other);
178 }
179
180 c10::SymNode sym_and(const c10::SymNode& other) override {
181 return dispatch_common_(__func__, other);
182 }
183
184 c10::SymNode sym_or(const c10::SymNode& other) override {
185 return dispatch_common_(__func__, other);
186 }
187
188 c10::SymNode sym_not() override {
189 return dispatch_common_(__func__);
190 }
191
192 c10::SymNode ceil() override {
193 return dispatch_common_(__func__);
194 }
195
196 c10::SymNode floor() override {
197 return dispatch_common_(__func__);
198 }
199
200 c10::SymNode neg() override {
201 return dispatch_common_(__func__);
202 }
203
204 c10::SymNode clone() override {
205 return dispatch_common_(__func__);
206 }
207
208 c10::SymNode sym_float() override {
209 return dispatch_common_(__func__);
210 }
211
212 py::handle getPyObj() {
213 return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
214 }
215 std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
216};
217
218} // namespace impl
219} // namespace torch
220