1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/core/ATen_fwd.h> |
5 | #include <torch/csrc/api/include/torch/detail/TensorDataContainer.h> |
6 | #include <algorithm> |
7 | |
8 | namespace torch { |
9 | namespace nested { |
10 | |
11 | /// Nested tensor |
12 | /// |
13 | /// See |
14 | /// https://pytorch.org/docs/master/nested.html#torch.nested.nested_tensor |
15 | /// |
16 | /// ``` |
17 | // implemented on python object to allow torch.nested.nested_tensor to be |
18 | // constructed with arbitrarily nested python objects - for now, only arbitrary |
19 | // python lists and lists of Tensors |
20 | // See torch/csrc/autograd/python_nested_functions_manual.cpp for Python |
21 | // implementation |
22 | // See here for C++ implementation |
23 | inline at::Tensor nested_tensor( |
24 | at::TensorList nested_tensor_data, |
25 | const at::TensorOptions& options = {}) { |
26 | auto out = at::_nested_tensor_from_tensor_list( |
27 | nested_tensor_data, |
28 | c10::typeMetaToScalarType(options.dtype()), |
29 | c10::nullopt, |
30 | options.device(), |
31 | options.pinned_memory()); |
32 | if (options.has_requires_grad() && options.requires_grad()) { |
33 | out.requires_grad_(true); |
34 | } |
35 | return out; |
36 | } |
37 | |
38 | inline at::Tensor nested_tensor( |
39 | at::ArrayRef<detail::TensorDataContainer> nested_tensor_data, |
40 | const at::TensorOptions& options = {}) { |
41 | for (const auto& tdc : nested_tensor_data) { |
42 | TORCH_CHECK( |
43 | tdc.is_init_list(), |
44 | "nested_tensor() not implemented for these parameters" ); |
45 | } |
46 | // Construct a TensorList using nested_tensor_data |
47 | std::vector<at::Tensor> tensor_list(nested_tensor_data.size()); |
48 | std::transform( |
49 | nested_tensor_data.begin(), |
50 | nested_tensor_data.end(), |
51 | tensor_list.begin(), |
52 | [&](const detail::TensorDataContainer& tdc) { |
53 | return tdc.convert_to_tensor(options); |
54 | }); |
55 | auto out = at::_nested_tensor_from_tensor_list( |
56 | tensor_list, |
57 | c10::typeMetaToScalarType(options.dtype()), |
58 | c10::nullopt, |
59 | options.device(), |
60 | options.pinned_memory()); |
61 | if (options.has_requires_grad() && options.requires_grad()) { |
62 | out.requires_grad_(true); |
63 | } |
64 | return out; |
65 | } |
66 | |
67 | /// As Nested Tensor |
68 | /// |
69 | /// See |
70 | /// https://pytorch.org/docs/master/nested.html#torch.nested.as_nested_tensor |
71 | /// |
72 | /// ``` |
73 | inline at::Tensor as_nested_tensor( |
74 | at::TensorList list, |
75 | c10::optional<at::ScalarType> dtype = c10::nullopt, |
76 | c10::optional<at::Device> device = c10::nullopt) { |
77 | return at::_nested_tensor_from_tensor_list( |
78 | list, dtype, c10::nullopt, device, c10::nullopt); |
79 | } |
80 | |
81 | /// Nested to padded tensor |
82 | /// |
83 | /// See |
84 | /// https://pytorch.org/docs/master/nested.html#torch.nested.to_padded_tensor |
85 | /// |
86 | /// ``` |
87 | inline at::Tensor to_padded_tensor( |
88 | const at::Tensor& self, |
89 | double padding, |
90 | at::OptionalIntArrayRef output_size = c10::nullopt) { |
91 | return at::nested_to_padded_tensor(self, padding, output_size); |
92 | } |
93 | |
94 | } // namespace nested |
95 | } // namespace torch |
96 | |