1#pragma once
2
3// Parse arguments to Python functions implemented in C++
4// This is similar to PyArg_ParseTupleAndKeywords(), but specifically handles
5// the types relevant to PyTorch and distinguishes between overloaded function
6// signatures.
7//
8// Example:
9//
10// static PythonArgParser parser({
11// "norm(Scalar p, int64_t dim, bool keepdim=False)",
12// "norm(Scalar p=2)",
13// });
14// ParsedArgs<3> parsed_args;
15// auto r = parser.parse(args, kwargs, parsed_args);
16// if (r.idx == 0) {
17// norm(r.scalar(0), r.int64(1), r.bool(0));
18// } else {
19// norm(r.scalar(0));
20// }
21//
22// We auto-generate most uses of PythonArgParser; the generated files
23// are torch/csrc/autograd/generated/python_*.cpp
24//
25// Some gotchas that you should watch out for:
26//
27// - Note [Order of overloads matters]
28// Order of overloads matters. A set of input arguments may
29// bind to multiple argument specs; we will always pick the
30// first one in PythonArgParser. However, when you are writing
31// overloads in, e.g., native_functions.yaml, you don't have to
32// worry about what order you write them, because the code
33// generation logic always gives the overloads a canonical
34// order, where Tensor overloads come first, before Scalar overloads.
35// This logic is in sort_declarations in
36// tools/autograd/gen_python_functions.py
37//
38// - Zero-dim tensors (e.g., torch.tensor(2)) bind to both
39// Scalar and Tensor, UNLESS they require grad (in which case
40// they only bind to Tensor).
41
42#include <pybind11/pytypes.h>
43#include <torch/csrc/python_headers.h>
44
45#include <torch/csrc/Device.h>
46#include <torch/csrc/Dtype.h>
47#include <torch/csrc/DynamicTypes.h>
48#include <torch/csrc/Exceptions.h>
49#include <torch/csrc/Generator.h>
50#include <torch/csrc/Layout.h>
51#include <torch/csrc/MemoryFormat.h>
52#include <torch/csrc/QScheme.h>
53#include <torch/csrc/Stream.h>
54#include <torch/csrc/autograd/python_variable.h>
55#include <torch/csrc/autograd/variable.h>
56#include <torch/csrc/jit/frontend/tracer.h>
57#include <torch/csrc/python_dimname.h>
58#include <torch/csrc/tensor/python_tensor.h>
59#include <torch/csrc/utils/disable_torch_function.h>
60#include <torch/csrc/utils/object_ptr.h>
61#include <torch/csrc/utils/pybind.h>
62#include <torch/csrc/utils/python_numbers.h>
63#include <torch/csrc/utils/python_strings.h>
64#include <torch/csrc/utils/python_symnode.h>
65#include <torch/csrc/utils/six.h>
66
67#include <ATen/PythonTorchFunctionTLS.h>
68#include <ATen/core/Tensor.h>
69#include <c10/util/Exception.h>
70#include <c10/util/irange.h>
71
72#include <c10/core/SymFloat.h>
73#include <c10/core/SymNodeImpl.h>
74
75#include <array>
76#include <cstddef>
77#include <memory>
78#include <sstream>
79#include <string>
80#include <vector>
81
82inline bool THPUtils_checkScalar(PyObject* obj) {
83#ifdef USE_NUMPY
84 if (torch::utils::is_numpy_scalar(obj)) {
85 return true;
86 }
87#endif
88 return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
89 torch::is_symint(py::handle(obj)) || torch::is_symfloat(py::handle(obj));
90}
91
92namespace torch {
93
94bool should_allow_numbers_as_tensors(const std::string& name);
95
96enum class ParameterType {
97 TENSOR,
98 SCALAR,
99 INT64,
100 SYM_INT,
101 DOUBLE,
102 COMPLEX,
103 TENSOR_LIST,
104 INT_LIST,
105 GENERATOR,
106 BOOL,
107 STORAGE,
108 PYOBJECT,
109 SCALARTYPE,
110 LAYOUT,
111 MEMORY_FORMAT,
112 DEVICE,
113 STREAM,
114 STRING,
115 DIMNAME,
116 DIMNAME_LIST,
117 QSCHEME,
118 FLOAT_LIST,
119 SCALAR_LIST,
120 SYM_INT_LIST
121};
122
123struct FunctionParameter;
124struct FunctionSignature;
125struct PythonArgs;
126
127// Contains bound Python arguments in declaration order
128template <int N>
129struct ParsedArgs {
130 ParsedArgs() : args() {}
131 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
132 PyObject* args[N];
133};
134
135struct PythonArgParser {
136 explicit PythonArgParser(
137 std::vector<std::string> fmts,
138 bool traceable = false);
139
140 // meant only for `torch` functions.
141 template <int N>
142 inline PythonArgs parse(
143 PyObject* self,
144 PyObject* args,
145 PyObject* kwargs,
146 ParsedArgs<N>& dst);
147
148 template <int N>
149 inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst);
150
151 inline PythonArgs parse(PyObject* self, ParsedArgs<0>& dst);
152
153 // Formatted strings of non-hidden signatures
154 std::vector<std::string> get_signatures() const;
155
156 private:
157 [[noreturn]]
158 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
159 void
160 print_error(
161 PyObject* self,
162 PyObject* args,
163 PyObject* kwargs,
164 PyObject* parsed_args[]);
165 void check_deprecated(const FunctionSignature& signature);
166 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
167 PythonArgs raw_parse(
168 PyObject* self,
169 PyObject* args,
170 PyObject* kwargs,
171 PyObject* parsed_args[]);
172
173 std::vector<FunctionSignature> signatures_;
174 std::string function_name;
175 size_t max_args;
176 bool traceable;
177};
178
179struct PYBIND11_EXPORT FunctionSignature {
180 explicit FunctionSignature(const std::string& fmt, int index);
181
182 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
183 bool parse(
184 PyObject* self,
185 PyObject* args,
186 PyObject* kwargs,
187 PyObject* dst[],
188 bool raise_exception);
189
190 std::string toString() const;
191
192 std::string name;
193 std::vector<FunctionParameter> params;
194 std::vector<py::handle> overloaded_args;
195 size_t min_args;
196 size_t max_args;
197 size_t max_pos_args;
198 int index;
199 bool hidden;
200 bool deprecated;
201 bool disable_torch_function;
202};
203
204struct PythonArgs {
205 PythonArgs(
206 bool traceable,
207 const FunctionSignature& signature,
208 PyObject** args)
209 : idx(signature.index),
210 traceable(traceable),
211 signature(signature),
212 args(args) {}
213
214 int idx;
215 bool traceable;
216 const FunctionSignature& signature;
217 PyObject** args;
218
219 inline bool has_torch_function();
220 inline std::string get_func_name();
221 inline at::Tensor tensor(int i);
222 inline c10::optional<at::Tensor> optionalTensor(int i);
223 inline at::Scalar scalar(int i);
224 inline at::Scalar scalarWithDefault(int i, const at::Scalar& default_scalar);
225 inline std::vector<at::Scalar> scalarlist(int i);
226 inline std::vector<at::Tensor> tensorlist(int i);
227 inline torch::List<c10::optional<at::Tensor>> list_of_optional_tensors(int i);
228 template <int N>
229 inline std::array<at::Tensor, N> tensorlist_n(int i);
230 inline std::vector<int64_t> intlist(int i);
231 inline std::vector<c10::SymInt> symintlist(int i);
232 inline c10::OptionalArray<int64_t> intlistOptional(int i);
233 inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
234 inline std::vector<int64_t> intlistWithDefault(
235 int i,
236 std::vector<int64_t> default_intlist);
237 inline c10::optional<at::Generator> generator(int i);
238 inline at::Storage storage(int i);
239 inline at::Storage storage(
240 int i,
241 at::ScalarType& storage_scalar_type,
242 bool& is_typed_storage);
243 inline c10::Stream stream(int i);
244 inline at::ScalarType scalartype(int i);
245 inline at::ScalarType scalartypeWithDefault(
246 int i,
247 at::ScalarType default_scalartype);
248 inline c10::optional<at::ScalarType> scalartypeOptional(int i);
249 inline c10::optional<at::Scalar> scalarOptional(int i);
250 inline c10::optional<int64_t> toInt64Optional(int i);
251 inline c10::optional<c10::SymInt> toSymIntOptional(int i);
252 inline c10::optional<bool> toBoolOptional(int i);
253 inline c10::optional<double> toDoubleOptional(int i);
254 inline c10::OptionalArray<double> doublelistOptional(int i);
255 inline std::vector<double> doublelist(int i);
256 inline std::vector<double> getDoublelist(int i);
257 inline at::Layout layout(int i);
258 inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
259 inline c10::optional<at::Layout> layoutOptional(int i);
260 inline at::Device device(int i);
261 inline at::Device deviceWithDefault(int i, const at::Device& default_device);
262 inline c10::optional<at::Device> deviceOptional(int i);
263 inline at::Dimname dimname(int i);
264 inline std::vector<at::Dimname> dimnamelist(int i);
265 inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
266 inline at::MemoryFormat memoryformat(int i);
267 inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
268 inline at::QScheme toQScheme(int i);
269 inline std::string string(int i);
270 inline std::string stringWithDefault(int i, const std::string& default_str);
271 inline c10::optional<std::string> stringOptional(int i);
272 inline c10::string_view stringView(int i);
273 inline c10::string_view stringViewWithDefault(
274 int i,
275 const c10::string_view default_str);
276 inline c10::optional<c10::string_view> stringViewOptional(int i);
277 inline PyObject* pyobject(int i);
278 inline int64_t toInt64(int i);
279 inline c10::SymInt toSymInt(int i);
280 inline int64_t toInt64WithDefault(int i, int64_t default_int);
281 inline double toDouble(int i);
282 inline double toDoubleWithDefault(int i, double default_double);
283 inline c10::complex<double> toComplex(int i);
284 inline c10::complex<double> toComplexWithDefault(
285 int i,
286 c10::complex<double> default_complex);
287 inline bool toBool(int i);
288 inline bool toBoolWithDefault(int i, bool default_bool);
289 inline bool isNone(int i);
290
291 private:
292 at::Tensor tensor_slow(int i);
293 at::Scalar scalar_slow(int i);
294 at::Scalar scalar_slow(PyObject* arg);
295};
296
297struct FunctionParameter {
298 FunctionParameter(const std::string& fmt, bool keyword_only);
299
300 bool check(
301 PyObject* obj,
302 std::vector<py::handle>& overloaded_args,
303 int argnum,
304 int64_t* failed_idx = nullptr);
305
306 void set_default_str(const std::string& str);
307 std::string type_name() const;
308
309 ParameterType type_;
310 bool optional;
311 bool allow_none;
312 bool keyword_only;
313 bool allow_numbers_as_tensors = false;
314 int size;
315 std::string name;
316 // having this as a raw PyObject * will presumably leak it, but these are only
317 // held by static objects anyway, and Py_Finalize can already be called when
318 // this is destructed.
319 PyObject* python_name;
320 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
321 at::SmallVector<PyObject*, 5> numpy_python_names;
322 at::Scalar default_scalar;
323 std::vector<int64_t> default_intlist;
324 std::string default_string;
325 union {
326 bool default_bool;
327 int64_t default_int;
328 double default_double;
329 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
330 double default_complex[2]; // see Scalar
331 at::ScalarType default_scalartype;
332 at::Layout default_layout;
333 };
334};
335
336template <int N>
337inline PythonArgs PythonArgParser::parse(
338 PyObject* self,
339 PyObject* args,
340 PyObject* kwargs,
341 ParsedArgs<N>& dst) {
342 if (N < max_args) {
343 throw ValueError(
344 "PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)",
345 (int)max_args,
346 N);
347 }
348 return raw_parse(self, args, kwargs, dst.args);
349}
350
351template <int N>
352inline PythonArgs PythonArgParser::parse(
353 PyObject* args,
354 PyObject* kwargs,
355 ParsedArgs<N>& dst) {
356 return parse(nullptr, args, kwargs, dst);
357}
358
359inline PythonArgs PythonArgParser::parse(PyObject* self, ParsedArgs<0>& dst) {
360 return parse(self, nullptr, nullptr, dst);
361}
362
363inline bool PythonArgs::has_torch_function() {
364 return !this->signature.overloaded_args.empty() ||
365 at::impl::torch_function_mode_enabled();
366}
367
368inline std::string PythonArgs::get_func_name() {
369 return signature.name;
370}
371
372// TODO: this can return MaybeOwned
373inline at::Tensor PythonArgs::tensor(int i) {
374 if (args[i] && THPVariable_CheckExact(args[i])) {
375 return THPVariable_Unpack(args[i]);
376 }
377 return tensor_slow(i);
378}
379
380inline c10::optional<at::Tensor> PythonArgs::optionalTensor(int i) {
381 at::Tensor t = tensor(i);
382 // NOLINTNEXTLINE(bugprone-branch-clone)
383 if (t.defined()) {
384 return t;
385 } else {
386 return c10::nullopt;
387 }
388}
389
390inline at::Scalar PythonArgs::scalar(int i) {
391 if (!args[i])
392 return signature.params[i].default_scalar;
393 return scalar_slow(i);
394}
395
396inline std::vector<at::Scalar> PythonArgs::scalarlist(int i) {
397 if (!args[i])
398 return std::vector<at::Scalar>();
399 auto tuple = six::isTuple(args[i]);
400 THPObjectPtr arg = six::maybeAsTuple(args[i]);
401 // NOLINTNEXTLINE(bugprone-branch-clone)
402 auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
403 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
404 std::vector<at::Scalar> res(size);
405 for (const auto idx : c10::irange(size)) {
406 PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
407 : PyList_GET_ITEM(arg.get(), idx);
408 res[idx] = scalar_slow(obj);
409 }
410 return res;
411}
412
413inline at::Scalar PythonArgs::scalarWithDefault(
414 int i,
415 const at::Scalar& default_scalar) {
416 if (!args[i])
417 return default_scalar;
418 return scalar_slow(i);
419}
420
421inline c10::optional<at::Scalar> PythonArgs::scalarOptional(int i) {
422 if (!args[i])
423 return c10::nullopt;
424 return scalar_slow(i);
425}
426
427inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
428 if (!args[i])
429 return std::vector<at::Tensor>();
430 auto tuple = six::isTuple(args[i]);
431 THPObjectPtr arg = six::maybeAsTuple(args[i]);
432 // NOLINTNEXTLINE(bugprone-branch-clone)
433 auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
434 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
435 std::vector<at::Tensor> res(size);
436 for (const auto idx : c10::irange(size)) {
437 PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
438 : PyList_GET_ITEM(arg.get(), idx);
439 // This is checked by the argument parser so it's safe to cast without
440 // checking if this is a tensor first
441 res[idx] = THPVariable_Unpack(obj);
442 }
443 return res;
444}
445
446inline torch::List<c10::optional<at::Tensor>> PythonArgs::
447 list_of_optional_tensors(int i) {
448 if (!args[i])
449 return torch::List<c10::optional<at::Tensor>>();
450 auto tuple = six::isTuple(args[i]);
451 THPObjectPtr arg = six::maybeAsTuple(args[i]);
452 // NOLINTNEXTLINE(bugprone-branch-clone)
453 auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
454 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
455 torch::List<c10::optional<at::Tensor>> res;
456 res.reserve(size);
457 for (const auto idx : c10::irange(size)) {
458 PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
459 : PyList_GET_ITEM(arg.get(), idx);
460 // This is checked by the argument parser so it's safe to cast without
461 // checking if this is a tensor first
462 res.push_back(THPVariable_Unpack(obj));
463 }
464 return res;
465}
466
467template <int N>
468inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
469 auto res = std::array<at::Tensor, N>();
470 if (!args[i])
471 return res;
472 auto tuple = six::isTuple(args[i]);
473 THPObjectPtr arg = six::maybeAsTuple(args[i]);
474 // NOLINTNEXTLINE(bugprone-branch-clone)
475 auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
476 if (size != N) {
477 throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
478 }
479 for (const auto idx : c10::irange(size)) {
480 PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
481 : PyList_GET_ITEM(arg.get(), idx);
482 // This is checked by the argument parser so it's safe to cast without
483 // checking if this is a tensor first
484 res[idx] = THPVariable_Unpack(obj);
485 }
486 return res;
487}
488
489inline std::vector<int64_t> PythonArgs::intlist(int i) {
490 return intlistWithDefault(i, signature.params[i].default_intlist);
491}
492
493inline PyObject* toPyObject(c10::SymInt symint) {
494 if (symint.is_symbolic()) {
495 auto r = py::cast(symint).release().ptr();
496 TORCH_INTERNAL_ASSERT(r);
497 return r;
498 } else {
499 return THPUtils_packInt64(symint.as_int_unchecked());
500 }
501}
502
503inline void throw_intlist_exception(
504 const torch::PythonArgs* args,
505 size_t i,
506 PyObject* obj,
507 size_t idx) {
508 throw TypeError(
509 "%s(): argument '%s' must be %s, but found element of type %s at pos %ld",
510 args->signature.name.c_str(),
511 args->signature.params[i].name.c_str(),
512 args->signature.params[i].type_name().c_str(),
513 Py_TYPE(obj)->tp_name,
514 idx + 1);
515}
516
517inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
518 if (!args[i]) {
519 return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
520 return c10::SymInt(di);
521 });
522 }
523
524 const auto size1 = signature.params[i].size;
525 if (size1 > 0 && THPUtils_checkLong(args[i])) {
526 return std::vector<c10::SymInt>(
527 size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
528 }
529
530 if (size1 > 0 && torch::is_symint(py::handle(args[i]))) {
531 auto si = py::handle(args[i]).cast<c10::SymInt>();
532 return std::vector<c10::SymInt>(size1, si);
533 }
534
535 PyObject* arg = args[i];
536 auto tuple = PyTuple_Check(arg);
537 // NOLINTNEXTLINE(bugprone-branch-clone)
538 const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
539 std::vector<c10::SymInt> res;
540 res.reserve(size2);
541 for (const auto idx : c10::irange(size2)) {
542 PyObject* obj =
543 tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
544
545 // Elements of torch.Size are tensors during tracing, and we need to
546 // record extra information before they are turned into an IntArrayRef
547 if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
548 auto& var = THPVariable_Unpack(obj);
549 jit::tracer::ArgumentStash::stashIntArrayRefElem(
550 signature.params[i].name, size2, idx, var);
551 try {
552 res.emplace_back(var.item<int64_t>());
553 continue;
554 } catch (std::exception& e) {
555 throw_intlist_exception(this, i, obj, idx);
556 }
557 continue;
558 } else {
559 // convert tensor to scalar outside of try / catch,
560 // so that Tensor subclass exceptions will not be caught.
561 if (THPVariable_Check(obj)) {
562 auto& var = THPVariable_Unpack(obj);
563 if (var.numel() != 1 ||
564 !at::isIntegralType(
565 var.dtype().toScalarType(), /*include_bool*/ true)) {
566 throw_intlist_exception(this, i, obj, idx);
567 }
568 auto scalar = var.item();
569 TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
570 res.push_back(scalar.toSymInt());
571 } else {
572 try {
573 if (is_symint(py::handle(obj))) {
574 res.push_back(py::handle(obj).cast<c10::SymInt>());
575 } else {
576 res.emplace_back(THPUtils_unpackIndex(obj));
577 }
578 } catch (std::exception& e) {
579 throw_intlist_exception(this, i, obj, idx);
580 }
581 }
582 }
583 }
584
585 return res;
586}
587
588inline std::vector<int64_t> PythonArgs::intlistWithDefault(
589 int i,
590 std::vector<int64_t> default_intlist) {
591 if (!args[i])
592 return default_intlist;
593 PyObject* arg = args[i];
594 const auto size1 = signature.params[i].size;
595 if (size1 > 0 && THPUtils_checkLong(arg)) {
596 return std::vector<int64_t>(size1, THPUtils_unpackIndex(arg));
597 }
598 auto tuple = PyTuple_Check(arg);
599 // NOLINTNEXTLINE(bugprone-branch-clone)
600 const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
601 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
602 std::vector<int64_t> res(size2);
603 for (const auto idx : c10::irange(size2)) {
604 PyObject* obj =
605 tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
606 // Elements of torch.Size are tensors during tracing, and we need to
607 // record extra information before they are turned into an IntArrayRef
608 if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
609 auto& var = THPVariable_Unpack(obj);
610 jit::tracer::ArgumentStash::stashIntArrayRefElem(
611 signature.params[i].name, size2, idx, var);
612 try {
613 res[idx] = var.item<int64_t>();
614 continue;
615 } catch (std::exception& e) {
616 throw_intlist_exception(this, i, obj, idx);
617 }
618 } else {
619 // convert tensor to scalar outside of try / catch,
620 // so that Tensor subclass exceptions will not be caught.
621 if (THPVariable_Check(obj)) {
622 auto& var = THPVariable_Unpack(obj);
623 if (var.numel() != 1 ||
624 !at::isIntegralType(
625 var.dtype().toScalarType(), /*include_bool*/ true)) {
626 throw_intlist_exception(this, i, obj, idx);
627 }
628 res[idx] = var.item<int64_t>();
629 } else {
630 try {
631 res[idx] = THPUtils_unpackIndex(obj);
632 } catch (std::exception& e) {
633 throw_intlist_exception(this, i, obj, idx);
634 }
635 }
636 }
637 }
638 return res;
639}
640
641inline c10::OptionalArray<int64_t> PythonArgs::intlistOptional(int i) {
642 if (!args[i]) {
643 return {};
644 }
645 return intlist(i);
646}
647
648inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) {
649 if (!args[i]) {
650 return {};
651 }
652 return symintlist(i);
653}
654
655inline std::vector<double> PythonArgs::getDoublelist(int i) {
656 PyObject* arg = args[i];
657 auto tuple = PyTuple_Check(arg);
658 // NOLINTNEXTLINE(bugprone-branch-clone)
659 auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
660 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
661 std::vector<double> res(size);
662 for (const auto idx : c10::irange(size)) {
663 PyObject* obj =
664 tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
665 try {
666 res[idx] = THPUtils_unpackDouble(obj);
667 } catch (const std::exception& e) {
668 throw TypeError(
669 "%s(): argument '%s' must be %s, but found element of type %s at pos %ld",
670 signature.name.c_str(),
671 signature.params[i].name.c_str(),
672 signature.params[i].type_name().c_str(),
673 Py_TYPE(obj)->tp_name,
674 idx + 1);
675 }
676 }
677 return res;
678}
679
680inline c10::OptionalArray<double> PythonArgs::doublelistOptional(int i) {
681 if (!args[i]) {
682 return {};
683 }
684 return this->getDoublelist(i);
685}
686
687inline std::vector<double> PythonArgs::doublelist(int i) {
688 if (!args[i]) {
689 return {};
690 }
691 return this->getDoublelist(i);
692}
693
694inline at::ScalarType PythonArgs::scalartypeWithDefault(
695 int i,
696 at::ScalarType default_scalartype) {
697 if (!args[i])
698 return default_scalartype;
699 return scalartype(i);
700}
701
702inline at::ScalarType PythonArgs::scalartype(int i) {
703 if (!args[i]) {
704 auto scalartype = signature.params[i].default_scalartype;
705 return (scalartype == at::ScalarType::Undefined)
706 ? torch::tensors::get_default_scalar_type()
707 : scalartype;
708 }
709 PyObject* obj = args[i];
710 if (obj == (PyObject*)&PyFloat_Type) {
711 return at::ScalarType::Double;
712 }
713 if (obj == (PyObject*)&PyBool_Type) {
714 return at::ScalarType::Bool;
715 }
716 if (obj == (PyObject*)&PyLong_Type) {
717 return at::ScalarType::Long;
718 }
719 return reinterpret_cast<THPDtype*>(obj)->scalar_type;
720}
721
722inline c10::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) {
723 if (!args[i])
724 return c10::nullopt;
725 return scalartype(i);
726}
727
728inline at::Layout toLayout(PyObject* obj) {
729 const auto layout = reinterpret_cast<THPLayout*>(obj);
730 return layout->layout;
731}
732
733inline at::Layout PythonArgs::layout(int i) {
734 if (!args[i])
735 return signature.params[i].default_layout;
736 return toLayout(args[i]);
737}
738
739inline at::Layout PythonArgs::layoutWithDefault(
740 int i,
741 at::Layout default_layout) {
742 if (!args[i])
743 return default_layout;
744 return layout(i);
745}
746
747inline c10::optional<at::Layout> PythonArgs::layoutOptional(int i) {
748 if (!args[i])
749 return c10::nullopt;
750 return layout(i);
751}
752
753inline at::Device toDevice(PyObject* obj) {
754 if (THPDevice_Check(obj)) {
755 const auto device = reinterpret_cast<THPDevice*>(obj);
756 return device->device;
757 }
758 if (THPUtils_checkLong(obj)) {
759 const auto device_index = THPUtils_unpackLong(obj);
760 TORCH_CHECK(device_index >= 0, "Device index must not be negative");
761 return at::Device(DeviceType::CUDA, device_index);
762 }
763 const std::string& device_str = THPUtils_unpackString(obj);
764 return at::Device(device_str);
765}
766
767inline at::Device PythonArgs::device(int i) {
768 if (!args[i]) {
769 return torch::tensors::get_default_device();
770 }
771 return toDevice(args[i]);
772}
773
774inline at::Device PythonArgs::deviceWithDefault(
775 int i,
776 const at::Device& default_device) {
777 if (!args[i])
778 return default_device;
779 return device(i);
780}
781
782inline c10::optional<at::Device> PythonArgs::deviceOptional(int i) {
783 if (!args[i])
784 return c10::nullopt;
785 return device(i);
786}
787
788inline at::Dimname PythonArgs::dimname(int i) {
789 TORCH_INTERNAL_ASSERT(args[i] != nullptr);
790 return THPDimname_parse(args[i]);
791}
792
793inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
794 auto tuple = PyTuple_Check(arg);
795 // NOLINTNEXTLINE(bugprone-branch-clone)
796 auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
797 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
798 std::vector<at::Dimname> res;
799 res.reserve(size);
800 for (const auto idx : c10::irange(size)) {
801 PyObject* obj =
802 tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
803 res.push_back(THPDimname_parse(obj));
804 }
805 return res;
806}
807
808inline c10::optional<std::vector<at::Dimname>> PythonArgs::
809 toDimnameListOptional(int i) {
810 if (!args[i])
811 return c10::nullopt;
812 return parseDimnameList(args[i]);
813}
814
815inline std::vector<at::Dimname> PythonArgs::dimnamelist(int i) {
816 TORCH_INTERNAL_ASSERT(args[i]);
817 PyObject* arg = args[i];
818 auto size = signature.params[i].size;
819 TORCH_INTERNAL_ASSERT(size == 0 || size == 1);
820 if (size == 1 && THPUtils_checkDimname(arg)) {
821 return {THPDimname_parse(arg)};
822 }
823 return parseDimnameList(arg);
824}
825
826inline at::MemoryFormat PythonArgs::memoryformat(int i) {
827 if (!args[i])
828 return at::MemoryFormat::Contiguous;
829 TORCH_CHECK(
830 THPMemoryFormat_Check(args[i]),
831 "memory_format arg must be an instance of the torch.memory_format");
832 const auto memory_format = reinterpret_cast<THPMemoryFormat*>(args[i]);
833 return memory_format->memory_format;
834}
835
836inline c10::optional<at::MemoryFormat> PythonArgs::memoryformatOptional(int i) {
837 if (!args[i])
838 return c10::nullopt;
839 return memoryformat(i);
840}
841
842inline at::QScheme PythonArgs::toQScheme(int i) {
843 if (!args[i])
844 return at::kPerTensorAffine;
845 TORCH_CHECK(
846 THPQScheme_Check(args[i]),
847 "qscheme arg must be an instance of the torch.qscheme");
848 const auto qscheme = reinterpret_cast<THPQScheme*>(args[i]);
849 return qscheme->qscheme;
850}
851
852inline std::string PythonArgs::string(int i) {
853 return stringWithDefault(i, signature.params[i].default_string);
854}
855
856inline std::string PythonArgs::stringWithDefault(
857 int i,
858 const std::string& default_str) {
859 if (!args[i])
860 return default_str;
861 return THPUtils_unpackString(args[i]);
862}
863
864inline c10::optional<std::string> PythonArgs::stringOptional(int i) {
865 if (!args[i])
866 return c10::nullopt;
867 return THPUtils_unpackString(args[i]);
868}
869
870inline c10::string_view PythonArgs::stringView(int i) {
871 return stringViewWithDefault(i, signature.params[i].default_string);
872}
873
874inline c10::string_view PythonArgs::stringViewWithDefault(
875 int i,
876 const c10::string_view default_str) {
877 if (!args[i])
878 return default_str;
879 return THPUtils_unpackStringView(args[i]);
880}
881
882inline c10::optional<c10::string_view> PythonArgs::stringViewOptional(int i) {
883 if (!args[i])
884 return c10::nullopt;
885 return THPUtils_unpackStringView(args[i]);
886}
887
888inline int64_t PythonArgs::toInt64(int i) {
889 if (!args[i])
890 return signature.params[i].default_int;
891 if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
892 auto& var = THPVariable_Unpack(args[i]);
893 jit::tracer::ArgumentStash::stashValue(
894 signature.params[i].name, idx, var, c10::IntType::get());
895 }
896 return THPUtils_unpackLong(args[i]);
897}
898
899inline c10::SymInt PythonArgs::toSymInt(int i) {
900 if (!args[i]) {
901 return c10::SymInt(signature.params[i].default_int);
902 }
903 if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
904 auto& var = THPVariable_Unpack(args[i]);
905 jit::tracer::ArgumentStash::stashValue(
906 signature.params[i].name, idx, var, c10::IntType::get());
907 }
908
909 return py::cast<c10::SymInt>(py::handle(args[i]));
910}
911
912inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) {
913 if (!args[i])
914 return default_int;
915 return toInt64(i);
916}
917
918inline c10::optional<int64_t> PythonArgs::toInt64Optional(int i) {
919 if (!args[i])
920 return c10::nullopt;
921 return toInt64(i);
922}
923
924inline c10::optional<c10::SymInt> PythonArgs::toSymIntOptional(int i) {
925 if (!args[i])
926 return c10::nullopt;
927 return toSymInt(i);
928}
929
930inline c10::optional<bool> PythonArgs::toBoolOptional(int i) {
931 if (!args[i]) {
932 return c10::nullopt;
933 }
934 return toBool(i);
935}
936
937inline c10::optional<double> PythonArgs::toDoubleOptional(int i) {
938 if (!args[i]) {
939 return c10::nullopt;
940 }
941 return toDouble(i);
942}
943
944inline double PythonArgs::toDouble(int i) {
945 if (!args[i])
946 return signature.params[i].default_double;
947 return THPUtils_unpackDouble(args[i]);
948}
949
950inline double PythonArgs::toDoubleWithDefault(int i, double default_double) {
951 if (!args[i])
952 return default_double;
953 return toDouble(i);
954}
955
956inline c10::complex<double> PythonArgs::toComplex(int i) {
957 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
958 c10::complex<double> default_value = *const_cast<c10::complex<double>*>(
959 reinterpret_cast<const c10::complex<double>*>(
960 signature.params[i].default_complex));
961 if (!args[i])
962 return default_value;
963 return THPUtils_unpackComplexDouble(args[i]);
964}
965
966inline c10::complex<double> PythonArgs::toComplexWithDefault(
967 int i,
968 c10::complex<double> default_value) {
969 if (!args[i])
970 return default_value;
971 return toComplex(i);
972}
973
974inline bool PythonArgs::toBool(int i) {
975 if (!args[i])
976 return signature.params[i].default_bool;
977 return args[i] == Py_True;
978}
979
980inline bool PythonArgs::toBoolWithDefault(int i, bool default_bool) {
981 if (!args[i])
982 return default_bool;
983 return toBool(i);
984}
985
986inline bool PythonArgs::isNone(int i) {
987 return args[i] == nullptr;
988}
989
990inline c10::optional<at::Generator> PythonArgs::generator(int i) {
991 if (!args[i])
992 return c10::nullopt;
993 return reinterpret_cast<THPGenerator*>(args[i])->cdata;
994}
995
996inline at::Storage PythonArgs::storage(int i) {
997 if (!args[i])
998 return at::Storage();
999 return createStorage(args[i]);
1000}
1001
1002inline at::Storage PythonArgs::storage(
1003 int i,
1004 at::ScalarType& storage_scalar_type,
1005 bool& is_typed_storage) {
1006 at::Storage storage;
1007 if (!args[i]) {
1008 storage = at::Storage();
1009 is_typed_storage = false;
1010 storage_scalar_type = at::ScalarType::Undefined;
1011 } else {
1012 storage =
1013 createStorageGetType(args[i], storage_scalar_type, is_typed_storage);
1014 }
1015 return storage;
1016}
1017
1018inline c10::Stream PythonArgs::stream(int i) {
1019 if (!args[i])
1020 return c10::Stream(
1021 c10::Stream::Default::DEFAULT, c10::Device(DeviceType::CPU, -1));
1022 if (!THPStream_Check(args[i])) {
1023 throw TypeError(
1024 "expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name);
1025 }
1026 return c10::Stream::unpack3(
1027 ((THPStream*)args[i])->stream_id,
1028 ((THPStream*)args[i])->device_index,
1029 static_cast<DeviceType>(((THPStream*)args[i])->device_type));
1030}
1031
1032inline PyObject* PythonArgs::pyobject(int i) {
1033 if (!args[i])
1034 return Py_None;
1035 return args[i];
1036}
1037
1038/*
1039 *
1040 * Handle __torch_function__ overrides if we know that there are overloaded
1041 * arguments. All objects stored in r.overloaded_args must have a
1042 * __torch_function__ implementation and the arguments must be ordered in order
1043 * of precedence. Precedence goes from left to right in the order of the
1044 * signature of the function the overloaded arguments were passed to, except
1045 * subclasses are always considered before superclasses.
1046 *
1047 * If the result of calling __torch_function__ is NotImplemented, the
1048 * next implementation in the precedence order is called. If all
1049 * arguments return NotImplemented from their __torch_function__
1050 * implementation, a TypeError is raised in Python.
1051 *
1052 * Assumes overloaded_args has at least one entry. All entries must have
1053 * a __torch_function__ attribute that resolves to a callable that
1054 * accepts a torch API function, a tuple of arguments, and a dict of
1055 * keyword arguments for the torch API function.
1056 *
1057 * It is sufficient to call PythonArgs::has_torch_function before
1058 * calling this function to verify that there are valid arguments
1059 * present. If that is not done then special care must be taken to
1060 * ensure there are arguments that are overloaded with
1061 * __torch_function__.
1062 *
1063 * See torch._overrides.handle_torch_function for the equivalent
1064 * code in the pure-python implementation.
1065 *
1066 * 'r' is a parsed PythonArgs instance, returned from
1067 * PythonArgParser::parse.
1068 *
1069 * 'args' is a reference to the python tuple of arguments to the torch
1070 * API function.
1071 *
1072 * 'kwargs' is a reference to the python dict of keyword arguments to
1073 * the torch API function.
1074 *
1075 * 'torch_api' is a reference to a python torch API namespace.
1076 *
1077 * 'torch_api_function' is the reference to the original torch method, usually,
1078 * we can use torch_api and func_name to get torch_api_function. In some cases,
1079 * e.g., torch custom op, we create the function in C++, if we still use
1080 * torch_api and func_name to fetch original api, a cyclic call will happen.
1081 *
1082 * 'overloaded_args' is the args which have overloaded __torch_function__.
1083 *
1084 * 'func_name' is the named of the original torch method.
1085 *
1086 * TODO: we could use different names for the following 'handle_torch_function'
1087 * instead of overloading.
1088 *
1089 */
1090// Used for Tensor methods with arguments.
1091auto handle_torch_function(
1092 PythonArgs& r,
1093 PyObject* self,
1094 PyObject* args,
1095 PyObject* kwargs,
1096 PyObject* torch_api,
1097 const char* module_name,
1098 const char* func_name_override = nullptr) -> PyObject*;
1099
1100// Used for functions which needs to parse python args.
1101auto handle_torch_function(
1102 PythonArgs& r,
1103 PyObject* args,
1104 PyObject* kwargs,
1105 PyObject* torch_api,
1106 const char* module_name,
1107 const char* func_name_override = nullptr) -> PyObject*;
1108
1109// Used for functions that have no argument parsing.
1110auto handle_torch_function(
1111 PyObject* self,
1112 const std::string& func_name,
1113 PyObject* args = nullptr,
1114 PyObject* kwargs = nullptr,
1115 PyObject* torch_api = THPVariableClass,
1116 const std::string& module_name = "torch.Tensor") -> PyObject*;
1117
1118// Used for functions created in C++, e.g., C++ custom op, which doesn't use
1119// PythonArgParser to get overloaded_args.
1120enum class TorchFunctionName { TorchFunction, TorchDispatch };
1121
1122auto TORCH_API handle_torch_function_no_python_arg_parser(
1123 at::ArrayRef<py::handle> overloaded_args,
1124 PyObject* args,
1125 PyObject* kwargs,
1126 const char* func_name,
1127 PyObject* torch_api_function,
1128 const char* module_name,
1129 TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction)
1130 -> PyObject*;
1131
1132// Used for getters of Tensor properties
1133auto handle_torch_function_getter(
1134 THPVariable* self,
1135 const std::string& property_name) -> PyObject*;
1136
1137// Used for setters of Tensor properties.
1138auto handle_torch_function_setter(
1139 THPVariable* self,
1140 const std::string& property_name,
1141 PyObject* value) -> int;
1142
1143// Used for __getitem__ and __setitem__
1144auto handle_torch_function_indexing(
1145 PyObject* self,
1146 PyObject* index,
1147 PyObject* val = nullptr) -> PyObject*;
1148
1149/*
1150 * Check if the input obj is Tensor type, including its subclass, or overloaded
1151 * type. If the type defines __torch_function__, it also returns true.
1152 * Otherwise returns flase. If the class is not torch.Tensor, and it defines
1153 * __torch_function__, we append obj to overloaded_args.
1154 *
1155 * 'obj': the input argument to be checked
1156 * 'overloaded_args': the vector to append the overloaded args.
1157 */
1158bool is_tensor_and_append_overloaded(
1159 PyObject* obj,
1160 std::vector<py::handle>* overloaded_args);
1161
1162/*
1163 * Check if the input obj is Tensor List or Tensor Tuple type. First check
1164 * whether obj is Tuple or List type, if true, iterate over each element and
1165 * check whether it is Tensor type, including its subclass or overloaded type.
1166 * At the same time, the overloaded arg is appended to the overloaded_args.
1167 *
1168 * 'obj': the input argument to be checked
1169 * 'overloaded_args': the vector to append the overloaded args.
1170 * 'argnum': the number of total arguments of the function being checked.
1171 * 'throw_error': whether throw error if any element in the list or tuple is
1172 * not tensor type or overloaded.
1173 */
1174bool is_tensor_list_and_append_overloaded(
1175 PyObject* obj,
1176 std::vector<py::handle>* overloaded_args,
1177 int argnum,
1178 bool throw_error);
1179
1180/* Given an argument that is definitely a tensor and is definitely overloaded,
1181 * append it to the overloaded arguments list. Use this instead of
1182 * is_tensor_and_append_overloaded in situations where you have a PyObject
1183 * and you know it definitely is a Tensor and it is definitely overloaded.
1184 *
1185 * 'overloaded_args': the vector to append the overloaded args
1186 * 'obj': the input tensor that is overloaded
1187 */
1188void append_overloaded_tensor(
1189 std::vector<py::handle>* overloaded_args,
1190 PyObject* obj);
1191
1192/* Given an argument that is definitely a type and is definitely overloaded,
1193 * append it to the overloaded arguments list. Use this only with
1194 * __torch_dispatch__, where we operate on classes that have a
1195 * __torch_dispatch__ classmethod.
1196 *
1197 * 'overloaded_args': the vector to append the overloaded type
1198 * 'obj': the input class that has a __torch_dispatch__ classmethod.
1199 */
1200void append_overloaded_type(
1201 std::vector<py::handle>* overloaded_args,
1202 PyObject* obj);
1203
1204} // namespace torch
1205