1 | #include <torch/csrc/DynamicTypes.h> |
2 | #include <torch/csrc/THP.h> |
3 | #include <torch/csrc/autograd/variable.h> |
4 | #include <torch/csrc/python_headers.h> |
5 | #include <torch/csrc/utils/invalid_arguments.h> |
6 | #include <torch/csrc/utils/python_strings.h> |
7 | #include <torch/csrc/utils/python_symnode.h> |
8 | #include <torch/csrc/utils/python_tuples.h> |
9 | |
10 | #include <torch/csrc/Export.h> |
11 | |
12 | #include <algorithm> |
13 | #include <cstdarg> |
14 | #include <iterator> |
15 | #include <sstream> |
16 | #include <string> |
17 | #include <unordered_map> |
18 | #include <vector> |
19 | |
20 | int THPUtils_getCallable(PyObject* arg, PyObject** result) { |
21 | if (!PyCallable_Check(arg)) |
22 | return 0; |
23 | *result = arg; |
24 | return 1; |
25 | } |
26 | |
27 | std::vector<int64_t> THPUtils_unpackLongs(PyObject* arg) { |
28 | bool tuple = PyTuple_Check(arg); |
29 | bool list = PyList_Check(arg); |
30 | if (tuple || list) { |
31 | // NOLINTNEXTLINE(bugprone-branch-clone) |
32 | const auto nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); |
33 | std::vector<int64_t> sizes(nDim); |
34 | for (int i = 0; i != nDim; ++i) { |
35 | PyObject* item = |
36 | tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i); |
37 | if (!THPUtils_checkLong(item)) { |
38 | std::ostringstream oss; |
39 | oss << "expected int at position " << i |
40 | << ", but got: " << THPUtils_typename(item); |
41 | throw std::runtime_error(oss.str()); |
42 | } |
43 | sizes[i] = THPUtils_unpackLong(item); |
44 | } |
45 | return sizes; |
46 | } |
47 | throw std::runtime_error("Expected tuple or list" ); |
48 | } |
49 | |
50 | bool THPUtils_checkIntTuple(PyObject* arg) { |
51 | if (!PyTuple_Check(arg)) { |
52 | return false; |
53 | } |
54 | for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) { |
55 | if (!THPUtils_checkLong(PyTuple_GET_ITEM(arg, i))) { |
56 | return false; |
57 | } |
58 | } |
59 | return true; |
60 | } |
61 | |
62 | std::vector<int> THPUtils_unpackIntTuple(PyObject* arg) { |
63 | if (!THPUtils_checkIntTuple(arg)) { |
64 | throw std::runtime_error("Couldn't unpack int tuple" ); |
65 | } |
66 | std::vector<int> values(PyTuple_GET_SIZE(arg)); |
67 | for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) { |
68 | values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i)); |
69 | } |
70 | return values; |
71 | } |
72 | |
73 | void THPUtils_setError(const char* format, ...) { |
74 | static const size_t ERROR_BUFFER_SIZE = 1000; |
75 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
76 | char buffer[ERROR_BUFFER_SIZE]; |
77 | va_list fmt_args; |
78 | |
79 | va_start(fmt_args, format); |
80 | vsnprintf(buffer, ERROR_BUFFER_SIZE, format, fmt_args); |
81 | va_end(fmt_args); |
82 | PyErr_SetString(PyExc_RuntimeError, buffer); |
83 | } |
84 | |
85 | void THPUtils_addPyMethodDefs( |
86 | std::vector<PyMethodDef>& vector, |
87 | PyMethodDef* methods) { |
88 | if (!vector.empty()) { |
89 | // remove nullptr terminator |
90 | vector.pop_back(); |
91 | } |
92 | while (true) { |
93 | vector.push_back(*methods); |
94 | if (!methods->ml_name) { |
95 | break; |
96 | } |
97 | methods++; |
98 | } |
99 | } |
100 | |
101 | static const char* classOrTypename(PyObject* obj) { |
102 | if (PyType_Check(obj)) { |
103 | return ((PyTypeObject*)obj)->tp_name; |
104 | } |
105 | return Py_TYPE(obj)->tp_name; |
106 | } |
107 | |
108 | PyObject* THPUtils_dispatchStateless( |
109 | PyObject* tensor, |
110 | const char* name, |
111 | PyObject* args, |
112 | PyObject* kwargs) { |
113 | THPObjectPtr methods( |
114 | PyObject_GetAttrString(tensor, THP_STATELESS_ATTRIBUTE_NAME)); |
115 | if (!methods) { |
116 | return PyErr_Format( |
117 | PyExc_TypeError, |
118 | "Type %s doesn't implement stateless methods" , |
119 | classOrTypename(tensor)); |
120 | } |
121 | THPObjectPtr method(PyObject_GetAttrString(methods, name)); |
122 | if (!method) { |
123 | return PyErr_Format( |
124 | PyExc_TypeError, |
125 | "Type %s doesn't implement stateless method %s" , |
126 | classOrTypename(tensor), |
127 | name); |
128 | } |
129 | return PyObject_Call(method.get(), args, kwargs); |
130 | } |
131 | |
132 | void THPUtils_invalidArguments( |
133 | PyObject* given_args, |
134 | PyObject* given_kwargs, |
135 | const char* function_name, |
136 | size_t num_options, |
137 | ...) { |
138 | std::vector<std::string> option_strings; |
139 | va_list option_list; |
140 | va_start(option_list, num_options); |
141 | std::generate_n( |
142 | std::back_inserter(option_strings), num_options, [&option_list] { |
143 | return va_arg(option_list, const char*); |
144 | }); |
145 | va_end(option_list); |
146 | |
147 | PyErr_SetString( |
148 | PyExc_TypeError, |
149 | torch::format_invalid_args( |
150 | given_args, given_kwargs, function_name, option_strings) |
151 | .c_str()); |
152 | } |
153 | |
154 | template <> |
155 | void THPPointer<THPGenerator>::free() { |
156 | if (ptr) |
157 | Py_DECREF(ptr); |
158 | } |
159 | |
160 | template class THPPointer<THPGenerator>; |
161 | |
162 | static bool backCompatBroadcastWarn = false; |
163 | |
164 | void setBackCompatBroadcastWarn(bool warn) { |
165 | backCompatBroadcastWarn = warn; |
166 | } |
167 | |
168 | bool getBackCompatBroadcastWarn() { |
169 | return backCompatBroadcastWarn; |
170 | } |
171 | |
172 | static bool backCompatKeepdimWarn = false; |
173 | |
174 | void setBackCompatKeepdimWarn(bool warn) { |
175 | backCompatKeepdimWarn = warn; |
176 | } |
177 | |
178 | bool getBackCompatKeepdimWarn() { |
179 | return backCompatKeepdimWarn; |
180 | } |
181 | |
182 | bool maybeThrowBackCompatKeepdimWarn(char* func) { |
183 | if (getBackCompatKeepdimWarn()) { |
184 | std::ostringstream ss; |
185 | ss << "backwards compatibility: call to \"" << func |
186 | << "\" uses default value for keepdim which has changed default to False. Consider passing as kwarg." , |
187 | PyErr_WarnEx(PyExc_UserWarning, ss.str().c_str(), 1); |
188 | } |
189 | return true; |
190 | } |
191 | |
192 | template <> |
193 | void THPPointer<THPStorage>::free() { |
194 | if (ptr) |
195 | Py_DECREF(ptr); |
196 | } |
197 | |
198 | void storage_copy(at::Storage dst, at::Storage src, bool non_blocking) { |
199 | auto dst_options = c10::TensorOptions().device(dst.device()).dtype(at::kByte); |
200 | auto dst_t = at::empty({0}, {}, dst_options).set_(dst); |
201 | |
202 | auto src_options = c10::TensorOptions().device(src.device()).dtype(at::kByte); |
203 | auto src_t = at::empty({0}, {}, src_options).set_(src); |
204 | dst_t.copy_(src_t, non_blocking); |
205 | } |
206 | |
207 | void storage_fill(at::Storage self, uint8_t value) { |
208 | auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte); |
209 | auto self_t = at::empty({0}, {}, options).set_(self); |
210 | self_t.fill_(value); |
211 | } |
212 | |
213 | void storage_set(at::Storage self, ptrdiff_t idx, uint8_t value) { |
214 | TORCH_CHECK( |
215 | (idx >= 0) && (idx < static_cast<ptrdiff_t>(self.nbytes())), |
216 | "out of bounds" ); |
217 | auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte); |
218 | auto self_t = at::empty({0}, {}, options).set_(self); |
219 | self_t[idx].fill_(value); |
220 | } |
221 | |
222 | uint8_t storage_get(at::Storage self, ptrdiff_t idx) { |
223 | TORCH_CHECK( |
224 | (idx >= 0) && (idx < static_cast<ptrdiff_t>(self.nbytes())), |
225 | "out of bounds" ); |
226 | auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte); |
227 | auto self_t = at::empty({0}, {}, options).set_(self); |
228 | return self_t[idx].item<uint8_t>(); |
229 | } |
230 | |
231 | template class THPPointer<THPStorage>; |
232 | |
233 | namespace torch { |
234 | namespace gdb { |
235 | /* ~~~ misc debugging utilities ~~~ |
236 | * |
237 | * torch::gdb::* functions are NOT meant to be called by general pytorch code, |
238 | * but only from within a gdb session. As such, utils.h does not contain any |
239 | * declaration for those. |
240 | */ |
241 | |
242 | // This is a helper needed by the torch-tensor-repr gdb command. |
243 | // Return an human-readable representation of the given Tensor. The resulting |
244 | // string is stored into a malloc()ed buffer. The caller is responsible to |
245 | // free() it. We use malloc() instead of new[] because it's much easier to |
246 | // call free than delete[] from withing gdb. |
247 | // Currently the code for computing the repr of a tensor is written in Python, |
248 | // so we need to wrap the Tensor into a Python object first. |
249 | char* tensor_repr(at::Tensor tensor) { |
250 | PyGILState_STATE gil = PyGILState_Ensure(); |
251 | PyObject* pytensor = nullptr; |
252 | PyObject* repr = nullptr; |
253 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
254 | Py_ssize_t bufsize; |
255 | const char* buf = nullptr; |
256 | char* result = nullptr; |
257 | |
258 | pytensor = THPVariable_Wrap(at::Tensor(tensor)); |
259 | if (!pytensor) |
260 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
261 | goto error; |
262 | repr = PyObject_Repr(pytensor); |
263 | if (!repr) |
264 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
265 | goto error; |
266 | buf = PyUnicode_AsUTF8AndSize(repr, &bufsize); |
267 | if (!buf) |
268 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
269 | goto error; |
270 | // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) |
271 | result = |
272 | static_cast<char*>(malloc(bufsize + 1)); // account for the trailing \0 |
273 | if (!result) { |
274 | fprintf(stderr, "cannot allocate memory for the result\n" ); |
275 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
276 | goto error; |
277 | } |
278 | // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.strcpy) |
279 | strcpy(result, buf); |
280 | Py_XDECREF(pytensor); |
281 | Py_XDECREF(repr); |
282 | PyGILState_Release(gil); |
283 | return result; |
284 | |
285 | error: |
286 | fprintf(stderr, "torch::gdb::tensor_repr: unexpected error\n" ); |
287 | if (PyErr_Occurred()) |
288 | PyErr_Print(); |
289 | Py_XDECREF(pytensor); |
290 | Py_XDECREF(repr); |
291 | // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) |
292 | free(result); |
293 | PyGILState_Release(gil); |
294 | return nullptr; |
295 | } |
296 | |
297 | } // namespace gdb |
298 | } // namespace torch |
299 | |
300 | namespace pybind11 { |
301 | namespace detail { |
302 | |
303 | bool type_caster<at::Tensor>::load(handle src, bool) { |
304 | PyObject* obj = src.ptr(); |
305 | if (THPVariable_Check(obj)) { |
306 | value = THPVariable_Unpack(obj); |
307 | return true; |
308 | } |
309 | return false; |
310 | } |
311 | |
312 | handle type_caster<at::Tensor>::cast( |
313 | const at::Tensor& src, |
314 | return_value_policy /* policy */, |
315 | handle /* parent */) { |
316 | return handle(THPVariable_Wrap(src)); |
317 | } |
318 | |
319 | bool type_caster<at::IntArrayRef>::load(handle src, bool) { |
320 | PyObject* source = src.ptr(); |
321 | auto tuple = PyTuple_Check(source); |
322 | if (tuple || PyList_Check(source)) { |
323 | // NOLINTNEXTLINE(bugprone-branch-clone) |
324 | const auto size = |
325 | tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); |
326 | v_value.resize(size); |
327 | for (const auto idx : c10::irange(size)) { |
328 | PyObject* obj = |
329 | tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx); |
330 | if (THPVariable_Check(obj)) { |
331 | v_value[idx] = THPVariable_Unpack(obj).item<int64_t>(); |
332 | } else if (PyLong_Check(obj)) { |
333 | // use THPUtils_unpackLong after it is safe to include |
334 | // python_numbers.h |
335 | v_value[idx] = THPUtils_unpackLong(obj); |
336 | } else { |
337 | return false; |
338 | } |
339 | } |
340 | value = v_value; |
341 | return true; |
342 | } |
343 | return false; |
344 | } |
345 | handle type_caster<at::IntArrayRef>::cast( |
346 | at::IntArrayRef src, |
347 | return_value_policy /* policy */, |
348 | handle /* parent */) { |
349 | return handle(THPUtils_packInt64Array(src.size(), src.data())); |
350 | } |
351 | |
352 | bool type_caster<at::SymIntArrayRef>::load(handle src, bool) { |
353 | PyObject* source = src.ptr(); |
354 | |
355 | auto tuple = PyTuple_Check(source); |
356 | if (tuple || PyList_Check(source)) { |
357 | // NOLINTNEXTLINE(bugprone-branch-clone) |
358 | const auto size = |
359 | tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); |
360 | v_value.resize(size); |
361 | for (const auto idx : c10::irange(size)) { |
362 | PyObject* obj = |
363 | tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx); |
364 | |
365 | if (THPVariable_Check(obj)) { |
366 | // TODO: this is for consistency with IntArrayRef but arguably |
367 | // we shouldn't really allow this on pybind11 casters |
368 | v_value[idx] = THPVariable_Unpack(obj).item<int64_t>(); |
369 | } else if (torch::is_symint(py::handle(obj))) { |
370 | v_value[idx] = py::handle(obj).cast<c10::SymInt>(); |
371 | } else if (PyLong_Check(obj)) { |
372 | v_value[idx] = c10::SymInt(THPUtils_unpackIndex(obj)); |
373 | } else { |
374 | return false; |
375 | } |
376 | } |
377 | value = v_value; |
378 | return true; |
379 | } |
380 | return false; |
381 | } |
382 | handle type_caster<at::SymIntArrayRef>::cast( |
383 | at::SymIntArrayRef src, |
384 | return_value_policy /* policy */, |
385 | handle /* parent */) { |
386 | py::list t(src.size()); |
387 | for (const auto i : c10::irange(src.size())) { |
388 | t[i] = py::cast(src[i]); |
389 | } |
390 | return t.release(); |
391 | } |
392 | |
393 | bool type_caster<at::ArrayRef<c10::SymNode>>::load(handle src, bool) { |
394 | TORCH_INTERNAL_ASSERT(0, "NYI" ); |
395 | } |
396 | handle type_caster<at::ArrayRef<c10::SymNode>>::cast( |
397 | at::ArrayRef<c10::SymNode> src, |
398 | return_value_policy /* policy */, |
399 | handle /* parent */) { |
400 | py::list t(src.size()); |
401 | for (const auto i : c10::irange(src.size())) { |
402 | // TODO: this is terrible but I don't know how to override when |
403 | // the SymNode is also explicitly cast by py::cast |
404 | auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(src[i].get()); |
405 | if (py_node) { |
406 | // Return the Python directly (unwrap) |
407 | t[i] = py_node->getPyObj(); |
408 | } else { |
409 | t[i] = py::cast(src[i]); |
410 | } |
411 | } |
412 | return t.release(); |
413 | } |
414 | |
415 | } // namespace detail |
416 | } // namespace pybind11 |
417 | |