1 | #include <torch/csrc/utils/nested.h> |
2 | #include <torch/csrc/utils/pycfunction_helpers.h> |
3 | #include <torch/csrc/utils/python_arg_parser.h> |
4 | #include <torch/torch.h> |
5 | |
6 | namespace torch { |
7 | namespace autograd { |
8 | |
9 | static PyObject* THPVariable_nested_tensor( |
10 | PyObject* /*self*/, |
11 | PyObject* args, |
12 | PyObject* kwargs) { |
13 | HANDLE_TH_ERRORS |
14 | static PythonArgParser parser({ |
15 | "nested_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)" , |
16 | }); |
17 | |
18 | constexpr int ctor_num_args = 5; |
19 | ParsedArgs<ctor_num_args> parsed_args; |
20 | auto r = parser.parse(args, kwargs, parsed_args); |
21 | |
22 | jit::tracer::warn( |
23 | "torch.nested.nested_tensor" , jit::tracer::WARN_CONSTRUCTOR); |
24 | return THPVariable_Wrap(torch::utils::nested_tensor_ctor( |
25 | torch::tensors::get_default_dispatch_key(), |
26 | torch::tensors::get_default_scalar_type(), |
27 | r)); |
28 | END_HANDLE_TH_ERRORS |
29 | } |
30 | |
31 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
32 | static PyMethodDef nested_functions_manual[] = { |
33 | {"nested_tensor" , |
34 | castPyCFunctionWithKeywords(THPVariable_nested_tensor), |
35 | METH_VARARGS | METH_KEYWORDS, |
36 | nullptr}, |
37 | }; |
38 | |
39 | PyMethodDef* get_nested_functions_manual() { |
40 | return nested_functions_manual; |
41 | } |
42 | |
43 | } // namespace autograd |
44 | } // namespace torch |
45 | |