1#pragma once
2
3#include <ATen/ATen.h>
4#include <ATen/core/ATen_fwd.h>
5#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
6#include <algorithm>
7
8namespace torch {
9namespace nested {
10
11/// Nested tensor
12///
13/// See
14/// https://pytorch.org/docs/master/nested.html#torch.nested.nested_tensor
15///
16/// ```
17// implemented on python object to allow torch.nested.nested_tensor to be
18// constructed with arbitrarily nested python objects - for now, only arbitrary
19// python lists and lists of Tensors
20// See torch/csrc/autograd/python_nested_functions_manual.cpp for Python
21// implementation
22// See here for C++ implementation
23inline at::Tensor nested_tensor(
24 at::TensorList nested_tensor_data,
25 const at::TensorOptions& options = {}) {
26 auto out = at::_nested_tensor_from_tensor_list(
27 nested_tensor_data,
28 c10::typeMetaToScalarType(options.dtype()),
29 c10::nullopt,
30 options.device(),
31 options.pinned_memory());
32 if (options.has_requires_grad() && options.requires_grad()) {
33 out.requires_grad_(true);
34 }
35 return out;
36}
37
38inline at::Tensor nested_tensor(
39 at::ArrayRef<detail::TensorDataContainer> nested_tensor_data,
40 const at::TensorOptions& options = {}) {
41 for (const auto& tdc : nested_tensor_data) {
42 TORCH_CHECK(
43 tdc.is_init_list(),
44 "nested_tensor() not implemented for these parameters");
45 }
46 // Construct a TensorList using nested_tensor_data
47 std::vector<at::Tensor> tensor_list(nested_tensor_data.size());
48 std::transform(
49 nested_tensor_data.begin(),
50 nested_tensor_data.end(),
51 tensor_list.begin(),
52 [&](const detail::TensorDataContainer& tdc) {
53 return tdc.convert_to_tensor(options);
54 });
55 auto out = at::_nested_tensor_from_tensor_list(
56 tensor_list,
57 c10::typeMetaToScalarType(options.dtype()),
58 c10::nullopt,
59 options.device(),
60 options.pinned_memory());
61 if (options.has_requires_grad() && options.requires_grad()) {
62 out.requires_grad_(true);
63 }
64 return out;
65}
66
67/// As Nested Tensor
68///
69/// See
70/// https://pytorch.org/docs/master/nested.html#torch.nested.as_nested_tensor
71///
72/// ```
73inline at::Tensor as_nested_tensor(
74 at::TensorList list,
75 c10::optional<at::ScalarType> dtype = c10::nullopt,
76 c10::optional<at::Device> device = c10::nullopt) {
77 return at::_nested_tensor_from_tensor_list(
78 list, dtype, c10::nullopt, device, c10::nullopt);
79}
80
81/// Nested to padded tensor
82///
83/// See
84/// https://pytorch.org/docs/master/nested.html#torch.nested.to_padded_tensor
85///
86/// ```
87inline at::Tensor to_padded_tensor(
88 const at::Tensor& self,
89 double padding,
90 at::OptionalIntArrayRef output_size = c10::nullopt) {
91 return at::nested_to_padded_tensor(self, padding, output_size);
92}
93
94} // namespace nested
95} // namespace torch
96