1 | #pragma once |
2 | |
3 | // Parse arguments to Python functions implemented in C++ |
4 | // This is similar to PyArg_ParseTupleAndKeywords(), but specifically handles |
5 | // the types relevant to PyTorch and distinguishes between overloaded function |
6 | // signatures. |
7 | // |
8 | // Example: |
9 | // |
10 | // static PythonArgParser parser({ |
11 | // "norm(Scalar p, int64_t dim, bool keepdim=False)", |
12 | // "norm(Scalar p=2)", |
13 | // }); |
14 | // ParsedArgs<3> parsed_args; |
15 | // auto r = parser.parse(args, kwargs, parsed_args); |
16 | // if (r.idx == 0) { |
17 | // norm(r.scalar(0), r.int64(1), r.bool(0)); |
18 | // } else { |
19 | // norm(r.scalar(0)); |
20 | // } |
21 | // |
22 | // We auto-generate most uses of PythonArgParser; the generated files |
23 | // are torch/csrc/autograd/generated/python_*.cpp |
24 | // |
25 | // Some gotchas that you should watch out for: |
26 | // |
27 | // - Note [Order of overloads matters] |
28 | // Order of overloads matters. A set of input arguments may |
29 | // bind to multiple argument specs; we will always pick the |
30 | // first one in PythonArgParser. However, when you are writing |
31 | // overloads in, e.g., native_functions.yaml, you don't have to |
32 | // worry about what order you write them, because the code |
33 | // generation logic always gives the overloads a canonical |
34 | // order, where Tensor overloads come first, before Scalar overloads. |
35 | // This logic is in sort_declarations in |
36 | // tools/autograd/gen_python_functions.py |
37 | // |
38 | // - Zero-dim tensors (e.g., torch.tensor(2)) bind to both |
39 | // Scalar and Tensor, UNLESS they require grad (in which case |
40 | // they only bind to Tensor). |
41 | |
42 | #include <pybind11/pytypes.h> |
43 | #include <torch/csrc/python_headers.h> |
44 | |
45 | #include <torch/csrc/Device.h> |
46 | #include <torch/csrc/Dtype.h> |
47 | #include <torch/csrc/DynamicTypes.h> |
48 | #include <torch/csrc/Exceptions.h> |
49 | #include <torch/csrc/Generator.h> |
50 | #include <torch/csrc/Layout.h> |
51 | #include <torch/csrc/MemoryFormat.h> |
52 | #include <torch/csrc/QScheme.h> |
53 | #include <torch/csrc/Stream.h> |
54 | #include <torch/csrc/autograd/python_variable.h> |
55 | #include <torch/csrc/autograd/variable.h> |
56 | #include <torch/csrc/jit/frontend/tracer.h> |
57 | #include <torch/csrc/python_dimname.h> |
58 | #include <torch/csrc/tensor/python_tensor.h> |
59 | #include <torch/csrc/utils/disable_torch_function.h> |
60 | #include <torch/csrc/utils/object_ptr.h> |
61 | #include <torch/csrc/utils/pybind.h> |
62 | #include <torch/csrc/utils/python_numbers.h> |
63 | #include <torch/csrc/utils/python_strings.h> |
64 | #include <torch/csrc/utils/python_symnode.h> |
65 | #include <torch/csrc/utils/six.h> |
66 | |
67 | #include <ATen/PythonTorchFunctionTLS.h> |
68 | #include <ATen/core/Tensor.h> |
69 | #include <c10/util/Exception.h> |
70 | #include <c10/util/irange.h> |
71 | |
72 | #include <c10/core/SymFloat.h> |
73 | #include <c10/core/SymNodeImpl.h> |
74 | |
75 | #include <array> |
76 | #include <cstddef> |
77 | #include <memory> |
78 | #include <sstream> |
79 | #include <string> |
80 | #include <vector> |
81 | |
82 | inline bool THPUtils_checkScalar(PyObject* obj) { |
83 | #ifdef USE_NUMPY |
84 | if (torch::utils::is_numpy_scalar(obj)) { |
85 | return true; |
86 | } |
87 | #endif |
88 | return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) || |
89 | torch::is_symint(py::handle(obj)) || torch::is_symfloat(py::handle(obj)); |
90 | } |
91 | |
92 | namespace torch { |
93 | |
94 | bool should_allow_numbers_as_tensors(const std::string& name); |
95 | |
96 | enum class ParameterType { |
97 | TENSOR, |
98 | SCALAR, |
99 | INT64, |
100 | SYM_INT, |
101 | DOUBLE, |
102 | COMPLEX, |
103 | TENSOR_LIST, |
104 | INT_LIST, |
105 | GENERATOR, |
106 | BOOL, |
107 | STORAGE, |
108 | PYOBJECT, |
109 | SCALARTYPE, |
110 | LAYOUT, |
111 | MEMORY_FORMAT, |
112 | DEVICE, |
113 | STREAM, |
114 | STRING, |
115 | DIMNAME, |
116 | DIMNAME_LIST, |
117 | QSCHEME, |
118 | FLOAT_LIST, |
119 | SCALAR_LIST, |
120 | SYM_INT_LIST |
121 | }; |
122 | |
123 | struct FunctionParameter; |
124 | struct FunctionSignature; |
125 | struct PythonArgs; |
126 | |
127 | // Contains bound Python arguments in declaration order |
128 | template <int N> |
129 | struct ParsedArgs { |
130 | ParsedArgs() : args() {} |
131 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
132 | PyObject* args[N]; |
133 | }; |
134 | |
135 | struct PythonArgParser { |
136 | explicit PythonArgParser( |
137 | std::vector<std::string> fmts, |
138 | bool traceable = false); |
139 | |
140 | // meant only for `torch` functions. |
141 | template <int N> |
142 | inline PythonArgs parse( |
143 | PyObject* self, |
144 | PyObject* args, |
145 | PyObject* kwargs, |
146 | ParsedArgs<N>& dst); |
147 | |
148 | template <int N> |
149 | inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst); |
150 | |
151 | inline PythonArgs parse(PyObject* self, ParsedArgs<0>& dst); |
152 | |
153 | // Formatted strings of non-hidden signatures |
154 | std::vector<std::string> get_signatures() const; |
155 | |
156 | private: |
157 | [[noreturn]] |
158 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
159 | void |
160 | print_error( |
161 | PyObject* self, |
162 | PyObject* args, |
163 | PyObject* kwargs, |
164 | PyObject* parsed_args[]); |
165 | void check_deprecated(const FunctionSignature& signature); |
166 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
167 | PythonArgs raw_parse( |
168 | PyObject* self, |
169 | PyObject* args, |
170 | PyObject* kwargs, |
171 | PyObject* parsed_args[]); |
172 | |
173 | std::vector<FunctionSignature> signatures_; |
174 | std::string function_name; |
175 | size_t max_args; |
176 | bool traceable; |
177 | }; |
178 | |
179 | struct PYBIND11_EXPORT FunctionSignature { |
180 | explicit FunctionSignature(const std::string& fmt, int index); |
181 | |
182 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
183 | bool parse( |
184 | PyObject* self, |
185 | PyObject* args, |
186 | PyObject* kwargs, |
187 | PyObject* dst[], |
188 | bool raise_exception); |
189 | |
190 | std::string toString() const; |
191 | |
192 | std::string name; |
193 | std::vector<FunctionParameter> params; |
194 | std::vector<py::handle> overloaded_args; |
195 | size_t min_args; |
196 | size_t max_args; |
197 | size_t max_pos_args; |
198 | int index; |
199 | bool hidden; |
200 | bool deprecated; |
201 | bool disable_torch_function; |
202 | }; |
203 | |
204 | struct PythonArgs { |
205 | PythonArgs( |
206 | bool traceable, |
207 | const FunctionSignature& signature, |
208 | PyObject** args) |
209 | : idx(signature.index), |
210 | traceable(traceable), |
211 | signature(signature), |
212 | args(args) {} |
213 | |
214 | int idx; |
215 | bool traceable; |
216 | const FunctionSignature& signature; |
217 | PyObject** args; |
218 | |
219 | inline bool has_torch_function(); |
220 | inline std::string get_func_name(); |
221 | inline at::Tensor tensor(int i); |
222 | inline c10::optional<at::Tensor> optionalTensor(int i); |
223 | inline at::Scalar scalar(int i); |
224 | inline at::Scalar scalarWithDefault(int i, const at::Scalar& default_scalar); |
225 | inline std::vector<at::Scalar> scalarlist(int i); |
226 | inline std::vector<at::Tensor> tensorlist(int i); |
227 | inline torch::List<c10::optional<at::Tensor>> list_of_optional_tensors(int i); |
228 | template <int N> |
229 | inline std::array<at::Tensor, N> tensorlist_n(int i); |
230 | inline std::vector<int64_t> intlist(int i); |
231 | inline std::vector<c10::SymInt> symintlist(int i); |
232 | inline c10::OptionalArray<int64_t> intlistOptional(int i); |
233 | inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i); |
234 | inline std::vector<int64_t> intlistWithDefault( |
235 | int i, |
236 | std::vector<int64_t> default_intlist); |
237 | inline c10::optional<at::Generator> generator(int i); |
238 | inline at::Storage storage(int i); |
239 | inline at::Storage storage( |
240 | int i, |
241 | at::ScalarType& storage_scalar_type, |
242 | bool& is_typed_storage); |
243 | inline c10::Stream stream(int i); |
244 | inline at::ScalarType scalartype(int i); |
245 | inline at::ScalarType scalartypeWithDefault( |
246 | int i, |
247 | at::ScalarType default_scalartype); |
248 | inline c10::optional<at::ScalarType> scalartypeOptional(int i); |
249 | inline c10::optional<at::Scalar> scalarOptional(int i); |
250 | inline c10::optional<int64_t> toInt64Optional(int i); |
251 | inline c10::optional<c10::SymInt> toSymIntOptional(int i); |
252 | inline c10::optional<bool> toBoolOptional(int i); |
253 | inline c10::optional<double> toDoubleOptional(int i); |
254 | inline c10::OptionalArray<double> doublelistOptional(int i); |
255 | inline std::vector<double> doublelist(int i); |
256 | inline std::vector<double> getDoublelist(int i); |
257 | inline at::Layout layout(int i); |
258 | inline at::Layout layoutWithDefault(int i, at::Layout default_layout); |
259 | inline c10::optional<at::Layout> layoutOptional(int i); |
260 | inline at::Device device(int i); |
261 | inline at::Device deviceWithDefault(int i, const at::Device& default_device); |
262 | inline c10::optional<at::Device> deviceOptional(int i); |
263 | inline at::Dimname dimname(int i); |
264 | inline std::vector<at::Dimname> dimnamelist(int i); |
265 | inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i); |
266 | inline at::MemoryFormat memoryformat(int i); |
267 | inline c10::optional<at::MemoryFormat> memoryformatOptional(int i); |
268 | inline at::QScheme toQScheme(int i); |
269 | inline std::string string(int i); |
270 | inline std::string stringWithDefault(int i, const std::string& default_str); |
271 | inline c10::optional<std::string> stringOptional(int i); |
272 | inline c10::string_view stringView(int i); |
273 | inline c10::string_view stringViewWithDefault( |
274 | int i, |
275 | const c10::string_view default_str); |
276 | inline c10::optional<c10::string_view> stringViewOptional(int i); |
277 | inline PyObject* pyobject(int i); |
278 | inline int64_t toInt64(int i); |
279 | inline c10::SymInt toSymInt(int i); |
280 | inline int64_t toInt64WithDefault(int i, int64_t default_int); |
281 | inline double toDouble(int i); |
282 | inline double toDoubleWithDefault(int i, double default_double); |
283 | inline c10::complex<double> toComplex(int i); |
284 | inline c10::complex<double> toComplexWithDefault( |
285 | int i, |
286 | c10::complex<double> default_complex); |
287 | inline bool toBool(int i); |
288 | inline bool toBoolWithDefault(int i, bool default_bool); |
289 | inline bool isNone(int i); |
290 | |
291 | private: |
292 | at::Tensor tensor_slow(int i); |
293 | at::Scalar scalar_slow(int i); |
294 | at::Scalar scalar_slow(PyObject* arg); |
295 | }; |
296 | |
297 | struct FunctionParameter { |
298 | FunctionParameter(const std::string& fmt, bool keyword_only); |
299 | |
300 | bool check( |
301 | PyObject* obj, |
302 | std::vector<py::handle>& overloaded_args, |
303 | int argnum, |
304 | int64_t* failed_idx = nullptr); |
305 | |
306 | void set_default_str(const std::string& str); |
307 | std::string type_name() const; |
308 | |
309 | ParameterType type_; |
310 | bool optional; |
311 | bool allow_none; |
312 | bool keyword_only; |
313 | bool allow_numbers_as_tensors = false; |
314 | int size; |
315 | std::string name; |
316 | // having this as a raw PyObject * will presumably leak it, but these are only |
317 | // held by static objects anyway, and Py_Finalize can already be called when |
318 | // this is destructed. |
319 | PyObject* python_name; |
320 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
321 | at::SmallVector<PyObject*, 5> numpy_python_names; |
322 | at::Scalar default_scalar; |
323 | std::vector<int64_t> default_intlist; |
324 | std::string default_string; |
325 | union { |
326 | bool default_bool; |
327 | int64_t default_int; |
328 | double default_double; |
329 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
330 | double default_complex[2]; // see Scalar |
331 | at::ScalarType default_scalartype; |
332 | at::Layout default_layout; |
333 | }; |
334 | }; |
335 | |
336 | template <int N> |
337 | inline PythonArgs PythonArgParser::parse( |
338 | PyObject* self, |
339 | PyObject* args, |
340 | PyObject* kwargs, |
341 | ParsedArgs<N>& dst) { |
342 | if (N < max_args) { |
343 | throw ValueError( |
344 | "PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)" , |
345 | (int)max_args, |
346 | N); |
347 | } |
348 | return raw_parse(self, args, kwargs, dst.args); |
349 | } |
350 | |
351 | template <int N> |
352 | inline PythonArgs PythonArgParser::parse( |
353 | PyObject* args, |
354 | PyObject* kwargs, |
355 | ParsedArgs<N>& dst) { |
356 | return parse(nullptr, args, kwargs, dst); |
357 | } |
358 | |
359 | inline PythonArgs PythonArgParser::parse(PyObject* self, ParsedArgs<0>& dst) { |
360 | return parse(self, nullptr, nullptr, dst); |
361 | } |
362 | |
363 | inline bool PythonArgs::has_torch_function() { |
364 | return !this->signature.overloaded_args.empty() || |
365 | at::impl::torch_function_mode_enabled(); |
366 | } |
367 | |
368 | inline std::string PythonArgs::get_func_name() { |
369 | return signature.name; |
370 | } |
371 | |
372 | // TODO: this can return MaybeOwned |
373 | inline at::Tensor PythonArgs::tensor(int i) { |
374 | if (args[i] && THPVariable_CheckExact(args[i])) { |
375 | return THPVariable_Unpack(args[i]); |
376 | } |
377 | return tensor_slow(i); |
378 | } |
379 | |
380 | inline c10::optional<at::Tensor> PythonArgs::optionalTensor(int i) { |
381 | at::Tensor t = tensor(i); |
382 | // NOLINTNEXTLINE(bugprone-branch-clone) |
383 | if (t.defined()) { |
384 | return t; |
385 | } else { |
386 | return c10::nullopt; |
387 | } |
388 | } |
389 | |
390 | inline at::Scalar PythonArgs::scalar(int i) { |
391 | if (!args[i]) |
392 | return signature.params[i].default_scalar; |
393 | return scalar_slow(i); |
394 | } |
395 | |
396 | inline std::vector<at::Scalar> PythonArgs::scalarlist(int i) { |
397 | if (!args[i]) |
398 | return std::vector<at::Scalar>(); |
399 | auto tuple = six::isTuple(args[i]); |
400 | THPObjectPtr arg = six::maybeAsTuple(args[i]); |
401 | // NOLINTNEXTLINE(bugprone-branch-clone) |
402 | auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); |
403 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
404 | std::vector<at::Scalar> res(size); |
405 | for (const auto idx : c10::irange(size)) { |
406 | PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) |
407 | : PyList_GET_ITEM(arg.get(), idx); |
408 | res[idx] = scalar_slow(obj); |
409 | } |
410 | return res; |
411 | } |
412 | |
413 | inline at::Scalar PythonArgs::scalarWithDefault( |
414 | int i, |
415 | const at::Scalar& default_scalar) { |
416 | if (!args[i]) |
417 | return default_scalar; |
418 | return scalar_slow(i); |
419 | } |
420 | |
421 | inline c10::optional<at::Scalar> PythonArgs::scalarOptional(int i) { |
422 | if (!args[i]) |
423 | return c10::nullopt; |
424 | return scalar_slow(i); |
425 | } |
426 | |
427 | inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) { |
428 | if (!args[i]) |
429 | return std::vector<at::Tensor>(); |
430 | auto tuple = six::isTuple(args[i]); |
431 | THPObjectPtr arg = six::maybeAsTuple(args[i]); |
432 | // NOLINTNEXTLINE(bugprone-branch-clone) |
433 | auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); |
434 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
435 | std::vector<at::Tensor> res(size); |
436 | for (const auto idx : c10::irange(size)) { |
437 | PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) |
438 | : PyList_GET_ITEM(arg.get(), idx); |
439 | // This is checked by the argument parser so it's safe to cast without |
440 | // checking if this is a tensor first |
441 | res[idx] = THPVariable_Unpack(obj); |
442 | } |
443 | return res; |
444 | } |
445 | |
446 | inline torch::List<c10::optional<at::Tensor>> PythonArgs:: |
447 | list_of_optional_tensors(int i) { |
448 | if (!args[i]) |
449 | return torch::List<c10::optional<at::Tensor>>(); |
450 | auto tuple = six::isTuple(args[i]); |
451 | THPObjectPtr arg = six::maybeAsTuple(args[i]); |
452 | // NOLINTNEXTLINE(bugprone-branch-clone) |
453 | auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); |
454 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
455 | torch::List<c10::optional<at::Tensor>> res; |
456 | res.reserve(size); |
457 | for (const auto idx : c10::irange(size)) { |
458 | PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) |
459 | : PyList_GET_ITEM(arg.get(), idx); |
460 | // This is checked by the argument parser so it's safe to cast without |
461 | // checking if this is a tensor first |
462 | res.push_back(THPVariable_Unpack(obj)); |
463 | } |
464 | return res; |
465 | } |
466 | |
467 | template <int N> |
468 | inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) { |
469 | auto res = std::array<at::Tensor, N>(); |
470 | if (!args[i]) |
471 | return res; |
472 | auto tuple = six::isTuple(args[i]); |
473 | THPObjectPtr arg = six::maybeAsTuple(args[i]); |
474 | // NOLINTNEXTLINE(bugprone-branch-clone) |
475 | auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); |
476 | if (size != N) { |
477 | throw TypeError("expected tuple of %d elements but got %d" , N, (int)size); |
478 | } |
479 | for (const auto idx : c10::irange(size)) { |
480 | PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) |
481 | : PyList_GET_ITEM(arg.get(), idx); |
482 | // This is checked by the argument parser so it's safe to cast without |
483 | // checking if this is a tensor first |
484 | res[idx] = THPVariable_Unpack(obj); |
485 | } |
486 | return res; |
487 | } |
488 | |
489 | inline std::vector<int64_t> PythonArgs::intlist(int i) { |
490 | return intlistWithDefault(i, signature.params[i].default_intlist); |
491 | } |
492 | |
493 | inline PyObject* toPyObject(c10::SymInt symint) { |
494 | if (symint.is_symbolic()) { |
495 | auto r = py::cast(symint).release().ptr(); |
496 | TORCH_INTERNAL_ASSERT(r); |
497 | return r; |
498 | } else { |
499 | return THPUtils_packInt64(symint.as_int_unchecked()); |
500 | } |
501 | } |
502 | |
503 | inline void throw_intlist_exception( |
504 | const torch::PythonArgs* args, |
505 | size_t i, |
506 | PyObject* obj, |
507 | size_t idx) { |
508 | throw TypeError( |
509 | "%s(): argument '%s' must be %s, but found element of type %s at pos %ld" , |
510 | args->signature.name.c_str(), |
511 | args->signature.params[i].name.c_str(), |
512 | args->signature.params[i].type_name().c_str(), |
513 | Py_TYPE(obj)->tp_name, |
514 | idx + 1); |
515 | } |
516 | |
517 | inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) { |
518 | if (!args[i]) { |
519 | return c10::fmap(signature.params[i].default_intlist, [](int64_t di) { |
520 | return c10::SymInt(di); |
521 | }); |
522 | } |
523 | |
524 | const auto size1 = signature.params[i].size; |
525 | if (size1 > 0 && THPUtils_checkLong(args[i])) { |
526 | return std::vector<c10::SymInt>( |
527 | size1, c10::SymInt(THPUtils_unpackIndex(args[i]))); |
528 | } |
529 | |
530 | if (size1 > 0 && torch::is_symint(py::handle(args[i]))) { |
531 | auto si = py::handle(args[i]).cast<c10::SymInt>(); |
532 | return std::vector<c10::SymInt>(size1, si); |
533 | } |
534 | |
535 | PyObject* arg = args[i]; |
536 | auto tuple = PyTuple_Check(arg); |
537 | // NOLINTNEXTLINE(bugprone-branch-clone) |
538 | const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); |
539 | std::vector<c10::SymInt> res; |
540 | res.reserve(size2); |
541 | for (const auto idx : c10::irange(size2)) { |
542 | PyObject* obj = |
543 | tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); |
544 | |
545 | // Elements of torch.Size are tensors during tracing, and we need to |
546 | // record extra information before they are turned into an IntArrayRef |
547 | if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) { |
548 | auto& var = THPVariable_Unpack(obj); |
549 | jit::tracer::ArgumentStash::stashIntArrayRefElem( |
550 | signature.params[i].name, size2, idx, var); |
551 | try { |
552 | res.emplace_back(var.item<int64_t>()); |
553 | continue; |
554 | } catch (std::exception& e) { |
555 | throw_intlist_exception(this, i, obj, idx); |
556 | } |
557 | continue; |
558 | } else { |
559 | // convert tensor to scalar outside of try / catch, |
560 | // so that Tensor subclass exceptions will not be caught. |
561 | if (THPVariable_Check(obj)) { |
562 | auto& var = THPVariable_Unpack(obj); |
563 | if (var.numel() != 1 || |
564 | !at::isIntegralType( |
565 | var.dtype().toScalarType(), /*include_bool*/ true)) { |
566 | throw_intlist_exception(this, i, obj, idx); |
567 | } |
568 | auto scalar = var.item(); |
569 | TORCH_CHECK(scalar.isIntegral(/*include bool*/ false)); |
570 | res.push_back(scalar.toSymInt()); |
571 | } else { |
572 | try { |
573 | if (is_symint(py::handle(obj))) { |
574 | res.push_back(py::handle(obj).cast<c10::SymInt>()); |
575 | } else { |
576 | res.emplace_back(THPUtils_unpackIndex(obj)); |
577 | } |
578 | } catch (std::exception& e) { |
579 | throw_intlist_exception(this, i, obj, idx); |
580 | } |
581 | } |
582 | } |
583 | } |
584 | |
585 | return res; |
586 | } |
587 | |
588 | inline std::vector<int64_t> PythonArgs::intlistWithDefault( |
589 | int i, |
590 | std::vector<int64_t> default_intlist) { |
591 | if (!args[i]) |
592 | return default_intlist; |
593 | PyObject* arg = args[i]; |
594 | const auto size1 = signature.params[i].size; |
595 | if (size1 > 0 && THPUtils_checkLong(arg)) { |
596 | return std::vector<int64_t>(size1, THPUtils_unpackIndex(arg)); |
597 | } |
598 | auto tuple = PyTuple_Check(arg); |
599 | // NOLINTNEXTLINE(bugprone-branch-clone) |
600 | const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); |
601 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
602 | std::vector<int64_t> res(size2); |
603 | for (const auto idx : c10::irange(size2)) { |
604 | PyObject* obj = |
605 | tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); |
606 | // Elements of torch.Size are tensors during tracing, and we need to |
607 | // record extra information before they are turned into an IntArrayRef |
608 | if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) { |
609 | auto& var = THPVariable_Unpack(obj); |
610 | jit::tracer::ArgumentStash::stashIntArrayRefElem( |
611 | signature.params[i].name, size2, idx, var); |
612 | try { |
613 | res[idx] = var.item<int64_t>(); |
614 | continue; |
615 | } catch (std::exception& e) { |
616 | throw_intlist_exception(this, i, obj, idx); |
617 | } |
618 | } else { |
619 | // convert tensor to scalar outside of try / catch, |
620 | // so that Tensor subclass exceptions will not be caught. |
621 | if (THPVariable_Check(obj)) { |
622 | auto& var = THPVariable_Unpack(obj); |
623 | if (var.numel() != 1 || |
624 | !at::isIntegralType( |
625 | var.dtype().toScalarType(), /*include_bool*/ true)) { |
626 | throw_intlist_exception(this, i, obj, idx); |
627 | } |
628 | res[idx] = var.item<int64_t>(); |
629 | } else { |
630 | try { |
631 | res[idx] = THPUtils_unpackIndex(obj); |
632 | } catch (std::exception& e) { |
633 | throw_intlist_exception(this, i, obj, idx); |
634 | } |
635 | } |
636 | } |
637 | } |
638 | return res; |
639 | } |
640 | |
641 | inline c10::OptionalArray<int64_t> PythonArgs::intlistOptional(int i) { |
642 | if (!args[i]) { |
643 | return {}; |
644 | } |
645 | return intlist(i); |
646 | } |
647 | |
648 | inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) { |
649 | if (!args[i]) { |
650 | return {}; |
651 | } |
652 | return symintlist(i); |
653 | } |
654 | |
655 | inline std::vector<double> PythonArgs::getDoublelist(int i) { |
656 | PyObject* arg = args[i]; |
657 | auto tuple = PyTuple_Check(arg); |
658 | // NOLINTNEXTLINE(bugprone-branch-clone) |
659 | auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); |
660 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
661 | std::vector<double> res(size); |
662 | for (const auto idx : c10::irange(size)) { |
663 | PyObject* obj = |
664 | tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); |
665 | try { |
666 | res[idx] = THPUtils_unpackDouble(obj); |
667 | } catch (const std::exception& e) { |
668 | throw TypeError( |
669 | "%s(): argument '%s' must be %s, but found element of type %s at pos %ld" , |
670 | signature.name.c_str(), |
671 | signature.params[i].name.c_str(), |
672 | signature.params[i].type_name().c_str(), |
673 | Py_TYPE(obj)->tp_name, |
674 | idx + 1); |
675 | } |
676 | } |
677 | return res; |
678 | } |
679 | |
680 | inline c10::OptionalArray<double> PythonArgs::doublelistOptional(int i) { |
681 | if (!args[i]) { |
682 | return {}; |
683 | } |
684 | return this->getDoublelist(i); |
685 | } |
686 | |
687 | inline std::vector<double> PythonArgs::doublelist(int i) { |
688 | if (!args[i]) { |
689 | return {}; |
690 | } |
691 | return this->getDoublelist(i); |
692 | } |
693 | |
694 | inline at::ScalarType PythonArgs::scalartypeWithDefault( |
695 | int i, |
696 | at::ScalarType default_scalartype) { |
697 | if (!args[i]) |
698 | return default_scalartype; |
699 | return scalartype(i); |
700 | } |
701 | |
702 | inline at::ScalarType PythonArgs::scalartype(int i) { |
703 | if (!args[i]) { |
704 | auto scalartype = signature.params[i].default_scalartype; |
705 | return (scalartype == at::ScalarType::Undefined) |
706 | ? torch::tensors::get_default_scalar_type() |
707 | : scalartype; |
708 | } |
709 | PyObject* obj = args[i]; |
710 | if (obj == (PyObject*)&PyFloat_Type) { |
711 | return at::ScalarType::Double; |
712 | } |
713 | if (obj == (PyObject*)&PyBool_Type) { |
714 | return at::ScalarType::Bool; |
715 | } |
716 | if (obj == (PyObject*)&PyLong_Type) { |
717 | return at::ScalarType::Long; |
718 | } |
719 | return reinterpret_cast<THPDtype*>(obj)->scalar_type; |
720 | } |
721 | |
722 | inline c10::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) { |
723 | if (!args[i]) |
724 | return c10::nullopt; |
725 | return scalartype(i); |
726 | } |
727 | |
728 | inline at::Layout toLayout(PyObject* obj) { |
729 | const auto layout = reinterpret_cast<THPLayout*>(obj); |
730 | return layout->layout; |
731 | } |
732 | |
733 | inline at::Layout PythonArgs::layout(int i) { |
734 | if (!args[i]) |
735 | return signature.params[i].default_layout; |
736 | return toLayout(args[i]); |
737 | } |
738 | |
739 | inline at::Layout PythonArgs::layoutWithDefault( |
740 | int i, |
741 | at::Layout default_layout) { |
742 | if (!args[i]) |
743 | return default_layout; |
744 | return layout(i); |
745 | } |
746 | |
747 | inline c10::optional<at::Layout> PythonArgs::layoutOptional(int i) { |
748 | if (!args[i]) |
749 | return c10::nullopt; |
750 | return layout(i); |
751 | } |
752 | |
753 | inline at::Device toDevice(PyObject* obj) { |
754 | if (THPDevice_Check(obj)) { |
755 | const auto device = reinterpret_cast<THPDevice*>(obj); |
756 | return device->device; |
757 | } |
758 | if (THPUtils_checkLong(obj)) { |
759 | const auto device_index = THPUtils_unpackLong(obj); |
760 | TORCH_CHECK(device_index >= 0, "Device index must not be negative" ); |
761 | return at::Device(DeviceType::CUDA, device_index); |
762 | } |
763 | const std::string& device_str = THPUtils_unpackString(obj); |
764 | return at::Device(device_str); |
765 | } |
766 | |
767 | inline at::Device PythonArgs::device(int i) { |
768 | if (!args[i]) { |
769 | return torch::tensors::get_default_device(); |
770 | } |
771 | return toDevice(args[i]); |
772 | } |
773 | |
774 | inline at::Device PythonArgs::deviceWithDefault( |
775 | int i, |
776 | const at::Device& default_device) { |
777 | if (!args[i]) |
778 | return default_device; |
779 | return device(i); |
780 | } |
781 | |
782 | inline c10::optional<at::Device> PythonArgs::deviceOptional(int i) { |
783 | if (!args[i]) |
784 | return c10::nullopt; |
785 | return device(i); |
786 | } |
787 | |
788 | inline at::Dimname PythonArgs::dimname(int i) { |
789 | TORCH_INTERNAL_ASSERT(args[i] != nullptr); |
790 | return THPDimname_parse(args[i]); |
791 | } |
792 | |
793 | inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) { |
794 | auto tuple = PyTuple_Check(arg); |
795 | // NOLINTNEXTLINE(bugprone-branch-clone) |
796 | auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); |
797 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
798 | std::vector<at::Dimname> res; |
799 | res.reserve(size); |
800 | for (const auto idx : c10::irange(size)) { |
801 | PyObject* obj = |
802 | tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); |
803 | res.push_back(THPDimname_parse(obj)); |
804 | } |
805 | return res; |
806 | } |
807 | |
808 | inline c10::optional<std::vector<at::Dimname>> PythonArgs:: |
809 | toDimnameListOptional(int i) { |
810 | if (!args[i]) |
811 | return c10::nullopt; |
812 | return parseDimnameList(args[i]); |
813 | } |
814 | |
815 | inline std::vector<at::Dimname> PythonArgs::dimnamelist(int i) { |
816 | TORCH_INTERNAL_ASSERT(args[i]); |
817 | PyObject* arg = args[i]; |
818 | auto size = signature.params[i].size; |
819 | TORCH_INTERNAL_ASSERT(size == 0 || size == 1); |
820 | if (size == 1 && THPUtils_checkDimname(arg)) { |
821 | return {THPDimname_parse(arg)}; |
822 | } |
823 | return parseDimnameList(arg); |
824 | } |
825 | |
826 | inline at::MemoryFormat PythonArgs::memoryformat(int i) { |
827 | if (!args[i]) |
828 | return at::MemoryFormat::Contiguous; |
829 | TORCH_CHECK( |
830 | THPMemoryFormat_Check(args[i]), |
831 | "memory_format arg must be an instance of the torch.memory_format" ); |
832 | const auto memory_format = reinterpret_cast<THPMemoryFormat*>(args[i]); |
833 | return memory_format->memory_format; |
834 | } |
835 | |
836 | inline c10::optional<at::MemoryFormat> PythonArgs::memoryformatOptional(int i) { |
837 | if (!args[i]) |
838 | return c10::nullopt; |
839 | return memoryformat(i); |
840 | } |
841 | |
842 | inline at::QScheme PythonArgs::toQScheme(int i) { |
843 | if (!args[i]) |
844 | return at::kPerTensorAffine; |
845 | TORCH_CHECK( |
846 | THPQScheme_Check(args[i]), |
847 | "qscheme arg must be an instance of the torch.qscheme" ); |
848 | const auto qscheme = reinterpret_cast<THPQScheme*>(args[i]); |
849 | return qscheme->qscheme; |
850 | } |
851 | |
852 | inline std::string PythonArgs::string(int i) { |
853 | return stringWithDefault(i, signature.params[i].default_string); |
854 | } |
855 | |
856 | inline std::string PythonArgs::stringWithDefault( |
857 | int i, |
858 | const std::string& default_str) { |
859 | if (!args[i]) |
860 | return default_str; |
861 | return THPUtils_unpackString(args[i]); |
862 | } |
863 | |
864 | inline c10::optional<std::string> PythonArgs::stringOptional(int i) { |
865 | if (!args[i]) |
866 | return c10::nullopt; |
867 | return THPUtils_unpackString(args[i]); |
868 | } |
869 | |
870 | inline c10::string_view PythonArgs::stringView(int i) { |
871 | return stringViewWithDefault(i, signature.params[i].default_string); |
872 | } |
873 | |
874 | inline c10::string_view PythonArgs::stringViewWithDefault( |
875 | int i, |
876 | const c10::string_view default_str) { |
877 | if (!args[i]) |
878 | return default_str; |
879 | return THPUtils_unpackStringView(args[i]); |
880 | } |
881 | |
882 | inline c10::optional<c10::string_view> PythonArgs::stringViewOptional(int i) { |
883 | if (!args[i]) |
884 | return c10::nullopt; |
885 | return THPUtils_unpackStringView(args[i]); |
886 | } |
887 | |
888 | inline int64_t PythonArgs::toInt64(int i) { |
889 | if (!args[i]) |
890 | return signature.params[i].default_int; |
891 | if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { |
892 | auto& var = THPVariable_Unpack(args[i]); |
893 | jit::tracer::ArgumentStash::stashValue( |
894 | signature.params[i].name, idx, var, c10::IntType::get()); |
895 | } |
896 | return THPUtils_unpackLong(args[i]); |
897 | } |
898 | |
899 | inline c10::SymInt PythonArgs::toSymInt(int i) { |
900 | if (!args[i]) { |
901 | return c10::SymInt(signature.params[i].default_int); |
902 | } |
903 | if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { |
904 | auto& var = THPVariable_Unpack(args[i]); |
905 | jit::tracer::ArgumentStash::stashValue( |
906 | signature.params[i].name, idx, var, c10::IntType::get()); |
907 | } |
908 | |
909 | return py::cast<c10::SymInt>(py::handle(args[i])); |
910 | } |
911 | |
912 | inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) { |
913 | if (!args[i]) |
914 | return default_int; |
915 | return toInt64(i); |
916 | } |
917 | |
918 | inline c10::optional<int64_t> PythonArgs::toInt64Optional(int i) { |
919 | if (!args[i]) |
920 | return c10::nullopt; |
921 | return toInt64(i); |
922 | } |
923 | |
924 | inline c10::optional<c10::SymInt> PythonArgs::toSymIntOptional(int i) { |
925 | if (!args[i]) |
926 | return c10::nullopt; |
927 | return toSymInt(i); |
928 | } |
929 | |
930 | inline c10::optional<bool> PythonArgs::toBoolOptional(int i) { |
931 | if (!args[i]) { |
932 | return c10::nullopt; |
933 | } |
934 | return toBool(i); |
935 | } |
936 | |
937 | inline c10::optional<double> PythonArgs::toDoubleOptional(int i) { |
938 | if (!args[i]) { |
939 | return c10::nullopt; |
940 | } |
941 | return toDouble(i); |
942 | } |
943 | |
944 | inline double PythonArgs::toDouble(int i) { |
945 | if (!args[i]) |
946 | return signature.params[i].default_double; |
947 | return THPUtils_unpackDouble(args[i]); |
948 | } |
949 | |
950 | inline double PythonArgs::toDoubleWithDefault(int i, double default_double) { |
951 | if (!args[i]) |
952 | return default_double; |
953 | return toDouble(i); |
954 | } |
955 | |
956 | inline c10::complex<double> PythonArgs::toComplex(int i) { |
957 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
958 | c10::complex<double> default_value = *const_cast<c10::complex<double>*>( |
959 | reinterpret_cast<const c10::complex<double>*>( |
960 | signature.params[i].default_complex)); |
961 | if (!args[i]) |
962 | return default_value; |
963 | return THPUtils_unpackComplexDouble(args[i]); |
964 | } |
965 | |
966 | inline c10::complex<double> PythonArgs::toComplexWithDefault( |
967 | int i, |
968 | c10::complex<double> default_value) { |
969 | if (!args[i]) |
970 | return default_value; |
971 | return toComplex(i); |
972 | } |
973 | |
974 | inline bool PythonArgs::toBool(int i) { |
975 | if (!args[i]) |
976 | return signature.params[i].default_bool; |
977 | return args[i] == Py_True; |
978 | } |
979 | |
980 | inline bool PythonArgs::toBoolWithDefault(int i, bool default_bool) { |
981 | if (!args[i]) |
982 | return default_bool; |
983 | return toBool(i); |
984 | } |
985 | |
986 | inline bool PythonArgs::isNone(int i) { |
987 | return args[i] == nullptr; |
988 | } |
989 | |
990 | inline c10::optional<at::Generator> PythonArgs::generator(int i) { |
991 | if (!args[i]) |
992 | return c10::nullopt; |
993 | return reinterpret_cast<THPGenerator*>(args[i])->cdata; |
994 | } |
995 | |
996 | inline at::Storage PythonArgs::storage(int i) { |
997 | if (!args[i]) |
998 | return at::Storage(); |
999 | return createStorage(args[i]); |
1000 | } |
1001 | |
1002 | inline at::Storage PythonArgs::storage( |
1003 | int i, |
1004 | at::ScalarType& storage_scalar_type, |
1005 | bool& is_typed_storage) { |
1006 | at::Storage storage; |
1007 | if (!args[i]) { |
1008 | storage = at::Storage(); |
1009 | is_typed_storage = false; |
1010 | storage_scalar_type = at::ScalarType::Undefined; |
1011 | } else { |
1012 | storage = |
1013 | createStorageGetType(args[i], storage_scalar_type, is_typed_storage); |
1014 | } |
1015 | return storage; |
1016 | } |
1017 | |
1018 | inline c10::Stream PythonArgs::stream(int i) { |
1019 | if (!args[i]) |
1020 | return c10::Stream( |
1021 | c10::Stream::Default::DEFAULT, c10::Device(DeviceType::CPU, -1)); |
1022 | if (!THPStream_Check(args[i])) { |
1023 | throw TypeError( |
1024 | "expected Stream object. Got '%s'" , Py_TYPE(args[i])->tp_name); |
1025 | } |
1026 | return c10::Stream::unpack3( |
1027 | ((THPStream*)args[i])->stream_id, |
1028 | ((THPStream*)args[i])->device_index, |
1029 | static_cast<DeviceType>(((THPStream*)args[i])->device_type)); |
1030 | } |
1031 | |
1032 | inline PyObject* PythonArgs::pyobject(int i) { |
1033 | if (!args[i]) |
1034 | return Py_None; |
1035 | return args[i]; |
1036 | } |
1037 | |
1038 | /* |
1039 | * |
1040 | * Handle __torch_function__ overrides if we know that there are overloaded |
1041 | * arguments. All objects stored in r.overloaded_args must have a |
1042 | * __torch_function__ implementation and the arguments must be ordered in order |
1043 | * of precedence. Precedence goes from left to right in the order of the |
1044 | * signature of the function the overloaded arguments were passed to, except |
1045 | * subclasses are always considered before superclasses. |
1046 | * |
1047 | * If the result of calling __torch_function__ is NotImplemented, the |
1048 | * next implementation in the precedence order is called. If all |
1049 | * arguments return NotImplemented from their __torch_function__ |
1050 | * implementation, a TypeError is raised in Python. |
1051 | * |
1052 | * Assumes overloaded_args has at least one entry. All entries must have |
1053 | * a __torch_function__ attribute that resolves to a callable that |
1054 | * accepts a torch API function, a tuple of arguments, and a dict of |
1055 | * keyword arguments for the torch API function. |
1056 | * |
1057 | * It is sufficient to call PythonArgs::has_torch_function before |
1058 | * calling this function to verify that there are valid arguments |
1059 | * present. If that is not done then special care must be taken to |
1060 | * ensure there are arguments that are overloaded with |
1061 | * __torch_function__. |
1062 | * |
1063 | * See torch._overrides.handle_torch_function for the equivalent |
1064 | * code in the pure-python implementation. |
1065 | * |
1066 | * 'r' is a parsed PythonArgs instance, returned from |
1067 | * PythonArgParser::parse. |
1068 | * |
1069 | * 'args' is a reference to the python tuple of arguments to the torch |
1070 | * API function. |
1071 | * |
1072 | * 'kwargs' is a reference to the python dict of keyword arguments to |
1073 | * the torch API function. |
1074 | * |
1075 | * 'torch_api' is a reference to a python torch API namespace. |
1076 | * |
1077 | * 'torch_api_function' is the reference to the original torch method, usually, |
1078 | * we can use torch_api and func_name to get torch_api_function. In some cases, |
1079 | * e.g., torch custom op, we create the function in C++, if we still use |
1080 | * torch_api and func_name to fetch original api, a cyclic call will happen. |
1081 | * |
1082 | * 'overloaded_args' is the args which have overloaded __torch_function__. |
1083 | * |
1084 | * 'func_name' is the named of the original torch method. |
1085 | * |
1086 | * TODO: we could use different names for the following 'handle_torch_function' |
1087 | * instead of overloading. |
1088 | * |
1089 | */ |
1090 | // Used for Tensor methods with arguments. |
1091 | auto handle_torch_function( |
1092 | PythonArgs& r, |
1093 | PyObject* self, |
1094 | PyObject* args, |
1095 | PyObject* kwargs, |
1096 | PyObject* torch_api, |
1097 | const char* module_name, |
1098 | const char* func_name_override = nullptr) -> PyObject*; |
1099 | |
1100 | // Used for functions which needs to parse python args. |
1101 | auto handle_torch_function( |
1102 | PythonArgs& r, |
1103 | PyObject* args, |
1104 | PyObject* kwargs, |
1105 | PyObject* torch_api, |
1106 | const char* module_name, |
1107 | const char* func_name_override = nullptr) -> PyObject*; |
1108 | |
1109 | // Used for functions that have no argument parsing. |
1110 | auto handle_torch_function( |
1111 | PyObject* self, |
1112 | const std::string& func_name, |
1113 | PyObject* args = nullptr, |
1114 | PyObject* kwargs = nullptr, |
1115 | PyObject* torch_api = THPVariableClass, |
1116 | const std::string& module_name = "torch.Tensor" ) -> PyObject*; |
1117 | |
1118 | // Used for functions created in C++, e.g., C++ custom op, which doesn't use |
1119 | // PythonArgParser to get overloaded_args. |
1120 | enum class TorchFunctionName { TorchFunction, TorchDispatch }; |
1121 | |
1122 | auto TORCH_API handle_torch_function_no_python_arg_parser( |
1123 | at::ArrayRef<py::handle> overloaded_args, |
1124 | PyObject* args, |
1125 | PyObject* kwargs, |
1126 | const char* func_name, |
1127 | PyObject* torch_api_function, |
1128 | const char* module_name, |
1129 | TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction) |
1130 | -> PyObject*; |
1131 | |
1132 | // Used for getters of Tensor properties |
1133 | auto handle_torch_function_getter( |
1134 | THPVariable* self, |
1135 | const std::string& property_name) -> PyObject*; |
1136 | |
1137 | // Used for setters of Tensor properties. |
1138 | auto handle_torch_function_setter( |
1139 | THPVariable* self, |
1140 | const std::string& property_name, |
1141 | PyObject* value) -> int; |
1142 | |
1143 | // Used for __getitem__ and __setitem__ |
1144 | auto handle_torch_function_indexing( |
1145 | PyObject* self, |
1146 | PyObject* index, |
1147 | PyObject* val = nullptr) -> PyObject*; |
1148 | |
1149 | /* |
1150 | * Check if the input obj is Tensor type, including its subclass, or overloaded |
1151 | * type. If the type defines __torch_function__, it also returns true. |
1152 | * Otherwise returns flase. If the class is not torch.Tensor, and it defines |
1153 | * __torch_function__, we append obj to overloaded_args. |
1154 | * |
1155 | * 'obj': the input argument to be checked |
1156 | * 'overloaded_args': the vector to append the overloaded args. |
1157 | */ |
1158 | bool is_tensor_and_append_overloaded( |
1159 | PyObject* obj, |
1160 | std::vector<py::handle>* overloaded_args); |
1161 | |
1162 | /* |
1163 | * Check if the input obj is Tensor List or Tensor Tuple type. First check |
1164 | * whether obj is Tuple or List type, if true, iterate over each element and |
1165 | * check whether it is Tensor type, including its subclass or overloaded type. |
1166 | * At the same time, the overloaded arg is appended to the overloaded_args. |
1167 | * |
1168 | * 'obj': the input argument to be checked |
1169 | * 'overloaded_args': the vector to append the overloaded args. |
1170 | * 'argnum': the number of total arguments of the function being checked. |
1171 | * 'throw_error': whether throw error if any element in the list or tuple is |
1172 | * not tensor type or overloaded. |
1173 | */ |
1174 | bool is_tensor_list_and_append_overloaded( |
1175 | PyObject* obj, |
1176 | std::vector<py::handle>* overloaded_args, |
1177 | int argnum, |
1178 | bool throw_error); |
1179 | |
1180 | /* Given an argument that is definitely a tensor and is definitely overloaded, |
1181 | * append it to the overloaded arguments list. Use this instead of |
1182 | * is_tensor_and_append_overloaded in situations where you have a PyObject |
1183 | * and you know it definitely is a Tensor and it is definitely overloaded. |
1184 | * |
1185 | * 'overloaded_args': the vector to append the overloaded args |
1186 | * 'obj': the input tensor that is overloaded |
1187 | */ |
1188 | void append_overloaded_tensor( |
1189 | std::vector<py::handle>* overloaded_args, |
1190 | PyObject* obj); |
1191 | |
1192 | /* Given an argument that is definitely a type and is definitely overloaded, |
1193 | * append it to the overloaded arguments list. Use this only with |
1194 | * __torch_dispatch__, where we operate on classes that have a |
1195 | * __torch_dispatch__ classmethod. |
1196 | * |
1197 | * 'overloaded_args': the vector to append the overloaded type |
1198 | * 'obj': the input class that has a __torch_dispatch__ classmethod. |
1199 | */ |
1200 | void append_overloaded_type( |
1201 | std::vector<py::handle>* overloaded_args, |
1202 | PyObject* obj); |
1203 | |
1204 | } // namespace torch |
1205 | |