1#include <ATen/ATen.h>
2#include <ATen/NamedTensorUtils.h>
3#include <ATen/WrapDimUtils.h>
4#include <ATen/core/op_registration/op_registration.h>
5#include <ATen/NestedTensorImpl.h>
6#include <c10/core/DispatchKey.h>
7#include <c10/core/DispatchKeySet.h>
8#include <c10/util/Exception.h>
9#include <c10/core/TensorImpl.h>
10#include <c10/util/Logging.h>
11
12#include <numeric>
13#include <functional>
14
15namespace {
16inline void validate_nested_tensor_metadata(
17 const at::Tensor& nested_sizes,
18 const at::Tensor& nested_strides,
19 const std::vector<int64_t>& offsets) {
20 TORCH_INTERNAL_ASSERT(nested_sizes.is_contiguous());
21 int64_t size_dim = nested_sizes.dim();
22 TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
23 TORCH_INTERNAL_ASSERT(nested_strides.is_contiguous());
24 TORCH_INTERNAL_ASSERT(nested_strides.dim() == size_dim);
25 TORCH_INTERNAL_ASSERT(nested_sizes.sizes() == nested_strides.sizes());
26 TORCH_INTERNAL_ASSERT(
27 (size_dim == 0 && (int64_t)offsets.empty()) ||
28 (size_dim == 2 && nested_sizes.size(0) == (int64_t)offsets.size()));
29}
30
31/**
32 * Generates a nested key_set from a non-nested tensor.
33 *
34 * When creating a nested tensor from a non-nested tensor
35 * We want to maintain the same keyset as the buffer but
36 * swap non nested keys for nested ones
37 *
38 * @return Appropriate key set for nested tensor
39 */
40inline c10::DispatchKeySet generate_nested_key_set_from_buffer(
41 const at::Tensor& buffer) {
42 auto nested_key_set = buffer.key_set();
43 const bool has_autograd = nested_key_set.has_any(c10::autograd_dispatch_keyset);
44 // Remove non_nested tensor specific keys
45 nested_key_set = nested_key_set -
46 c10::DispatchKeySet{c10::DispatchKey::Dense, c10::DispatchKey::Autograd};
47
48 // Add nested tensor specific keys
49 nested_key_set =
50 nested_key_set | c10::DispatchKeySet{c10::DispatchKey::NestedTensor};
51 nested_key_set =
52 has_autograd ? nested_key_set | c10::autograd_nested : nested_key_set;
53 return nested_key_set;
54}
55
56/**
57 * Generates a the correct view keyset.
58 *
59 * When creating a nested tensor view of base
60 * The appropriate keyset will be dependent on the nested
61 * status of the base
62 *
63 * @return Appropriate key set for nested tensor
64 */
65c10::DispatchKeySet get_view_key_set(const at::Tensor& base) {
66 return base.is_nested() ? base.key_set()
67 : generate_nested_key_set_from_buffer(base);
68}
69
70} // namespace
71namespace at {
72namespace native {
73
74inline std::vector<int64_t> construct_opt_sizes(const at::Tensor& sizes) {
75 // torch.tensor([]) is considered to have `dim() = 1` and `size(0) = 0`
76 // torch.nested_tensor([]) should also has `dim() = 1` and `size(0) = 0`
77 if (sizes.dim() == 0) {
78 return std::vector<int64_t>({0});
79 }
80 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
81 std::vector<int64_t> result(1, sizes.sizes()[0]);
82 if (sizes.dim() > 0) {
83 size_t nested_dim = result.size();
84 int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
85 result.resize(nested_dim + sizes.sizes()[1]);
86 int64_t sizes_size_0 = sizes.sizes()[0];
87 int64_t sizes_size_1 = sizes.sizes()[1];
88 for (const auto i : c10::irange(sizes_size_1)) {
89 result[nested_dim + i] = sizes_ptr[i];
90 }
91 for (const auto j : c10::irange(sizes_size_1)) {
92 for (const auto i : c10::irange(sizes_size_0)) {
93 if (result[nested_dim + j] &&
94 (result[nested_dim + j] != sizes_ptr[i * sizes.size(1) + j])) {
95 result[nested_dim + j] = -1;
96 }
97 }
98 }
99 }
100 return result;
101}
102
103// assume contiguous, we can construct stride from size
104inline at::Tensor construct_nested_stride_tensor(const at::Tensor& sizes) {
105 // empty `sizes` means empty nested tensor, so return empty strides
106 if (sizes.dim() == 0) {
107 return sizes;
108 }
109 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
110 int64_t orig_dim = sizes.size(1);
111 // `sizes`.sizes() = ntensors x 0 means empty but shaped `sizes`
112 // in this case strides is also empty but shaped
113 if (orig_dim == 0) {
114 return sizes;
115 }
116 at::Tensor strides = sizes.new_empty(sizes.sizes());
117 const int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
118 int64_t* strides_ptr = strides.data_ptr<int64_t>();
119 for (int64_t i = 0; i < sizes.size(0); i++) {
120 strides_ptr[orig_dim - 1] = 1;
121 int64_t product = sizes_ptr[orig_dim - 1];
122 for (int64_t j = orig_dim - 2; j >= 0; j--) {
123 strides_ptr[j] = product;
124 product *= sizes_ptr[j];
125 }
126 sizes_ptr += orig_dim;
127 strides_ptr += orig_dim;
128 }
129 return strides;
130}
131
132/**
133 * Create a vector of offsets assuming the nested tensor is contiguous
134 *
135 * This function iterates over the implicit ntensor outer dimension
136 * populating a vector with the num_elements in each implicit tensor.
137 * The first element is always 0 and the length of the returned vector
138 * is n_tensor.
139 *
140 * @return A vector of offsets
141 */
142inline std::vector<int64_t> construct_offsets(const at::Tensor& sizes) {
143 // empty `sizes` means empty nested tensor, so return empty strides
144 if (sizes.dim() == 0) {
145 return std::vector<int64_t>();
146 }
147 int64_t ntensors = sizes.size(0), orig_dim = sizes.size(1);
148 std::vector<int64_t> offsets(ntensors);
149 // nesting scalars has easy offsets
150 if (orig_dim == 0) {
151 std::iota(offsets.begin(), offsets.end(), 0);
152 return offsets;
153 }
154 const int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
155 offsets[0] = 0;
156 for (const auto i : c10::irange(ntensors - 1)) {
157 const int64_t row_product = std::accumulate(sizes_ptr, sizes_ptr + orig_dim, 1, std::multiplies<int64_t>());
158 offsets[i + 1] = offsets[i] + row_product;
159 sizes_ptr += orig_dim;
160 }
161 return offsets;
162}
163
164NestedTensorImpl::NestedTensorImpl(
165 Storage storage,
166 c10::DispatchKeySet key_set,
167 const caffe2::TypeMeta data_type,
168 at::Tensor nested_size_tensor,
169 at::Tensor nested_stride_tensor,
170 std::vector<int64_t>&& offsets)
171 : TensorImpl(std::move(storage), key_set, data_type),
172 nested_size_tensor_(std::move(nested_size_tensor)),
173 nested_stride_tensor_(std::move(nested_stride_tensor)),
174 storage_offsets_(std::move(offsets)),
175 opt_sizes_(construct_opt_sizes(nested_size_tensor_)) {
176 C10_LOG_API_USAGE_ONCE("torch.NestedTensor");
177 TORCH_WARN_ONCE(
178 "The PyTorch API of nested tensors is in prototype stage and will change "
179 "in the near future.");
180 auto storage_device = storage_.device();
181 TORCH_INTERNAL_ASSERT(
182 storage_device.is_cpu() || storage_device.is_cuda(),
183 "NestedTensorImpl storage must be either CUDA or CPU but got ",
184 storage_device);
185 validate_nested_tensor_metadata(nested_size_tensor_, nested_stride_tensor_, storage_offsets_);
186 refresh_dim();
187 set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
188}
189
190NestedTensorImpl::NestedTensorImpl(
191 at::Tensor buffer,
192 at::Tensor nested_size_tensor,
193 at::Tensor nested_stride_tensor,
194 std::vector<int64_t>&& offsets)
195 : NestedTensorImpl(
196 buffer.storage(),
197 generate_nested_key_set_from_buffer(buffer),
198 buffer.dtype(),
199 nested_size_tensor,
200 nested_stride_tensor,
201 std::move(offsets)) {
202
203 TORCH_INTERNAL_ASSERT(
204 buffer.dim() == 1,
205 "NestedTensorImpl buffer is required to be 1 dimensional but got a buffer with ",
206 buffer.dim(),
207 " dimensions.");
208}
209
210// assume contiguous, `nested_stride_tensor` and `offsets`
211// can be infered from `nested_size_tensor`
212NestedTensorImpl::NestedTensorImpl(
213 at::Tensor buffer,
214 at::Tensor nested_size_tensor)
215 : NestedTensorImpl(
216 buffer,
217 nested_size_tensor,
218 construct_nested_stride_tensor(nested_size_tensor),
219 construct_offsets(nested_size_tensor))
220{}
221
222NestedTensorImpl::NestedTensorImpl(
223 c10::TensorImpl::ImplType impl_type,
224 const at::Tensor& base_tensor,
225 at::Tensor nested_size_tensor,
226 at::Tensor nested_stride_tensor,
227 std::vector<int64_t>&& offsets)
228 : TensorImpl(impl_type, Storage(base_tensor.storage()), get_view_key_set(base_tensor), base_tensor.dtype()),
229 nested_size_tensor_(std::move(nested_size_tensor)),
230 nested_stride_tensor_(std::move(nested_stride_tensor)),
231 storage_offsets_(std::move(offsets)),
232 opt_sizes_(construct_opt_sizes(nested_size_tensor_)) {
233 validate_nested_tensor_metadata(nested_size_tensor_, nested_stride_tensor_, storage_offsets_);
234 refresh_dim();
235 set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
236}
237
238void NestedTensorImpl::refresh_dim() {
239 const auto my_dim = nested_size_tensor_.dim() ? nested_size_tensor_.sizes()[1] + 1 : 1;
240 sizes_and_strides_.resize(my_dim);
241 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim() == my_dim);
242}
243
244int64_t NestedTensorImpl::dim_custom() const {
245 return dim_default();
246}
247
248// Currently sizes and strides assume contiguous
249int64_t NestedTensorImpl::numel_custom() const {
250 if (nested_size_tensor_.dim() == 0) {
251 return 0;
252 }
253 constexpr auto numel_max = std::min(
254 static_cast<uint64_t>(std::numeric_limits<int64_t>::max()),
255 static_cast<uint64_t>(std::numeric_limits<size_t>::max()));
256
257 const auto nt_dim = nested_size_tensor_.size(1);
258 const int64_t* sizes_ptr = nested_size_tensor_.data_ptr<int64_t>();
259 uint64_t num_elements{0};
260
261 for (const auto i : c10::irange(nested_size_tensor_.size(0))) {
262 uint64_t n = 1;
263 const auto start{sizes_ptr + i * nt_dim};
264 const auto end{start + nt_dim};
265 bool overflows = c10::safe_multiplies_u64(start, end, &n);
266 num_elements += n;
267 overflows |= (num_elements > numel_max);
268 TORCH_CHECK(!overflows, "numel: integer multiplication overflow");
269 }
270 return static_cast<int64_t>(num_elements);
271}
272
273
274c10::SymInt NestedTensorImpl::sym_numel_custom() const {
275 return NestedTensorImpl::numel_custom();
276}
277
278bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const {
279 return nested_tensor_impl_is_contiguous(this);
280}
281IntArrayRef NestedTensorImpl::sizes_custom() const {
282 TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
283}
284c10::SymIntArrayRef NestedTensorImpl::sym_sizes_custom() const {
285 TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
286}
287
288c10::SymIntArrayRef NestedTensorImpl::sym_strides_custom() const {
289 TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");
290}
291
292IntArrayRef NestedTensorImpl::strides_custom() const {
293 TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");
294}
295
296const char* NestedTensorImpl::tensorimpl_type_name() const {
297 return "NestedTensorImpl";
298}
299
300
301template <typename VariableVersion>
302c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach_core(
303 VariableVersion&& version_counter,
304 bool allow_tensor_metadata_change) const {
305 if (key_set_.has(DispatchKey::Python) &&
306 !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
307 auto r = pyobj_slot_.load_pyobj_interpreter()->detach(this);
308 if (r) {
309 r->set_version_counter(std::forward<VariableVersion>(version_counter));
310 r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
311 return r;
312 }
313 // otherwise just copy the TensorImpl and not the PyObject. Since
314 // the interpreter is dead no one can call us out on it
315 }
316 auto impl = c10::make_intrusive<NestedTensorImpl>(
317 storage_,
318 key_set_,
319 data_type_,
320 nested_size_tensor_,
321 nested_stride_tensor_,
322 std::vector<int64_t>(storage_offsets_));
323
324 copy_tensor_metadata(
325 /*src_impl=*/this,
326 /*dest_impl=*/impl.get(),
327 /*version_counter=*/std::forward<VariableVersion>(version_counter),
328 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
329 return impl;
330}
331
332c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach(
333 const c10::VariableVersion& version_counter,
334 bool allow_tensor_metadata_change) const {
335 return shallow_copy_and_detach_core(
336 version_counter, allow_tensor_metadata_change);
337}
338
339c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach(
340 c10::VariableVersion&& version_counter,
341 bool allow_tensor_metadata_change) const {
342 return shallow_copy_and_detach_core(
343 std::move(version_counter), allow_tensor_metadata_change);
344}
345
346} // namespace native
347} // namespace at
348