1#include <torch/csrc/Dtype.h>
2#include <torch/csrc/DynamicTypes.h>
3#include <torch/csrc/Exceptions.h>
4#include <torch/csrc/autograd/function.h>
5#include <torch/csrc/autograd/functions/basic_ops.h>
6#include <torch/csrc/autograd/functions/utils.h>
7#include <torch/csrc/autograd/generated/variable_factories.h>
8#include <torch/csrc/autograd/python_torch_functions.h>
9#include <torch/csrc/autograd/python_variable.h>
10#include <torch/csrc/autograd/utils/wrap_outputs.h>
11#include <torch/csrc/jit/frontend/tracer.h>
12#include <torch/csrc/utils/cuda_lazy_init.h>
13#include <torch/csrc/utils/out_types.h>
14#include <torch/csrc/utils/pybind.h>
15#include <torch/csrc/utils/pycfunction_helpers.h>
16#include <torch/csrc/utils/python_arg_parser.h>
17#include <torch/csrc/utils/structseq.h>
18#include <torch/csrc/utils/tensor_layouts.h>
19#include <torch/csrc/utils/tensor_new.h>
20#include <torch/csrc/utils/tensor_numpy.h>
21
22#include <ATen/ATen.h>
23#include <ATen/FunctionalTensorWrapper.h>
24
25#include <Python.h>
26#include <fmt/format.h>
27#include <pybind11/pybind11.h>
28#include <utility>
29#include <vector>
30
31using at::DeviceGuard;
32using at::DimnameList;
33using at::IntArrayRef;
34using at::OptionalDeviceGuard;
35using at::Scalar;
36using at::Tensor;
37using at::TensorList;
38using at::TensorOptions;
39
40using torch::utils::check_out_type_matches;
41using namespace torch::autograd::utils;
42
43namespace torch {
44namespace autograd {
45
46// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
47PyObject* THPVariableFunctionsModule = nullptr;
48
49inline Tensor dispatch_range(
50 const Scalar& start,
51 const Scalar& end,
52 const Scalar& step,
53 Tensor result) {
54 pybind11::gil_scoped_release no_gil;
55 OptionalDeviceGuard device_guard(device_of(result));
56 return at::range_out(result, start, end, step);
57}
58
59inline Tensor dispatch_range(
60 const Scalar& start,
61 const Scalar& end,
62 const Scalar& step,
63 const TensorOptions& options) {
64 torch::utils::maybe_initialize_cuda(options);
65 pybind11::gil_scoped_release no_gil;
66 DeviceGuard device_guard(options.device());
67 return torch::range(start, end, step, options);
68}
69
70static PyObject* THPVariable_range(
71 PyObject* self,
72 PyObject* args,
73 PyObject* kwargs) {
74 HANDLE_TH_ERRORS
75 static PythonArgParser parser({
76 "range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
77 });
78
79 ParsedArgs<8> parsed_args;
80 auto r = parser.parse(args, kwargs, parsed_args);
81
82 if (r.idx == 0) {
83 auto ret = PyErr_WarnEx(
84 PyExc_UserWarning,
85 "torch.range is deprecated and will be removed in a future release "
86 "because its behavior is inconsistent with Python's range builtin. "
87 "Instead, use torch.arange, which produces values in [start, end).",
88 1);
89 if (ret != 0)
90 throw python_error();
91 if (r.isNone(3)) {
92 const auto options = TensorOptions()
93 .dtype(r.scalartype(4))
94 .device(r.device(6))
95 .layout(r.layout(5))
96 .requires_grad(r.toBool(7));
97 return wrap(
98 dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options));
99 } else {
100 check_out_type_matches(
101 r.tensor(3),
102 r.scalartype(4),
103 r.isNone(4),
104 r.layout(5),
105 r.device(6),
106 r.isNone(6));
107 return wrap(
108 dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3))
109 .set_requires_grad(r.toBool(7)));
110 }
111 }
112 Py_RETURN_NONE;
113 END_HANDLE_TH_ERRORS
114}
115
116// implemented on python object to allow torch.as_tensor to be constructed with
117// arbitrarily nested python objects - list, tuple, np array, scalar, etc.
118static PyObject* THPVariable_as_tensor(
119 PyObject* self,
120 PyObject* args,
121 PyObject* kwargs) {
122 HANDLE_TH_ERRORS
123 static PythonArgParser parser({
124 "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
125 });
126
127 ParsedArgs<3> parsed_args;
128 auto r = parser.parse(args, kwargs, parsed_args);
129 if (r.has_torch_function()) {
130 return handle_torch_function(
131 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
132 }
133 jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR);
134 return THPVariable_Wrap(torch::utils::as_tensor(
135 torch::tensors::get_default_dispatch_key(),
136 torch::tensors::get_default_scalar_type(),
137 r));
138 END_HANDLE_TH_ERRORS
139}
140
141// implemented on python object here because PyObject currently not natively
142// declarable See: ATen/native/README.md for more context
143static PyObject* THPVariable_from_numpy(PyObject* module, PyObject* arg) {
144 HANDLE_TH_ERRORS
145 jit::tracer::warn("torch.from_numpy", jit::tracer::WARN_CONSTRUCTOR);
146 return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg));
147 END_HANDLE_TH_ERRORS
148}
149
150static Tensor dispatch_nonzero(const Tensor& self) {
151 pybind11::gil_scoped_release no_gil;
152 OptionalDeviceGuard device_guard(device_of(self));
153 return self.nonzero();
154}
155
156static Tensor dispatch_nonzero(const Tensor& self, Tensor out) {
157 pybind11::gil_scoped_release no_gil;
158 OptionalDeviceGuard device_guard(device_of(self));
159 return at::nonzero_out(out, self);
160}
161
162static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor& self) {
163 pybind11::gil_scoped_release no_gil;
164 OptionalDeviceGuard device_guard(device_of(self));
165 return self.nonzero_numpy();
166}
167
168static PyObject* THPVariable_nonzero(
169 PyObject* self,
170 PyObject* args,
171 PyObject* kwargs);
172
173#define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \
174 static PyObject* THPVariable_##NAME( \
175 PyObject* self, PyObject* args, PyObject* kwargs) { \
176 HANDLE_TH_ERRORS \
177 static PythonArgParser parser SIGNATURES; \
178 ParsedArgs<NARGS> parsed_args; \
179 auto r = parser.parse(args, kwargs, parsed_args); \
180 if (r.has_torch_function()) { \
181 return handle_torch_function( \
182 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \
183 } \
184 jit::tracer::warn("torch." #NAME, jit::tracer::WARN_CONSTRUCTOR); \
185 return THPVariable_Wrap(torch::utils::NAME##_ctor( \
186 torch::tensors::get_default_dispatch_key(), \
187 torch::tensors::get_default_scalar_type(), \
188 r)); \
189 END_HANDLE_TH_ERRORS \
190 }
191
192THPVARIABLE_SPARSE_COMPRESSED_CTOR(
193 sparse_compressed_tensor,
194 10,
195 ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
196 "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
197THPVARIABLE_SPARSE_COMPRESSED_CTOR(
198 sparse_csr_tensor,
199 10,
200 ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
201 "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
202THPVARIABLE_SPARSE_COMPRESSED_CTOR(
203 sparse_csc_tensor,
204 10,
205 ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
206 "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
207THPVARIABLE_SPARSE_COMPRESSED_CTOR(
208 sparse_bsr_tensor,
209 10,
210 ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
211 "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
212THPVARIABLE_SPARSE_COMPRESSED_CTOR(
213 sparse_bsc_tensor,
214 10,
215 ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
216 "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
217
218static PyObject* THPVariable_sparse_coo_tensor(
219 PyObject* self,
220 PyObject* args,
221 PyObject* kwargs) {
222 HANDLE_TH_ERRORS
223 static PythonArgParser parser({
224 "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
225 "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
226 "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
227 });
228
229 ParsedArgs<7> parsed_args;
230 auto r = parser.parse(args, kwargs, parsed_args);
231 if (r.has_torch_function()) {
232 return handle_torch_function(
233 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
234 }
235 jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR);
236 return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(
237 torch::tensors::get_default_dispatch_key(),
238 torch::tensors::get_default_scalar_type(),
239 r));
240 END_HANDLE_TH_ERRORS
241}
242
243// implemented on python object to allow torch.tensor to be constructed with
244// arbitrarily nested python objects - list, tuple, np array, scalar, etc.
245static PyObject* THPVariable_tensor(
246 PyObject* self,
247 PyObject* args,
248 PyObject* kwargs) {
249 HANDLE_TH_ERRORS
250 static PythonArgParser parser({
251 "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",
252 });
253
254 constexpr int ctor_num_args = 6;
255 ParsedArgs<ctor_num_args> parsed_args;
256 auto r = parser.parse(args, kwargs, parsed_args);
257 if (r.has_torch_function()) {
258 return handle_torch_function(
259 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
260 }
261 jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR);
262 return THPVariable_Wrap(torch::utils::tensor_ctor(
263 torch::tensors::get_default_dispatch_key(),
264 torch::tensors::get_default_scalar_type(),
265 r));
266 END_HANDLE_TH_ERRORS
267}
268
269static PyObject* THPVariable_get_device(
270 PyObject* self_,
271 PyObject* args,
272 PyObject* kwargs) {
273 HANDLE_TH_ERRORS
274 static PythonArgParser parser(
275 {
276 "get_device(Tensor input)",
277 },
278 /*traceable=*/false);
279
280 ParsedArgs<1> parsed_args;
281 auto r = parser.parse(args, kwargs, parsed_args);
282
283 if (r.idx == 0) {
284 return wrap(r.tensor(0).get_device());
285 }
286 Py_RETURN_NONE;
287 END_HANDLE_TH_ERRORS
288}
289
290static PyObject* THPVariable_frombuffer(
291 PyObject* self_,
292 PyObject* args,
293 PyObject* kwargs) {
294 HANDLE_TH_ERRORS
295 static PythonArgParser parser(
296 {
297 "frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)",
298 },
299 /*traceable=*/false);
300
301 ParsedArgs<5> parsed_args;
302 auto r = parser.parse(args, kwargs, parsed_args);
303
304 if (r.idx == 0) {
305 auto buffer = r.pyobject(0);
306 auto dtype = r.scalartype(1);
307 auto count = r.toInt64(2);
308 auto offset = r.toInt64(3);
309 auto requires_grad = r.toBool(4);
310
311 TORCH_CHECK_VALUE(
312 PyObject_CheckBuffer(buffer) != 0,
313 "object does not implement Python buffer protocol.");
314 return wrap(torch::utils::tensor_frombuffer(
315 buffer, dtype, count, offset, requires_grad));
316 }
317
318 Py_RETURN_NONE;
319 END_HANDLE_TH_ERRORS
320}
321
322static PyObject* THPVariable_asarray(
323 PyObject* self_,
324 PyObject* args,
325 PyObject* kwargs) {
326 HANDLE_TH_ERRORS
327 static PythonArgParser parser(
328 {
329 "asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)",
330 },
331 /*traceable=*/false);
332
333 ParsedArgs<5> parsed_args;
334 auto r = parser.parse(args, kwargs, parsed_args);
335
336 if (r.idx == 0) {
337 auto obj = r.pyobject(0);
338 auto dtype = r.scalartypeOptional(1);
339 auto device = r.deviceOptional(2);
340 auto copy = r.toBoolOptional(3);
341 auto requires_grad = r.toBool(4);
342 return wrap(torch::utils::asarray(obj, dtype, device, copy, requires_grad));
343 }
344
345 Py_RETURN_NONE;
346 END_HANDLE_TH_ERRORS
347}
348
349static PyObject* THPVariable_numel(
350 PyObject* self_,
351 PyObject* args,
352 PyObject* kwargs);
353
354static PyObject* THPVariable__to_functional_tensor(
355 PyObject* self,
356 PyObject* args,
357 PyObject* kwargs) {
358 HANDLE_TH_ERRORS
359 static PythonArgParser parser(
360 {"_to_functional_tensor(Tensor t, *, bool mirror_autograd_meta=False)"},
361 /*traceable=*/true);
362
363 ParsedArgs<2> parsed_args;
364 auto r = parser.parse(args, kwargs, parsed_args);
365 auto self_ = r.tensor(0);
366 auto mirror_autograd_meta = r.toBool(1);
367 auto wrapped = at::functionalization::impl::to_functional_tensor(self_);
368 if (mirror_autograd_meta) {
369 // Here, we unsafely set the grad function on the wrapper to be the same as
370 // the inner. We expect this grad_fn to NEVER be used. It's needed so that
371 // .is_leaf metadata is accurate on the wrapper
372 auto inner_autograd_meta = impl::get_autograd_meta(self_);
373 if (inner_autograd_meta) {
374 wrapped.set_requires_grad(self_.requires_grad());
375 if (wrapped.requires_grad()) {
376 auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
377 new torch::autograd::Error(
378 "Cannot backprop through mirrored meta, file a bug in PyTorch"),
379 torch::autograd::deleteNode);
380 torch::autograd::set_history(wrapped, new_grad_fn);
381 }
382 }
383 }
384 return wrap(std::move(wrapped));
385 END_HANDLE_TH_ERRORS
386}
387
388static PyObject* THPVariable__from_functional_tensor(
389 PyObject* self,
390 PyObject* args,
391 PyObject* kwargs) {
392 HANDLE_TH_ERRORS
393 static PythonArgParser parser(
394 {"_from_functional_tensor(Tensor t)"}, /*traceable=*/true);
395
396 ParsedArgs<1> parsed_args;
397 auto r = parser.parse(args, kwargs, parsed_args);
398 auto self_ = r.tensor(0);
399 auto unwrapped = at::functionalization::impl::from_functional_tensor(self_);
400 return wrap(std::move(unwrapped));
401 END_HANDLE_TH_ERRORS
402}
403
404static PyObject* THPVariable__freeze_functional_tensor(
405 PyObject* self,
406 PyObject* args,
407 PyObject* kwargs) {
408 HANDLE_TH_ERRORS
409 static PythonArgParser parser(
410 {"_freeze_functional_tensor(Tensor t)"}, /*traceable=*/true);
411
412 ParsedArgs<1> parsed_args;
413 auto r = parser.parse(args, kwargs, parsed_args);
414 auto self_ = r.tensor(0);
415 at::functionalization::impl::freeze_functional_tensor(self_);
416 Py_RETURN_NONE;
417 END_HANDLE_TH_ERRORS
418}
419
420static PyObject* THPVariable__is_functional_tensor(
421 PyObject* self,
422 PyObject* args,
423 PyObject* kwargs) {
424 HANDLE_TH_ERRORS
425 static PythonArgParser parser(
426 {"_is_functional_tensor(Tensor t)"}, /*traceable=*/true);
427
428 ParsedArgs<1> parsed_args;
429 auto r = parser.parse(args, kwargs, parsed_args);
430 auto self_ = r.tensor(0);
431 if (at::functionalization::impl::isFunctionalTensor(self_)) {
432 Py_RETURN_TRUE;
433 } else {
434 Py_RETURN_FALSE;
435 }
436 END_HANDLE_TH_ERRORS
437}
438
439static PyObject* THPVariable__sync(
440 PyObject* self,
441 PyObject* args,
442 PyObject* kwargs) {
443 HANDLE_TH_ERRORS
444 static PythonArgParser parser({"_sync(Tensor t)"}, /*traceable=*/true);
445
446 ParsedArgs<1> parsed_args;
447 auto r = parser.parse(args, kwargs, parsed_args);
448 auto self_ = r.tensor(0);
449 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_));
450 at::functionalization::impl::sync(self_);
451 Py_RETURN_NONE;
452 END_HANDLE_TH_ERRORS
453}
454
455static PyObject* THPVariable__enable_functionalization(
456 PyObject* self,
457 PyObject* args,
458 PyObject* kwargs) {
459 HANDLE_TH_ERRORS
460 static PythonArgParser parser(
461 {"_enable_functionalization(*, bool reapply_views=False)"},
462 /*traceable=*/true);
463 ParsedArgs<1> parsed_args;
464 auto r = parser.parse(args, kwargs, parsed_args);
465 const auto reapply_views = r.toBool(0);
466
467 if (c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Functionalize)) {
468 TORCH_INTERNAL_ASSERT(
469 false,
470 "multiple layers of mode-style functionalization nesting is not"
471 " currently supported, outside of the functionalize() transform");
472 }
473 c10::impl::tls_set_dispatch_key_included(
474 at::DispatchKey::Functionalize, true);
475 if (reapply_views) {
476 at::functionalization::impl::setFunctionalizationReapplyViewsTLS(true);
477 }
478 Py_RETURN_NONE;
479 END_HANDLE_TH_ERRORS
480}
481
482static PyObject* THPVariable__disable_functionalization(
483 PyObject* self,
484 PyObject* args,
485 PyObject* kwargs) {
486 HANDLE_TH_ERRORS
487 c10::impl::tls_set_dispatch_key_included(
488 at::DispatchKey::Functionalize, false);
489 at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false);
490 Py_RETURN_NONE;
491 END_HANDLE_TH_ERRORS
492}
493
494// XXX: ops that are bound here are not exposed to the C++ api nor the JIT.
495// Any new ops added here should be accompanied with a comment why they are not
496// being registered through native_functions.yaml, and be tagged cpp / JIT
497// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
498static PyMethodDef torch_functions_manual[] = {
499 {"asarray",
500 castPyCFunctionWithKeywords(THPVariable_asarray),
501 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
502 nullptr},
503 {"as_tensor",
504 castPyCFunctionWithKeywords(THPVariable_as_tensor),
505 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
506 nullptr},
507 {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, nullptr},
508 {"frombuffer",
509 castPyCFunctionWithKeywords(THPVariable_frombuffer),
510 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
511 nullptr},
512 {"_is_functional_tensor",
513 castPyCFunctionWithKeywords(THPVariable__is_functional_tensor),
514 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
515 nullptr},
516 {"_to_functional_tensor",
517 castPyCFunctionWithKeywords(THPVariable__to_functional_tensor),
518 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
519 nullptr},
520 {"_from_functional_tensor",
521 castPyCFunctionWithKeywords(THPVariable__from_functional_tensor),
522 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
523 nullptr},
524 {"_freeze_functional_tensor",
525 castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor),
526 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
527 nullptr},
528 {"_sync",
529 castPyCFunctionWithKeywords(THPVariable__sync),
530 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
531 nullptr},
532 {"_enable_functionalization",
533 castPyCFunctionWithKeywords(THPVariable__enable_functionalization),
534 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
535 nullptr},
536 {"_disable_functionalization",
537 castPyCFunctionWithKeywords(THPVariable__disable_functionalization),
538 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
539 nullptr},
540 {"nonzero",
541 castPyCFunctionWithKeywords(THPVariable_nonzero),
542 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
543 nullptr},
544 {"range",
545 castPyCFunctionWithKeywords(THPVariable_range),
546 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
547 nullptr},
548 {"sparse_coo_tensor",
549 castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor),
550 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
551 nullptr},
552 {"sparse_compressed_tensor",
553 castPyCFunctionWithKeywords(THPVariable_sparse_compressed_tensor),
554 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
555 nullptr},
556 {"sparse_csr_tensor",
557 castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor),
558 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
559 nullptr},
560 {"sparse_csc_tensor",
561 castPyCFunctionWithKeywords(THPVariable_sparse_csc_tensor),
562 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
563 nullptr},
564 {"sparse_bsr_tensor",
565 castPyCFunctionWithKeywords(THPVariable_sparse_bsr_tensor),
566 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
567 nullptr},
568 {"sparse_bsc_tensor",
569 castPyCFunctionWithKeywords(THPVariable_sparse_bsc_tensor),
570 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
571 nullptr},
572 {"tensor",
573 castPyCFunctionWithKeywords(THPVariable_tensor),
574 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
575 nullptr},
576 {"get_device",
577 castPyCFunctionWithKeywords(THPVariable_get_device),
578 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
579 nullptr},
580 {"numel",
581 castPyCFunctionWithKeywords(THPVariable_numel),
582 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
583 nullptr},
584};
585
586static PyObject* THPVariable_nonzero(
587 PyObject* self,
588 PyObject* args,
589 PyObject* kwargs) {
590 HANDLE_TH_ERRORS
591 static PythonArgParser parser({
592 "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)",
593 });
594 ParsedArgs<3> parsed_args;
595 auto r = parser.parse(args, kwargs, parsed_args);
596
597 if (r.has_torch_function()) {
598 return handle_torch_function(
599 r, args, kwargs, THPVariableFunctionsModule, "torch");
600 }
601
602 const auto as_tuple = r.toBool(1);
603 const auto has_out = !r.isNone(2);
604
605 if (as_tuple) {
606 TORCH_CHECK(
607 !has_out,
608 "nonzero does not support the out kwarg when as_tuple is True");
609 return wrap(dispatch_nonzero_numpy(r.tensor(0)));
610 }
611
612 if (has_out) {
613 return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2)));
614 }
615
616 return wrap(dispatch_nonzero(r.tensor(0)));
617
618 END_HANDLE_TH_ERRORS
619}
620
621static PyObject* THPVariable_numel(
622 PyObject* self_,
623 PyObject* args,
624 PyObject* kwargs) {
625 HANDLE_TH_ERRORS
626 static PythonArgParser parser(
627 {
628 "numel(Tensor input)",
629 },
630 /*traceable=*/false);
631
632 ParsedArgs<1> parsed_args;
633 auto r = parser.parse(args, kwargs, parsed_args);
634
635 if (r.has_torch_function()) {
636 return handle_torch_function(
637 r, args, kwargs, THPVariableFunctionsModule, "torch");
638 }
639
640 if (r.idx == 0) {
641 return py::cast(r.tensor(0).sym_numel()).release().ptr();
642 }
643 Py_RETURN_NONE;
644 END_HANDLE_TH_ERRORS
645}
646
647// Sharded function definitions
648void gatherTorchFunctions_0(std::vector<PyMethodDef>& torch_functions);
649void gatherTorchFunctions_1(std::vector<PyMethodDef>& torch_functions);
650void gatherTorchFunctions_2(std::vector<PyMethodDef>& torch_functions);
651
652void gatherTorchFunctions(std::vector<PyMethodDef>& torch_functions) {
653 constexpr size_t num_functions =
654 sizeof(torch_functions_manual) / sizeof(torch_functions_manual[0]);
655 torch_functions.assign(
656 torch_functions_manual, torch_functions_manual + num_functions);
657 // NOTE: Must be synced with num_shards in
658 // tools/autograd/gen_python_functions.py
659 gatherTorchFunctions_0(torch_functions);
660 gatherTorchFunctions_1(torch_functions);
661 gatherTorchFunctions_2(torch_functions);
662
663 static std::array<std::pair<const char*, const char*>, 4> aliases{
664 {// Canonical function, alias name
665 {"sspaddmm", "saddmm"},
666 {"mm", "spmm"},
667 {"mm", "dsmm"},
668 {"hspmm", "hsmm"}}};
669
670 for (const auto& alias : aliases) {
671 auto it = std::find_if(
672 torch_functions.begin(),
673 torch_functions.end(),
674 [&](const PyMethodDef& def) {
675 return strcmp(def.ml_name, alias.first) == 0;
676 });
677 TORCH_INTERNAL_ASSERT(
678 it != torch_functions.end(),
679 "Failed to create function alias from ",
680 alias.first,
681 " to ",
682 alias.second);
683 PyMethodDef alias_def = *it;
684 alias_def.ml_name = alias.second;
685
686 torch_functions.push_back(alias_def);
687 }
688
689 torch_functions.push_back({nullptr});
690 torch_functions.shrink_to_fit();
691}
692
693static PyTypeObject THPVariableFunctions = {
694 PyVarObject_HEAD_INIT(
695 nullptr,
696 0) "torch._C._VariableFunctionsClass", /* tp_name */
697 0, /* tp_basicsize */
698 0, /* tp_itemsize */
699 nullptr, /* tp_dealloc */
700 0, /* tp_vectorcall_offset */
701 nullptr, /* tp_getattr */
702 nullptr, /* tp_setattr */
703 nullptr, /* tp_reserved */
704 nullptr, /* tp_repr */
705 nullptr, /* tp_as_number */
706 nullptr, /* tp_as_sequence */
707 nullptr, /* tp_as_mapping */
708 nullptr, /* tp_hash */
709 nullptr, /* tp_call */
710 nullptr, /* tp_str */
711 nullptr, /* tp_getattro */
712 nullptr, /* tp_setattro */
713 nullptr, /* tp_as_buffer */
714 Py_TPFLAGS_DEFAULT, /* tp_flags */
715 nullptr, /* tp_doc */
716 nullptr, /* tp_traverse */
717 nullptr, /* tp_clear */
718 nullptr, /* tp_richcompare */
719 0, /* tp_weaklistoffset */
720 nullptr, /* tp_iter */
721 nullptr, /* tp_iternext */
722 nullptr, /* tp_methods */
723 nullptr, /* tp_members */
724 nullptr, /* tp_getset */
725 nullptr, /* tp_base */
726 nullptr, /* tp_dict */
727 nullptr, /* tp_descr_get */
728 nullptr, /* tp_descr_set */
729 0, /* tp_dictoffset */
730 nullptr, /* tp_init */
731 nullptr, /* tp_alloc */
732 nullptr /* tp_new */
733};
734
735void initTorchFunctions(PyObject* module) {
736 static std::vector<PyMethodDef> torch_functions;
737 gatherTorchFunctions(torch_functions);
738 THPVariableFunctions.tp_methods = torch_functions.data();
739
740 if (PyType_Ready(&THPVariableFunctions) < 0) {
741 throw python_error();
742 }
743 Py_INCREF(&THPVariableFunctions);
744
745 // Steals
746 Py_INCREF(&THPVariableFunctions);
747 if (PyModule_AddObject(
748 module,
749 "_VariableFunctionsClass",
750 reinterpret_cast<PyObject*>(&THPVariableFunctions)) < 0) {
751 throw python_error();
752 }
753 // PyType_GenericNew returns a new reference
754 THPVariableFunctionsModule =
755 PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
756 // PyModule_AddObject steals a reference
757 if (PyModule_AddObject(
758 module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
759 throw python_error();
760 }
761}
762
763} // namespace autograd
764} // namespace torch
765