1#pragma once
2#include <ATen/MemoryOverlap.h>
3#include <ATen/Tensor.h>
4#include <c10/core/DispatchKey.h>
5#include <c10/core/DispatchKeySet.h>
6#include <c10/core/MemoryFormat.h>
7#include <c10/core/TensorImpl.h>
8#include <c10/util/ArrayRef.h>
9#include <c10/util/Exception.h>
10#include <c10/util/Metaprogramming.h>
11#include <c10/util/irange.h>
12
13namespace at {
14namespace native {
15struct NestedTensorImpl;
16inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
17
18struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
19 explicit NestedTensorImpl(
20 Storage storage,
21 c10::DispatchKeySet key_set,
22 const caffe2::TypeMeta data_type,
23 at::Tensor nested_size_tensor,
24 at::Tensor nested_stride_tensor,
25 std::vector<int64_t>&& offsets);
26
27 explicit NestedTensorImpl(
28 at::Tensor buffer,
29 at::Tensor nested_size_tensor,
30 at::Tensor nested_stride_tensor,
31 std::vector<int64_t>&& offsets);
32 // assume contiguous, `nested_stride_tensor` and `offsets`
33 // can be infered from `nested_size_tensor`
34 explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor);
35
36 // This constructor is used creating view tensors from nested tensors
37 explicit NestedTensorImpl(
38 c10::TensorImpl::ImplType impl_type,
39 const at::Tensor& base_tensor,
40 at::Tensor nested_size_tensor,
41 at::Tensor nested_stride_tensor,
42 std::vector<int64_t>&& offsets);
43
44 // TODO: don't expose private implementation details like this; in
45 // particular, resizing this tensor will mess up our dim() and
46 // callers cannot fix it.
47 const Tensor& get_nested_size_tensor() const {
48 return nested_size_tensor_;
49 }
50 // TODO: don't expose private implementation details like this
51 const Tensor& get_nested_stride_tensor() const {
52 return nested_stride_tensor_;
53 }
54 const std::vector<int64_t>& get_storage_offsets() const {
55 return storage_offsets_;
56 }
57 // Returns nullopt if the ith dimension is irregular. The ith dimension
58 // of a NestedTensor is regular if the unbound tensors match in
59 // size at the (i-1)th dimension.
60 c10::optional<int64_t> opt_size(int64_t d) const {
61 d = at::maybe_wrap_dim(d, dim(), false);
62 if (opt_sizes_[d] == -1) {
63 return c10::nullopt;
64 }
65 return opt_sizes_[d];
66 }
67
68 int64_t size(int64_t d) const {
69 c10::optional<int64_t> optional_size = this->opt_size(d);
70 TORCH_CHECK(
71 optional_size.has_value(),
72 "Given dimension ",
73 d,
74 " is irregular and does not have a size.");
75 return *optional_size;
76 }
77 /**
78 * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
79 *
80 * The buffer tensor created by this function shares the same storage_impl as
81 * the original nested tensor, and therefore can be seen as a view.
82 *
83 * @return A newly constructed view tensor
84 */
85 at::Tensor get_buffer() const {
86 TORCH_CHECK(
87 nested_tensor_impl_is_contiguous(this),
88 "NestedTensor must be contiguous to get buffer.");
89 return get_unsafe_storage_as_tensor();
90 }
91 /**
92 * If possible use get_buffer() instead. This function returns the storage
93 * as a tensor directly, which is not safe to use in general. If using this
94 * function, The caller must ensure to account for nested_sizes,
95 * nested_strides and storage_offsets.
96 *
97 * @return A newly constructed view tensor
98 */
99 at::Tensor get_unsafe_storage_as_tensor() const {
100 auto buffer_key_set_ = generate_buffer_key_set();
101 const auto buffer_size = get_buffer_size();
102 auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
103 c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
104 buffer_tensor_impl->set_sizes_contiguous(c10::makeArrayRef(buffer_size));
105 return Tensor(buffer_tensor_impl);
106 }
107
108 int64_t get_buffer_size() const {
109 return storage_.nbytes() / data_type_.itemsize();
110 }
111
112 protected:
113 const char* tensorimpl_type_name() const override;
114
115 // TODO: numel_custom and is_contiguous_custom can be profitably overridden
116 // with real implementations
117 int64_t numel_custom() const override;
118 c10::SymInt sym_numel_custom() const override;
119 bool is_contiguous_custom(MemoryFormat) const override;
120 int64_t size_custom(int64_t d) const override {
121 return this->size(d);
122 }
123 c10::SymInt sym_size_custom(int64_t d) const override {
124 return c10::SymInt{this->size(d)};
125 }
126 IntArrayRef sizes_custom() const override;
127 c10::SymIntArrayRef sym_sizes_custom() const override;
128 IntArrayRef strides_custom() const override;
129 c10::SymIntArrayRef sym_strides_custom() const override;
130
131 // this one is real
132 int64_t dim_custom() const override;
133
134 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
135 const c10::VariableVersion& version_counter,
136 bool allow_tensor_metadata_change) const override;
137
138 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
139 c10::VariableVersion&& version_counter,
140 bool allow_tensor_metadata_change) const override;
141
142 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
143 copy_tensor_metadata(
144 /*src_impl=*/impl.get(),
145 /*dest_impl=*/this,
146 /*version_counter=*/version_counter(),
147 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
148 }
149
150 private:
151 // Must be called after any changes to our dim() to sync the state
152 // to TensorImpl.
153 void refresh_dim();
154
155 const at::Tensor nested_size_tensor_, nested_stride_tensor_;
156 // The starting positions of the underlying tensors in contiguous buffer
157 // i.e. the buffer memory offsets to get the underlying tensors
158 // The reason to keep this metadata is that, without strong enough constraint
159 // it cannot be derived from `nested_size_tensor_`
160 // and `nested_stride_tensor_`:
161 // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
162 // this can happen e.g. after slicing a nested tensor
163 // 2. when multiple tensors share a same memory
164 // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
165 // Some strong enough constraints are:
166 // 1. every underlying tensor is contiguous in memory
167 // && nesting in ascending order
168 std::vector<int64_t> storage_offsets_;
169 // NOTE: -1 here means the size is missing
170 // TODO: maybe we can remove this metadata since
171 // we can compute it from `nested_size_tensor_`
172 std::vector<int64_t> opt_sizes_;
173
174 template <typename VariableVersion>
175 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
176 VariableVersion&& version_counter,
177 bool allow_tensor_metadata_change) const;
178
179 /**
180 * Generates a non-nested key_set from a nested tensor.
181 *
182 * For many nested tensor kernel implementations a buffer tensor
183 * is generated and redispatched to a non-nested kernel this function
184 * generates the key set used by that buffer tensor
185 *
186 * @return Appropriate key set for non-nested tensor
187 */
188 inline c10::DispatchKeySet generate_buffer_key_set() const {
189 auto buffer_key_set = this->key_set();
190 const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
191 // Remove nested tensor specific keys
192 buffer_key_set = buffer_key_set -
193 c10::DispatchKeySet{
194 c10::DispatchKey::NestedTensor,
195 c10::DispatchKey::AutogradNestedTensor};
196
197 // Add dense tensor specific keys
198 buffer_key_set =
199 buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
200 buffer_key_set = Autograd
201 ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
202 : buffer_key_set;
203
204 return buffer_key_set;
205 }
206};
207
208inline NestedTensorImpl* get_nested_tensor_impl_or_null(
209 const at::Tensor& tensor) {
210 if (tensor.is_nested()) {
211 return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
212 }
213 return nullptr;
214}
215
216inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
217 TORCH_CHECK(
218 tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
219 return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
220}
221
222inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
223 int64_t ntensors = nt->size(0);
224 if (ntensors == 0) {
225 return true;
226 }
227 const Tensor &sizemat = nt->get_nested_size_tensor(),
228 &stridemat = nt->get_nested_stride_tensor();
229 const auto& offsets = nt->get_storage_offsets();
230 int64_t orig_dim = sizemat.size(1);
231 // nesting scalars
232 if (orig_dim == 0) {
233 // each scalar must be contiguous
234 // if there is blanck memory between underlying scalars
235 for (int64_t i = 0; i < ntensors; i++) {
236 if (offsets[i] != i) {
237 return false;
238 }
239 }
240 }
241 // nesting tensors
242 else {
243 // if any underlying tensor is noncontiguous
244 const int64_t *sizemat_ptr = sizemat.data_ptr<int64_t>(),
245 *stridemat_ptr = stridemat.data_ptr<int64_t>();
246 for (int64_t i = 0; i < ntensors; i++) {
247 if (stridemat_ptr[orig_dim - 1] != 1) {
248 return false;
249 }
250 int64_t product = sizemat_ptr[orig_dim - 1];
251 for (int64_t j = orig_dim - 2; j >= 0; j--) {
252 if (stridemat_ptr[j] != product) {
253 return false;
254 }
255 product *= sizemat_ptr[j];
256 }
257 sizemat_ptr += orig_dim;
258 stridemat_ptr += orig_dim;
259 }
260 // if there is blanck memory between underlying tensors
261 if (offsets[0] != 0) {
262 return false;
263 }
264 sizemat_ptr = sizemat.data_ptr<int64_t>();
265 stridemat_ptr = stridemat.data_ptr<int64_t>();
266 for (int64_t i = 1; i < ntensors; i++) {
267 if (offsets[i] != offsets[i - 1] + *sizemat_ptr * *stridemat_ptr) {
268 return false;
269 }
270 sizemat_ptr += orig_dim;
271 stridemat_ptr += orig_dim;
272 }
273 }
274 // everything is fine
275 return true;
276}
277
278inline const at::Tensor& get_nested_size_tensor(const at::Tensor& tensor) {
279 return get_nested_tensor_impl(tensor)->get_nested_size_tensor();
280}
281
282} // namespace native
283} // namespace at
284