1 | #include <torch/csrc/Dtype.h> |
2 | #include <torch/csrc/DynamicTypes.h> |
3 | #include <torch/csrc/Exceptions.h> |
4 | #include <torch/csrc/autograd/function.h> |
5 | #include <torch/csrc/autograd/functions/basic_ops.h> |
6 | #include <torch/csrc/autograd/functions/utils.h> |
7 | #include <torch/csrc/autograd/generated/variable_factories.h> |
8 | #include <torch/csrc/autograd/python_torch_functions.h> |
9 | #include <torch/csrc/autograd/python_variable.h> |
10 | #include <torch/csrc/autograd/utils/wrap_outputs.h> |
11 | #include <torch/csrc/jit/frontend/tracer.h> |
12 | #include <torch/csrc/utils/cuda_lazy_init.h> |
13 | #include <torch/csrc/utils/out_types.h> |
14 | #include <torch/csrc/utils/pybind.h> |
15 | #include <torch/csrc/utils/pycfunction_helpers.h> |
16 | #include <torch/csrc/utils/python_arg_parser.h> |
17 | #include <torch/csrc/utils/structseq.h> |
18 | #include <torch/csrc/utils/tensor_layouts.h> |
19 | #include <torch/csrc/utils/tensor_new.h> |
20 | #include <torch/csrc/utils/tensor_numpy.h> |
21 | |
22 | #include <ATen/ATen.h> |
23 | #include <ATen/FunctionalTensorWrapper.h> |
24 | |
25 | #include <Python.h> |
26 | #include <fmt/format.h> |
27 | #include <pybind11/pybind11.h> |
28 | #include <utility> |
29 | #include <vector> |
30 | |
31 | using at::DeviceGuard; |
32 | using at::DimnameList; |
33 | using at::IntArrayRef; |
34 | using at::OptionalDeviceGuard; |
35 | using at::Scalar; |
36 | using at::Tensor; |
37 | using at::TensorList; |
38 | using at::TensorOptions; |
39 | |
40 | using torch::utils::check_out_type_matches; |
41 | using namespace torch::autograd::utils; |
42 | |
43 | namespace torch { |
44 | namespace autograd { |
45 | |
46 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
47 | PyObject* THPVariableFunctionsModule = nullptr; |
48 | |
49 | inline Tensor dispatch_range( |
50 | const Scalar& start, |
51 | const Scalar& end, |
52 | const Scalar& step, |
53 | Tensor result) { |
54 | pybind11::gil_scoped_release no_gil; |
55 | OptionalDeviceGuard device_guard(device_of(result)); |
56 | return at::range_out(result, start, end, step); |
57 | } |
58 | |
59 | inline Tensor dispatch_range( |
60 | const Scalar& start, |
61 | const Scalar& end, |
62 | const Scalar& step, |
63 | const TensorOptions& options) { |
64 | torch::utils::maybe_initialize_cuda(options); |
65 | pybind11::gil_scoped_release no_gil; |
66 | DeviceGuard device_guard(options.device()); |
67 | return torch::range(start, end, step, options); |
68 | } |
69 | |
70 | static PyObject* THPVariable_range( |
71 | PyObject* self, |
72 | PyObject* args, |
73 | PyObject* kwargs) { |
74 | HANDLE_TH_ERRORS |
75 | static PythonArgParser parser({ |
76 | "range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)" , |
77 | }); |
78 | |
79 | ParsedArgs<8> parsed_args; |
80 | auto r = parser.parse(args, kwargs, parsed_args); |
81 | |
82 | if (r.idx == 0) { |
83 | auto ret = PyErr_WarnEx( |
84 | PyExc_UserWarning, |
85 | "torch.range is deprecated and will be removed in a future release " |
86 | "because its behavior is inconsistent with Python's range builtin. " |
87 | "Instead, use torch.arange, which produces values in [start, end)." , |
88 | 1); |
89 | if (ret != 0) |
90 | throw python_error(); |
91 | if (r.isNone(3)) { |
92 | const auto options = TensorOptions() |
93 | .dtype(r.scalartype(4)) |
94 | .device(r.device(6)) |
95 | .layout(r.layout(5)) |
96 | .requires_grad(r.toBool(7)); |
97 | return wrap( |
98 | dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options)); |
99 | } else { |
100 | check_out_type_matches( |
101 | r.tensor(3), |
102 | r.scalartype(4), |
103 | r.isNone(4), |
104 | r.layout(5), |
105 | r.device(6), |
106 | r.isNone(6)); |
107 | return wrap( |
108 | dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)) |
109 | .set_requires_grad(r.toBool(7))); |
110 | } |
111 | } |
112 | Py_RETURN_NONE; |
113 | END_HANDLE_TH_ERRORS |
114 | } |
115 | |
116 | // implemented on python object to allow torch.as_tensor to be constructed with |
117 | // arbitrarily nested python objects - list, tuple, np array, scalar, etc. |
118 | static PyObject* THPVariable_as_tensor( |
119 | PyObject* self, |
120 | PyObject* args, |
121 | PyObject* kwargs) { |
122 | HANDLE_TH_ERRORS |
123 | static PythonArgParser parser({ |
124 | "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)" , |
125 | }); |
126 | |
127 | ParsedArgs<3> parsed_args; |
128 | auto r = parser.parse(args, kwargs, parsed_args); |
129 | if (r.has_torch_function()) { |
130 | return handle_torch_function( |
131 | r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch" ); |
132 | } |
133 | jit::tracer::warn("torch.as_tensor" , jit::tracer::WARN_CONSTRUCTOR); |
134 | return THPVariable_Wrap(torch::utils::as_tensor( |
135 | torch::tensors::get_default_dispatch_key(), |
136 | torch::tensors::get_default_scalar_type(), |
137 | r)); |
138 | END_HANDLE_TH_ERRORS |
139 | } |
140 | |
141 | // implemented on python object here because PyObject currently not natively |
142 | // declarable See: ATen/native/README.md for more context |
143 | static PyObject* THPVariable_from_numpy(PyObject* module, PyObject* arg) { |
144 | HANDLE_TH_ERRORS |
145 | jit::tracer::warn("torch.from_numpy" , jit::tracer::WARN_CONSTRUCTOR); |
146 | return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg)); |
147 | END_HANDLE_TH_ERRORS |
148 | } |
149 | |
150 | static Tensor dispatch_nonzero(const Tensor& self) { |
151 | pybind11::gil_scoped_release no_gil; |
152 | OptionalDeviceGuard device_guard(device_of(self)); |
153 | return self.nonzero(); |
154 | } |
155 | |
156 | static Tensor dispatch_nonzero(const Tensor& self, Tensor out) { |
157 | pybind11::gil_scoped_release no_gil; |
158 | OptionalDeviceGuard device_guard(device_of(self)); |
159 | return at::nonzero_out(out, self); |
160 | } |
161 | |
162 | static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor& self) { |
163 | pybind11::gil_scoped_release no_gil; |
164 | OptionalDeviceGuard device_guard(device_of(self)); |
165 | return self.nonzero_numpy(); |
166 | } |
167 | |
168 | static PyObject* THPVariable_nonzero( |
169 | PyObject* self, |
170 | PyObject* args, |
171 | PyObject* kwargs); |
172 | |
173 | #define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \ |
174 | static PyObject* THPVariable_##NAME( \ |
175 | PyObject* self, PyObject* args, PyObject* kwargs) { \ |
176 | HANDLE_TH_ERRORS \ |
177 | static PythonArgParser parser SIGNATURES; \ |
178 | ParsedArgs<NARGS> parsed_args; \ |
179 | auto r = parser.parse(args, kwargs, parsed_args); \ |
180 | if (r.has_torch_function()) { \ |
181 | return handle_torch_function( \ |
182 | r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \ |
183 | } \ |
184 | jit::tracer::warn("torch." #NAME, jit::tracer::WARN_CONSTRUCTOR); \ |
185 | return THPVariable_Wrap(torch::utils::NAME##_ctor( \ |
186 | torch::tensors::get_default_dispatch_key(), \ |
187 | torch::tensors::get_default_scalar_type(), \ |
188 | r)); \ |
189 | END_HANDLE_TH_ERRORS \ |
190 | } |
191 | |
192 | THPVARIABLE_SPARSE_COMPRESSED_CTOR( |
193 | sparse_compressed_tensor, |
194 | 10, |
195 | ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" , |
196 | "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" })) |
197 | THPVARIABLE_SPARSE_COMPRESSED_CTOR( |
198 | sparse_csr_tensor, |
199 | 10, |
200 | ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" , |
201 | "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" })) |
202 | THPVARIABLE_SPARSE_COMPRESSED_CTOR( |
203 | sparse_csc_tensor, |
204 | 10, |
205 | ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" , |
206 | "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" })) |
207 | THPVARIABLE_SPARSE_COMPRESSED_CTOR( |
208 | sparse_bsr_tensor, |
209 | 10, |
210 | ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" , |
211 | "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" })) |
212 | THPVARIABLE_SPARSE_COMPRESSED_CTOR( |
213 | sparse_bsc_tensor, |
214 | 10, |
215 | ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" , |
216 | "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)" })) |
217 | |
218 | static PyObject* THPVariable_sparse_coo_tensor( |
219 | PyObject* self, |
220 | PyObject* args, |
221 | PyObject* kwargs) { |
222 | HANDLE_TH_ERRORS |
223 | static PythonArgParser parser({ |
224 | "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)" , |
225 | "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)" , |
226 | "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)" , |
227 | }); |
228 | |
229 | ParsedArgs<7> parsed_args; |
230 | auto r = parser.parse(args, kwargs, parsed_args); |
231 | if (r.has_torch_function()) { |
232 | return handle_torch_function( |
233 | r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch" ); |
234 | } |
235 | jit::tracer::warn("torch.sparse_coo_tensor" , jit::tracer::WARN_CONSTRUCTOR); |
236 | return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor( |
237 | torch::tensors::get_default_dispatch_key(), |
238 | torch::tensors::get_default_scalar_type(), |
239 | r)); |
240 | END_HANDLE_TH_ERRORS |
241 | } |
242 | |
243 | // implemented on python object to allow torch.tensor to be constructed with |
244 | // arbitrarily nested python objects - list, tuple, np array, scalar, etc. |
245 | static PyObject* THPVariable_tensor( |
246 | PyObject* self, |
247 | PyObject* args, |
248 | PyObject* kwargs) { |
249 | HANDLE_TH_ERRORS |
250 | static PythonArgParser parser({ |
251 | "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)" , |
252 | }); |
253 | |
254 | constexpr int ctor_num_args = 6; |
255 | ParsedArgs<ctor_num_args> parsed_args; |
256 | auto r = parser.parse(args, kwargs, parsed_args); |
257 | if (r.has_torch_function()) { |
258 | return handle_torch_function( |
259 | r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch" ); |
260 | } |
261 | jit::tracer::warn("torch.tensor" , jit::tracer::WARN_CONSTRUCTOR); |
262 | return THPVariable_Wrap(torch::utils::tensor_ctor( |
263 | torch::tensors::get_default_dispatch_key(), |
264 | torch::tensors::get_default_scalar_type(), |
265 | r)); |
266 | END_HANDLE_TH_ERRORS |
267 | } |
268 | |
269 | static PyObject* THPVariable_get_device( |
270 | PyObject* self_, |
271 | PyObject* args, |
272 | PyObject* kwargs) { |
273 | HANDLE_TH_ERRORS |
274 | static PythonArgParser parser( |
275 | { |
276 | "get_device(Tensor input)" , |
277 | }, |
278 | /*traceable=*/false); |
279 | |
280 | ParsedArgs<1> parsed_args; |
281 | auto r = parser.parse(args, kwargs, parsed_args); |
282 | |
283 | if (r.idx == 0) { |
284 | return wrap(r.tensor(0).get_device()); |
285 | } |
286 | Py_RETURN_NONE; |
287 | END_HANDLE_TH_ERRORS |
288 | } |
289 | |
290 | static PyObject* THPVariable_frombuffer( |
291 | PyObject* self_, |
292 | PyObject* args, |
293 | PyObject* kwargs) { |
294 | HANDLE_TH_ERRORS |
295 | static PythonArgParser parser( |
296 | { |
297 | "frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)" , |
298 | }, |
299 | /*traceable=*/false); |
300 | |
301 | ParsedArgs<5> parsed_args; |
302 | auto r = parser.parse(args, kwargs, parsed_args); |
303 | |
304 | if (r.idx == 0) { |
305 | auto buffer = r.pyobject(0); |
306 | auto dtype = r.scalartype(1); |
307 | auto count = r.toInt64(2); |
308 | auto offset = r.toInt64(3); |
309 | auto requires_grad = r.toBool(4); |
310 | |
311 | TORCH_CHECK_VALUE( |
312 | PyObject_CheckBuffer(buffer) != 0, |
313 | "object does not implement Python buffer protocol." ); |
314 | return wrap(torch::utils::tensor_frombuffer( |
315 | buffer, dtype, count, offset, requires_grad)); |
316 | } |
317 | |
318 | Py_RETURN_NONE; |
319 | END_HANDLE_TH_ERRORS |
320 | } |
321 | |
322 | static PyObject* THPVariable_asarray( |
323 | PyObject* self_, |
324 | PyObject* args, |
325 | PyObject* kwargs) { |
326 | HANDLE_TH_ERRORS |
327 | static PythonArgParser parser( |
328 | { |
329 | "asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)" , |
330 | }, |
331 | /*traceable=*/false); |
332 | |
333 | ParsedArgs<5> parsed_args; |
334 | auto r = parser.parse(args, kwargs, parsed_args); |
335 | |
336 | if (r.idx == 0) { |
337 | auto obj = r.pyobject(0); |
338 | auto dtype = r.scalartypeOptional(1); |
339 | auto device = r.deviceOptional(2); |
340 | auto copy = r.toBoolOptional(3); |
341 | auto requires_grad = r.toBool(4); |
342 | return wrap(torch::utils::asarray(obj, dtype, device, copy, requires_grad)); |
343 | } |
344 | |
345 | Py_RETURN_NONE; |
346 | END_HANDLE_TH_ERRORS |
347 | } |
348 | |
349 | static PyObject* THPVariable_numel( |
350 | PyObject* self_, |
351 | PyObject* args, |
352 | PyObject* kwargs); |
353 | |
354 | static PyObject* THPVariable__to_functional_tensor( |
355 | PyObject* self, |
356 | PyObject* args, |
357 | PyObject* kwargs) { |
358 | HANDLE_TH_ERRORS |
359 | static PythonArgParser parser( |
360 | {"_to_functional_tensor(Tensor t, *, bool mirror_autograd_meta=False)" }, |
361 | /*traceable=*/true); |
362 | |
363 | ParsedArgs<2> parsed_args; |
364 | auto r = parser.parse(args, kwargs, parsed_args); |
365 | auto self_ = r.tensor(0); |
366 | auto mirror_autograd_meta = r.toBool(1); |
367 | auto wrapped = at::functionalization::impl::to_functional_tensor(self_); |
368 | if (mirror_autograd_meta) { |
369 | // Here, we unsafely set the grad function on the wrapper to be the same as |
370 | // the inner. We expect this grad_fn to NEVER be used. It's needed so that |
371 | // .is_leaf metadata is accurate on the wrapper |
372 | auto inner_autograd_meta = impl::get_autograd_meta(self_); |
373 | if (inner_autograd_meta) { |
374 | wrapped.set_requires_grad(self_.requires_grad()); |
375 | if (wrapped.requires_grad()) { |
376 | auto new_grad_fn = std::shared_ptr<torch::autograd::Error>( |
377 | new torch::autograd::Error( |
378 | "Cannot backprop through mirrored meta, file a bug in PyTorch" ), |
379 | torch::autograd::deleteNode); |
380 | torch::autograd::set_history(wrapped, new_grad_fn); |
381 | } |
382 | } |
383 | } |
384 | return wrap(std::move(wrapped)); |
385 | END_HANDLE_TH_ERRORS |
386 | } |
387 | |
388 | static PyObject* THPVariable__from_functional_tensor( |
389 | PyObject* self, |
390 | PyObject* args, |
391 | PyObject* kwargs) { |
392 | HANDLE_TH_ERRORS |
393 | static PythonArgParser parser( |
394 | {"_from_functional_tensor(Tensor t)" }, /*traceable=*/true); |
395 | |
396 | ParsedArgs<1> parsed_args; |
397 | auto r = parser.parse(args, kwargs, parsed_args); |
398 | auto self_ = r.tensor(0); |
399 | auto unwrapped = at::functionalization::impl::from_functional_tensor(self_); |
400 | return wrap(std::move(unwrapped)); |
401 | END_HANDLE_TH_ERRORS |
402 | } |
403 | |
404 | static PyObject* THPVariable__freeze_functional_tensor( |
405 | PyObject* self, |
406 | PyObject* args, |
407 | PyObject* kwargs) { |
408 | HANDLE_TH_ERRORS |
409 | static PythonArgParser parser( |
410 | {"_freeze_functional_tensor(Tensor t)" }, /*traceable=*/true); |
411 | |
412 | ParsedArgs<1> parsed_args; |
413 | auto r = parser.parse(args, kwargs, parsed_args); |
414 | auto self_ = r.tensor(0); |
415 | at::functionalization::impl::freeze_functional_tensor(self_); |
416 | Py_RETURN_NONE; |
417 | END_HANDLE_TH_ERRORS |
418 | } |
419 | |
420 | static PyObject* THPVariable__is_functional_tensor( |
421 | PyObject* self, |
422 | PyObject* args, |
423 | PyObject* kwargs) { |
424 | HANDLE_TH_ERRORS |
425 | static PythonArgParser parser( |
426 | {"_is_functional_tensor(Tensor t)" }, /*traceable=*/true); |
427 | |
428 | ParsedArgs<1> parsed_args; |
429 | auto r = parser.parse(args, kwargs, parsed_args); |
430 | auto self_ = r.tensor(0); |
431 | if (at::functionalization::impl::isFunctionalTensor(self_)) { |
432 | Py_RETURN_TRUE; |
433 | } else { |
434 | Py_RETURN_FALSE; |
435 | } |
436 | END_HANDLE_TH_ERRORS |
437 | } |
438 | |
439 | static PyObject* THPVariable__sync( |
440 | PyObject* self, |
441 | PyObject* args, |
442 | PyObject* kwargs) { |
443 | HANDLE_TH_ERRORS |
444 | static PythonArgParser parser({"_sync(Tensor t)" }, /*traceable=*/true); |
445 | |
446 | ParsedArgs<1> parsed_args; |
447 | auto r = parser.parse(args, kwargs, parsed_args); |
448 | auto self_ = r.tensor(0); |
449 | TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); |
450 | at::functionalization::impl::sync(self_); |
451 | Py_RETURN_NONE; |
452 | END_HANDLE_TH_ERRORS |
453 | } |
454 | |
455 | static PyObject* THPVariable__enable_functionalization( |
456 | PyObject* self, |
457 | PyObject* args, |
458 | PyObject* kwargs) { |
459 | HANDLE_TH_ERRORS |
460 | static PythonArgParser parser( |
461 | {"_enable_functionalization(*, bool reapply_views=False)" }, |
462 | /*traceable=*/true); |
463 | ParsedArgs<1> parsed_args; |
464 | auto r = parser.parse(args, kwargs, parsed_args); |
465 | const auto reapply_views = r.toBool(0); |
466 | |
467 | if (c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Functionalize)) { |
468 | TORCH_INTERNAL_ASSERT( |
469 | false, |
470 | "multiple layers of mode-style functionalization nesting is not" |
471 | " currently supported, outside of the functionalize() transform" ); |
472 | } |
473 | c10::impl::tls_set_dispatch_key_included( |
474 | at::DispatchKey::Functionalize, true); |
475 | if (reapply_views) { |
476 | at::functionalization::impl::setFunctionalizationReapplyViewsTLS(true); |
477 | } |
478 | Py_RETURN_NONE; |
479 | END_HANDLE_TH_ERRORS |
480 | } |
481 | |
482 | static PyObject* THPVariable__disable_functionalization( |
483 | PyObject* self, |
484 | PyObject* args, |
485 | PyObject* kwargs) { |
486 | HANDLE_TH_ERRORS |
487 | c10::impl::tls_set_dispatch_key_included( |
488 | at::DispatchKey::Functionalize, false); |
489 | at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false); |
490 | Py_RETURN_NONE; |
491 | END_HANDLE_TH_ERRORS |
492 | } |
493 | |
494 | // XXX: ops that are bound here are not exposed to the C++ api nor the JIT. |
495 | // Any new ops added here should be accompanied with a comment why they are not |
496 | // being registered through native_functions.yaml, and be tagged cpp / JIT |
497 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
498 | static PyMethodDef torch_functions_manual[] = { |
499 | {"asarray" , |
500 | castPyCFunctionWithKeywords(THPVariable_asarray), |
501 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
502 | nullptr}, |
503 | {"as_tensor" , |
504 | castPyCFunctionWithKeywords(THPVariable_as_tensor), |
505 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
506 | nullptr}, |
507 | {"from_numpy" , THPVariable_from_numpy, METH_STATIC | METH_O, nullptr}, |
508 | {"frombuffer" , |
509 | castPyCFunctionWithKeywords(THPVariable_frombuffer), |
510 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
511 | nullptr}, |
512 | {"_is_functional_tensor" , |
513 | castPyCFunctionWithKeywords(THPVariable__is_functional_tensor), |
514 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
515 | nullptr}, |
516 | {"_to_functional_tensor" , |
517 | castPyCFunctionWithKeywords(THPVariable__to_functional_tensor), |
518 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
519 | nullptr}, |
520 | {"_from_functional_tensor" , |
521 | castPyCFunctionWithKeywords(THPVariable__from_functional_tensor), |
522 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
523 | nullptr}, |
524 | {"_freeze_functional_tensor" , |
525 | castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor), |
526 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
527 | nullptr}, |
528 | {"_sync" , |
529 | castPyCFunctionWithKeywords(THPVariable__sync), |
530 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
531 | nullptr}, |
532 | {"_enable_functionalization" , |
533 | castPyCFunctionWithKeywords(THPVariable__enable_functionalization), |
534 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
535 | nullptr}, |
536 | {"_disable_functionalization" , |
537 | castPyCFunctionWithKeywords(THPVariable__disable_functionalization), |
538 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
539 | nullptr}, |
540 | {"nonzero" , |
541 | castPyCFunctionWithKeywords(THPVariable_nonzero), |
542 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
543 | nullptr}, |
544 | {"range" , |
545 | castPyCFunctionWithKeywords(THPVariable_range), |
546 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
547 | nullptr}, |
548 | {"sparse_coo_tensor" , |
549 | castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), |
550 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
551 | nullptr}, |
552 | {"sparse_compressed_tensor" , |
553 | castPyCFunctionWithKeywords(THPVariable_sparse_compressed_tensor), |
554 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
555 | nullptr}, |
556 | {"sparse_csr_tensor" , |
557 | castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), |
558 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
559 | nullptr}, |
560 | {"sparse_csc_tensor" , |
561 | castPyCFunctionWithKeywords(THPVariable_sparse_csc_tensor), |
562 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
563 | nullptr}, |
564 | {"sparse_bsr_tensor" , |
565 | castPyCFunctionWithKeywords(THPVariable_sparse_bsr_tensor), |
566 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
567 | nullptr}, |
568 | {"sparse_bsc_tensor" , |
569 | castPyCFunctionWithKeywords(THPVariable_sparse_bsc_tensor), |
570 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
571 | nullptr}, |
572 | {"tensor" , |
573 | castPyCFunctionWithKeywords(THPVariable_tensor), |
574 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
575 | nullptr}, |
576 | {"get_device" , |
577 | castPyCFunctionWithKeywords(THPVariable_get_device), |
578 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
579 | nullptr}, |
580 | {"numel" , |
581 | castPyCFunctionWithKeywords(THPVariable_numel), |
582 | METH_VARARGS | METH_KEYWORDS | METH_STATIC, |
583 | nullptr}, |
584 | }; |
585 | |
586 | static PyObject* THPVariable_nonzero( |
587 | PyObject* self, |
588 | PyObject* args, |
589 | PyObject* kwargs) { |
590 | HANDLE_TH_ERRORS |
591 | static PythonArgParser parser({ |
592 | "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)" , |
593 | }); |
594 | ParsedArgs<3> parsed_args; |
595 | auto r = parser.parse(args, kwargs, parsed_args); |
596 | |
597 | if (r.has_torch_function()) { |
598 | return handle_torch_function( |
599 | r, args, kwargs, THPVariableFunctionsModule, "torch" ); |
600 | } |
601 | |
602 | const auto as_tuple = r.toBool(1); |
603 | const auto has_out = !r.isNone(2); |
604 | |
605 | if (as_tuple) { |
606 | TORCH_CHECK( |
607 | !has_out, |
608 | "nonzero does not support the out kwarg when as_tuple is True" ); |
609 | return wrap(dispatch_nonzero_numpy(r.tensor(0))); |
610 | } |
611 | |
612 | if (has_out) { |
613 | return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2))); |
614 | } |
615 | |
616 | return wrap(dispatch_nonzero(r.tensor(0))); |
617 | |
618 | END_HANDLE_TH_ERRORS |
619 | } |
620 | |
621 | static PyObject* THPVariable_numel( |
622 | PyObject* self_, |
623 | PyObject* args, |
624 | PyObject* kwargs) { |
625 | HANDLE_TH_ERRORS |
626 | static PythonArgParser parser( |
627 | { |
628 | "numel(Tensor input)" , |
629 | }, |
630 | /*traceable=*/false); |
631 | |
632 | ParsedArgs<1> parsed_args; |
633 | auto r = parser.parse(args, kwargs, parsed_args); |
634 | |
635 | if (r.has_torch_function()) { |
636 | return handle_torch_function( |
637 | r, args, kwargs, THPVariableFunctionsModule, "torch" ); |
638 | } |
639 | |
640 | if (r.idx == 0) { |
641 | return py::cast(r.tensor(0).sym_numel()).release().ptr(); |
642 | } |
643 | Py_RETURN_NONE; |
644 | END_HANDLE_TH_ERRORS |
645 | } |
646 | |
647 | // Sharded function definitions |
648 | void gatherTorchFunctions_0(std::vector<PyMethodDef>& torch_functions); |
649 | void gatherTorchFunctions_1(std::vector<PyMethodDef>& torch_functions); |
650 | void gatherTorchFunctions_2(std::vector<PyMethodDef>& torch_functions); |
651 | |
652 | void gatherTorchFunctions(std::vector<PyMethodDef>& torch_functions) { |
653 | constexpr size_t num_functions = |
654 | sizeof(torch_functions_manual) / sizeof(torch_functions_manual[0]); |
655 | torch_functions.assign( |
656 | torch_functions_manual, torch_functions_manual + num_functions); |
657 | // NOTE: Must be synced with num_shards in |
658 | // tools/autograd/gen_python_functions.py |
659 | gatherTorchFunctions_0(torch_functions); |
660 | gatherTorchFunctions_1(torch_functions); |
661 | gatherTorchFunctions_2(torch_functions); |
662 | |
663 | static std::array<std::pair<const char*, const char*>, 4> aliases{ |
664 | {// Canonical function, alias name |
665 | {"sspaddmm" , "saddmm" }, |
666 | {"mm" , "spmm" }, |
667 | {"mm" , "dsmm" }, |
668 | {"hspmm" , "hsmm" }}}; |
669 | |
670 | for (const auto& alias : aliases) { |
671 | auto it = std::find_if( |
672 | torch_functions.begin(), |
673 | torch_functions.end(), |
674 | [&](const PyMethodDef& def) { |
675 | return strcmp(def.ml_name, alias.first) == 0; |
676 | }); |
677 | TORCH_INTERNAL_ASSERT( |
678 | it != torch_functions.end(), |
679 | "Failed to create function alias from " , |
680 | alias.first, |
681 | " to " , |
682 | alias.second); |
683 | PyMethodDef alias_def = *it; |
684 | alias_def.ml_name = alias.second; |
685 | |
686 | torch_functions.push_back(alias_def); |
687 | } |
688 | |
689 | torch_functions.push_back({nullptr}); |
690 | torch_functions.shrink_to_fit(); |
691 | } |
692 | |
693 | static PyTypeObject THPVariableFunctions = { |
694 | PyVarObject_HEAD_INIT( |
695 | nullptr, |
696 | 0) "torch._C._VariableFunctionsClass" , /* tp_name */ |
697 | 0, /* tp_basicsize */ |
698 | 0, /* tp_itemsize */ |
699 | nullptr, /* tp_dealloc */ |
700 | 0, /* tp_vectorcall_offset */ |
701 | nullptr, /* tp_getattr */ |
702 | nullptr, /* tp_setattr */ |
703 | nullptr, /* tp_reserved */ |
704 | nullptr, /* tp_repr */ |
705 | nullptr, /* tp_as_number */ |
706 | nullptr, /* tp_as_sequence */ |
707 | nullptr, /* tp_as_mapping */ |
708 | nullptr, /* tp_hash */ |
709 | nullptr, /* tp_call */ |
710 | nullptr, /* tp_str */ |
711 | nullptr, /* tp_getattro */ |
712 | nullptr, /* tp_setattro */ |
713 | nullptr, /* tp_as_buffer */ |
714 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
715 | nullptr, /* tp_doc */ |
716 | nullptr, /* tp_traverse */ |
717 | nullptr, /* tp_clear */ |
718 | nullptr, /* tp_richcompare */ |
719 | 0, /* tp_weaklistoffset */ |
720 | nullptr, /* tp_iter */ |
721 | nullptr, /* tp_iternext */ |
722 | nullptr, /* tp_methods */ |
723 | nullptr, /* tp_members */ |
724 | nullptr, /* tp_getset */ |
725 | nullptr, /* tp_base */ |
726 | nullptr, /* tp_dict */ |
727 | nullptr, /* tp_descr_get */ |
728 | nullptr, /* tp_descr_set */ |
729 | 0, /* tp_dictoffset */ |
730 | nullptr, /* tp_init */ |
731 | nullptr, /* tp_alloc */ |
732 | nullptr /* tp_new */ |
733 | }; |
734 | |
735 | void initTorchFunctions(PyObject* module) { |
736 | static std::vector<PyMethodDef> torch_functions; |
737 | gatherTorchFunctions(torch_functions); |
738 | THPVariableFunctions.tp_methods = torch_functions.data(); |
739 | |
740 | if (PyType_Ready(&THPVariableFunctions) < 0) { |
741 | throw python_error(); |
742 | } |
743 | Py_INCREF(&THPVariableFunctions); |
744 | |
745 | // Steals |
746 | Py_INCREF(&THPVariableFunctions); |
747 | if (PyModule_AddObject( |
748 | module, |
749 | "_VariableFunctionsClass" , |
750 | reinterpret_cast<PyObject*>(&THPVariableFunctions)) < 0) { |
751 | throw python_error(); |
752 | } |
753 | // PyType_GenericNew returns a new reference |
754 | THPVariableFunctionsModule = |
755 | PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None); |
756 | // PyModule_AddObject steals a reference |
757 | if (PyModule_AddObject( |
758 | module, "_VariableFunctions" , THPVariableFunctionsModule) < 0) { |
759 | throw python_error(); |
760 | } |
761 | } |
762 | |
763 | } // namespace autograd |
764 | } // namespace torch |
765 | |