1#include <torch/csrc/DynamicTypes.h>
2#include <torch/csrc/THP.h>
3#include <torch/csrc/autograd/variable.h>
4#include <torch/csrc/python_headers.h>
5#include <torch/csrc/utils/invalid_arguments.h>
6#include <torch/csrc/utils/python_strings.h>
7#include <torch/csrc/utils/python_symnode.h>
8#include <torch/csrc/utils/python_tuples.h>
9
10#include <torch/csrc/Export.h>
11
12#include <algorithm>
13#include <cstdarg>
14#include <iterator>
15#include <sstream>
16#include <string>
17#include <unordered_map>
18#include <vector>
19
20int THPUtils_getCallable(PyObject* arg, PyObject** result) {
21 if (!PyCallable_Check(arg))
22 return 0;
23 *result = arg;
24 return 1;
25}
26
27std::vector<int64_t> THPUtils_unpackLongs(PyObject* arg) {
28 bool tuple = PyTuple_Check(arg);
29 bool list = PyList_Check(arg);
30 if (tuple || list) {
31 // NOLINTNEXTLINE(bugprone-branch-clone)
32 const auto nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
33 std::vector<int64_t> sizes(nDim);
34 for (int i = 0; i != nDim; ++i) {
35 PyObject* item =
36 tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i);
37 if (!THPUtils_checkLong(item)) {
38 std::ostringstream oss;
39 oss << "expected int at position " << i
40 << ", but got: " << THPUtils_typename(item);
41 throw std::runtime_error(oss.str());
42 }
43 sizes[i] = THPUtils_unpackLong(item);
44 }
45 return sizes;
46 }
47 throw std::runtime_error("Expected tuple or list");
48}
49
50bool THPUtils_checkIntTuple(PyObject* arg) {
51 if (!PyTuple_Check(arg)) {
52 return false;
53 }
54 for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
55 if (!THPUtils_checkLong(PyTuple_GET_ITEM(arg, i))) {
56 return false;
57 }
58 }
59 return true;
60}
61
62std::vector<int> THPUtils_unpackIntTuple(PyObject* arg) {
63 if (!THPUtils_checkIntTuple(arg)) {
64 throw std::runtime_error("Couldn't unpack int tuple");
65 }
66 std::vector<int> values(PyTuple_GET_SIZE(arg));
67 for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
68 values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i));
69 }
70 return values;
71}
72
73void THPUtils_setError(const char* format, ...) {
74 static const size_t ERROR_BUFFER_SIZE = 1000;
75 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
76 char buffer[ERROR_BUFFER_SIZE];
77 va_list fmt_args;
78
79 va_start(fmt_args, format);
80 vsnprintf(buffer, ERROR_BUFFER_SIZE, format, fmt_args);
81 va_end(fmt_args);
82 PyErr_SetString(PyExc_RuntimeError, buffer);
83}
84
85void THPUtils_addPyMethodDefs(
86 std::vector<PyMethodDef>& vector,
87 PyMethodDef* methods) {
88 if (!vector.empty()) {
89 // remove nullptr terminator
90 vector.pop_back();
91 }
92 while (true) {
93 vector.push_back(*methods);
94 if (!methods->ml_name) {
95 break;
96 }
97 methods++;
98 }
99}
100
101static const char* classOrTypename(PyObject* obj) {
102 if (PyType_Check(obj)) {
103 return ((PyTypeObject*)obj)->tp_name;
104 }
105 return Py_TYPE(obj)->tp_name;
106}
107
108PyObject* THPUtils_dispatchStateless(
109 PyObject* tensor,
110 const char* name,
111 PyObject* args,
112 PyObject* kwargs) {
113 THPObjectPtr methods(
114 PyObject_GetAttrString(tensor, THP_STATELESS_ATTRIBUTE_NAME));
115 if (!methods) {
116 return PyErr_Format(
117 PyExc_TypeError,
118 "Type %s doesn't implement stateless methods",
119 classOrTypename(tensor));
120 }
121 THPObjectPtr method(PyObject_GetAttrString(methods, name));
122 if (!method) {
123 return PyErr_Format(
124 PyExc_TypeError,
125 "Type %s doesn't implement stateless method %s",
126 classOrTypename(tensor),
127 name);
128 }
129 return PyObject_Call(method.get(), args, kwargs);
130}
131
132void THPUtils_invalidArguments(
133 PyObject* given_args,
134 PyObject* given_kwargs,
135 const char* function_name,
136 size_t num_options,
137 ...) {
138 std::vector<std::string> option_strings;
139 va_list option_list;
140 va_start(option_list, num_options);
141 std::generate_n(
142 std::back_inserter(option_strings), num_options, [&option_list] {
143 return va_arg(option_list, const char*);
144 });
145 va_end(option_list);
146
147 PyErr_SetString(
148 PyExc_TypeError,
149 torch::format_invalid_args(
150 given_args, given_kwargs, function_name, option_strings)
151 .c_str());
152}
153
154template <>
155void THPPointer<THPGenerator>::free() {
156 if (ptr)
157 Py_DECREF(ptr);
158}
159
160template class THPPointer<THPGenerator>;
161
162static bool backCompatBroadcastWarn = false;
163
164void setBackCompatBroadcastWarn(bool warn) {
165 backCompatBroadcastWarn = warn;
166}
167
168bool getBackCompatBroadcastWarn() {
169 return backCompatBroadcastWarn;
170}
171
172static bool backCompatKeepdimWarn = false;
173
174void setBackCompatKeepdimWarn(bool warn) {
175 backCompatKeepdimWarn = warn;
176}
177
178bool getBackCompatKeepdimWarn() {
179 return backCompatKeepdimWarn;
180}
181
182bool maybeThrowBackCompatKeepdimWarn(char* func) {
183 if (getBackCompatKeepdimWarn()) {
184 std::ostringstream ss;
185 ss << "backwards compatibility: call to \"" << func
186 << "\" uses default value for keepdim which has changed default to False. Consider passing as kwarg.",
187 PyErr_WarnEx(PyExc_UserWarning, ss.str().c_str(), 1);
188 }
189 return true;
190}
191
192template <>
193void THPPointer<THPStorage>::free() {
194 if (ptr)
195 Py_DECREF(ptr);
196}
197
198void storage_copy(at::Storage dst, at::Storage src, bool non_blocking) {
199 auto dst_options = c10::TensorOptions().device(dst.device()).dtype(at::kByte);
200 auto dst_t = at::empty({0}, {}, dst_options).set_(dst);
201
202 auto src_options = c10::TensorOptions().device(src.device()).dtype(at::kByte);
203 auto src_t = at::empty({0}, {}, src_options).set_(src);
204 dst_t.copy_(src_t, non_blocking);
205}
206
207void storage_fill(at::Storage self, uint8_t value) {
208 auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
209 auto self_t = at::empty({0}, {}, options).set_(self);
210 self_t.fill_(value);
211}
212
213void storage_set(at::Storage self, ptrdiff_t idx, uint8_t value) {
214 TORCH_CHECK(
215 (idx >= 0) && (idx < static_cast<ptrdiff_t>(self.nbytes())),
216 "out of bounds");
217 auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
218 auto self_t = at::empty({0}, {}, options).set_(self);
219 self_t[idx].fill_(value);
220}
221
222uint8_t storage_get(at::Storage self, ptrdiff_t idx) {
223 TORCH_CHECK(
224 (idx >= 0) && (idx < static_cast<ptrdiff_t>(self.nbytes())),
225 "out of bounds");
226 auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
227 auto self_t = at::empty({0}, {}, options).set_(self);
228 return self_t[idx].item<uint8_t>();
229}
230
231template class THPPointer<THPStorage>;
232
233namespace torch {
234namespace gdb {
235/* ~~~ misc debugging utilities ~~~
236 *
237 * torch::gdb::* functions are NOT meant to be called by general pytorch code,
238 * but only from within a gdb session. As such, utils.h does not contain any
239 * declaration for those.
240 */
241
242// This is a helper needed by the torch-tensor-repr gdb command.
243// Return an human-readable representation of the given Tensor. The resulting
244// string is stored into a malloc()ed buffer. The caller is responsible to
245// free() it. We use malloc() instead of new[] because it's much easier to
246// call free than delete[] from withing gdb.
247// Currently the code for computing the repr of a tensor is written in Python,
248// so we need to wrap the Tensor into a Python object first.
249char* tensor_repr(at::Tensor tensor) {
250 PyGILState_STATE gil = PyGILState_Ensure();
251 PyObject* pytensor = nullptr;
252 PyObject* repr = nullptr;
253 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
254 Py_ssize_t bufsize;
255 const char* buf = nullptr;
256 char* result = nullptr;
257
258 pytensor = THPVariable_Wrap(at::Tensor(tensor));
259 if (!pytensor)
260 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
261 goto error;
262 repr = PyObject_Repr(pytensor);
263 if (!repr)
264 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
265 goto error;
266 buf = PyUnicode_AsUTF8AndSize(repr, &bufsize);
267 if (!buf)
268 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
269 goto error;
270 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
271 result =
272 static_cast<char*>(malloc(bufsize + 1)); // account for the trailing \0
273 if (!result) {
274 fprintf(stderr, "cannot allocate memory for the result\n");
275 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
276 goto error;
277 }
278 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.strcpy)
279 strcpy(result, buf);
280 Py_XDECREF(pytensor);
281 Py_XDECREF(repr);
282 PyGILState_Release(gil);
283 return result;
284
285error:
286 fprintf(stderr, "torch::gdb::tensor_repr: unexpected error\n");
287 if (PyErr_Occurred())
288 PyErr_Print();
289 Py_XDECREF(pytensor);
290 Py_XDECREF(repr);
291 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
292 free(result);
293 PyGILState_Release(gil);
294 return nullptr;
295}
296
297} // namespace gdb
298} // namespace torch
299
300namespace pybind11 {
301namespace detail {
302
303bool type_caster<at::Tensor>::load(handle src, bool) {
304 PyObject* obj = src.ptr();
305 if (THPVariable_Check(obj)) {
306 value = THPVariable_Unpack(obj);
307 return true;
308 }
309 return false;
310}
311
312handle type_caster<at::Tensor>::cast(
313 const at::Tensor& src,
314 return_value_policy /* policy */,
315 handle /* parent */) {
316 return handle(THPVariable_Wrap(src));
317}
318
319bool type_caster<at::IntArrayRef>::load(handle src, bool) {
320 PyObject* source = src.ptr();
321 auto tuple = PyTuple_Check(source);
322 if (tuple || PyList_Check(source)) {
323 // NOLINTNEXTLINE(bugprone-branch-clone)
324 const auto size =
325 tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
326 v_value.resize(size);
327 for (const auto idx : c10::irange(size)) {
328 PyObject* obj =
329 tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
330 if (THPVariable_Check(obj)) {
331 v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
332 } else if (PyLong_Check(obj)) {
333 // use THPUtils_unpackLong after it is safe to include
334 // python_numbers.h
335 v_value[idx] = THPUtils_unpackLong(obj);
336 } else {
337 return false;
338 }
339 }
340 value = v_value;
341 return true;
342 }
343 return false;
344}
345handle type_caster<at::IntArrayRef>::cast(
346 at::IntArrayRef src,
347 return_value_policy /* policy */,
348 handle /* parent */) {
349 return handle(THPUtils_packInt64Array(src.size(), src.data()));
350}
351
352bool type_caster<at::SymIntArrayRef>::load(handle src, bool) {
353 PyObject* source = src.ptr();
354
355 auto tuple = PyTuple_Check(source);
356 if (tuple || PyList_Check(source)) {
357 // NOLINTNEXTLINE(bugprone-branch-clone)
358 const auto size =
359 tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
360 v_value.resize(size);
361 for (const auto idx : c10::irange(size)) {
362 PyObject* obj =
363 tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
364
365 if (THPVariable_Check(obj)) {
366 // TODO: this is for consistency with IntArrayRef but arguably
367 // we shouldn't really allow this on pybind11 casters
368 v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
369 } else if (torch::is_symint(py::handle(obj))) {
370 v_value[idx] = py::handle(obj).cast<c10::SymInt>();
371 } else if (PyLong_Check(obj)) {
372 v_value[idx] = c10::SymInt(THPUtils_unpackIndex(obj));
373 } else {
374 return false;
375 }
376 }
377 value = v_value;
378 return true;
379 }
380 return false;
381}
382handle type_caster<at::SymIntArrayRef>::cast(
383 at::SymIntArrayRef src,
384 return_value_policy /* policy */,
385 handle /* parent */) {
386 py::list t(src.size());
387 for (const auto i : c10::irange(src.size())) {
388 t[i] = py::cast(src[i]);
389 }
390 return t.release();
391}
392
393bool type_caster<at::ArrayRef<c10::SymNode>>::load(handle src, bool) {
394 TORCH_INTERNAL_ASSERT(0, "NYI");
395}
396handle type_caster<at::ArrayRef<c10::SymNode>>::cast(
397 at::ArrayRef<c10::SymNode> src,
398 return_value_policy /* policy */,
399 handle /* parent */) {
400 py::list t(src.size());
401 for (const auto i : c10::irange(src.size())) {
402 // TODO: this is terrible but I don't know how to override when
403 // the SymNode is also explicitly cast by py::cast
404 auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(src[i].get());
405 if (py_node) {
406 // Return the Python directly (unwrap)
407 t[i] = py_node->getPyObj();
408 } else {
409 t[i] = py::cast(src[i]);
410 }
411 }
412 return t.release();
413}
414
415} // namespace detail
416} // namespace pybind11
417