1 | #pragma once |
2 | #include <ATen/MemoryOverlap.h> |
3 | #include <ATen/Tensor.h> |
4 | #include <c10/core/DispatchKey.h> |
5 | #include <c10/core/DispatchKeySet.h> |
6 | #include <c10/core/MemoryFormat.h> |
7 | #include <c10/core/TensorImpl.h> |
8 | #include <c10/util/ArrayRef.h> |
9 | #include <c10/util/Exception.h> |
10 | #include <c10/util/Metaprogramming.h> |
11 | #include <c10/util/irange.h> |
12 | |
13 | namespace at { |
14 | namespace native { |
15 | struct NestedTensorImpl; |
16 | inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt); |
17 | |
18 | struct TORCH_API NestedTensorImpl : public c10::TensorImpl { |
19 | explicit NestedTensorImpl( |
20 | Storage storage, |
21 | c10::DispatchKeySet key_set, |
22 | const caffe2::TypeMeta data_type, |
23 | at::Tensor nested_size_tensor, |
24 | at::Tensor nested_stride_tensor, |
25 | std::vector<int64_t>&& offsets); |
26 | |
27 | explicit NestedTensorImpl( |
28 | at::Tensor buffer, |
29 | at::Tensor nested_size_tensor, |
30 | at::Tensor nested_stride_tensor, |
31 | std::vector<int64_t>&& offsets); |
32 | // assume contiguous, `nested_stride_tensor` and `offsets` |
33 | // can be infered from `nested_size_tensor` |
34 | explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor); |
35 | |
36 | // This constructor is used creating view tensors from nested tensors |
37 | explicit NestedTensorImpl( |
38 | c10::TensorImpl::ImplType impl_type, |
39 | const at::Tensor& base_tensor, |
40 | at::Tensor nested_size_tensor, |
41 | at::Tensor nested_stride_tensor, |
42 | std::vector<int64_t>&& offsets); |
43 | |
44 | // TODO: don't expose private implementation details like this; in |
45 | // particular, resizing this tensor will mess up our dim() and |
46 | // callers cannot fix it. |
47 | const Tensor& get_nested_size_tensor() const { |
48 | return nested_size_tensor_; |
49 | } |
50 | // TODO: don't expose private implementation details like this |
51 | const Tensor& get_nested_stride_tensor() const { |
52 | return nested_stride_tensor_; |
53 | } |
54 | const std::vector<int64_t>& get_storage_offsets() const { |
55 | return storage_offsets_; |
56 | } |
57 | // Returns nullopt if the ith dimension is irregular. The ith dimension |
58 | // of a NestedTensor is regular if the unbound tensors match in |
59 | // size at the (i-1)th dimension. |
60 | c10::optional<int64_t> opt_size(int64_t d) const { |
61 | d = at::maybe_wrap_dim(d, dim(), false); |
62 | if (opt_sizes_[d] == -1) { |
63 | return c10::nullopt; |
64 | } |
65 | return opt_sizes_[d]; |
66 | } |
67 | |
68 | int64_t size(int64_t d) const { |
69 | c10::optional<int64_t> optional_size = this->opt_size(d); |
70 | TORCH_CHECK( |
71 | optional_size.has_value(), |
72 | "Given dimension " , |
73 | d, |
74 | " is irregular and does not have a size." ); |
75 | return *optional_size; |
76 | } |
77 | /** |
78 | * Return a view of the nested tensor as a 1 dimensional contiguous tensor. |
79 | * |
80 | * The buffer tensor created by this function shares the same storage_impl as |
81 | * the original nested tensor, and therefore can be seen as a view. |
82 | * |
83 | * @return A newly constructed view tensor |
84 | */ |
85 | at::Tensor get_buffer() const { |
86 | TORCH_CHECK( |
87 | nested_tensor_impl_is_contiguous(this), |
88 | "NestedTensor must be contiguous to get buffer." ); |
89 | return get_unsafe_storage_as_tensor(); |
90 | } |
91 | /** |
92 | * If possible use get_buffer() instead. This function returns the storage |
93 | * as a tensor directly, which is not safe to use in general. If using this |
94 | * function, The caller must ensure to account for nested_sizes, |
95 | * nested_strides and storage_offsets. |
96 | * |
97 | * @return A newly constructed view tensor |
98 | */ |
99 | at::Tensor get_unsafe_storage_as_tensor() const { |
100 | auto buffer_key_set_ = generate_buffer_key_set(); |
101 | const auto buffer_size = get_buffer_size(); |
102 | auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>( |
103 | c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_); |
104 | buffer_tensor_impl->set_sizes_contiguous(c10::makeArrayRef(buffer_size)); |
105 | return Tensor(buffer_tensor_impl); |
106 | } |
107 | |
108 | int64_t get_buffer_size() const { |
109 | return storage_.nbytes() / data_type_.itemsize(); |
110 | } |
111 | |
112 | protected: |
113 | const char* tensorimpl_type_name() const override; |
114 | |
115 | // TODO: numel_custom and is_contiguous_custom can be profitably overridden |
116 | // with real implementations |
117 | int64_t numel_custom() const override; |
118 | c10::SymInt sym_numel_custom() const override; |
119 | bool is_contiguous_custom(MemoryFormat) const override; |
120 | int64_t size_custom(int64_t d) const override { |
121 | return this->size(d); |
122 | } |
123 | c10::SymInt sym_size_custom(int64_t d) const override { |
124 | return c10::SymInt{this->size(d)}; |
125 | } |
126 | IntArrayRef sizes_custom() const override; |
127 | c10::SymIntArrayRef sym_sizes_custom() const override; |
128 | IntArrayRef strides_custom() const override; |
129 | c10::SymIntArrayRef sym_strides_custom() const override; |
130 | |
131 | // this one is real |
132 | int64_t dim_custom() const override; |
133 | |
134 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
135 | const c10::VariableVersion& version_counter, |
136 | bool allow_tensor_metadata_change) const override; |
137 | |
138 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
139 | c10::VariableVersion&& version_counter, |
140 | bool allow_tensor_metadata_change) const override; |
141 | |
142 | void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override { |
143 | copy_tensor_metadata( |
144 | /*src_impl=*/impl.get(), |
145 | /*dest_impl=*/this, |
146 | /*version_counter=*/version_counter(), |
147 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); |
148 | } |
149 | |
150 | private: |
151 | // Must be called after any changes to our dim() to sync the state |
152 | // to TensorImpl. |
153 | void refresh_dim(); |
154 | |
155 | const at::Tensor nested_size_tensor_, nested_stride_tensor_; |
156 | // The starting positions of the underlying tensors in contiguous buffer |
157 | // i.e. the buffer memory offsets to get the underlying tensors |
158 | // The reason to keep this metadata is that, without strong enough constraint |
159 | // it cannot be derived from `nested_size_tensor_` |
160 | // and `nested_stride_tensor_`: |
161 | // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2] |
162 | // this can happen e.g. after slicing a nested tensor |
163 | // 2. when multiple tensors share a same memory |
164 | // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2] |
165 | // Some strong enough constraints are: |
166 | // 1. every underlying tensor is contiguous in memory |
167 | // && nesting in ascending order |
168 | std::vector<int64_t> storage_offsets_; |
169 | // NOTE: -1 here means the size is missing |
170 | // TODO: maybe we can remove this metadata since |
171 | // we can compute it from `nested_size_tensor_` |
172 | std::vector<int64_t> opt_sizes_; |
173 | |
174 | template <typename VariableVersion> |
175 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core( |
176 | VariableVersion&& version_counter, |
177 | bool allow_tensor_metadata_change) const; |
178 | |
179 | /** |
180 | * Generates a non-nested key_set from a nested tensor. |
181 | * |
182 | * For many nested tensor kernel implementations a buffer tensor |
183 | * is generated and redispatched to a non-nested kernel this function |
184 | * generates the key set used by that buffer tensor |
185 | * |
186 | * @return Appropriate key set for non-nested tensor |
187 | */ |
188 | inline c10::DispatchKeySet generate_buffer_key_set() const { |
189 | auto buffer_key_set = this->key_set(); |
190 | const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset); |
191 | // Remove nested tensor specific keys |
192 | buffer_key_set = buffer_key_set - |
193 | c10::DispatchKeySet{ |
194 | c10::DispatchKey::NestedTensor, |
195 | c10::DispatchKey::AutogradNestedTensor}; |
196 | |
197 | // Add dense tensor specific keys |
198 | buffer_key_set = |
199 | buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense}; |
200 | buffer_key_set = Autograd |
201 | ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set |
202 | : buffer_key_set; |
203 | |
204 | return buffer_key_set; |
205 | } |
206 | }; |
207 | |
208 | inline NestedTensorImpl* get_nested_tensor_impl_or_null( |
209 | const at::Tensor& tensor) { |
210 | if (tensor.is_nested()) { |
211 | return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl()); |
212 | } |
213 | return nullptr; |
214 | } |
215 | |
216 | inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) { |
217 | TORCH_CHECK( |
218 | tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor." ); |
219 | return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl()); |
220 | } |
221 | |
222 | inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) { |
223 | int64_t ntensors = nt->size(0); |
224 | if (ntensors == 0) { |
225 | return true; |
226 | } |
227 | const Tensor &sizemat = nt->get_nested_size_tensor(), |
228 | &stridemat = nt->get_nested_stride_tensor(); |
229 | const auto& offsets = nt->get_storage_offsets(); |
230 | int64_t orig_dim = sizemat.size(1); |
231 | // nesting scalars |
232 | if (orig_dim == 0) { |
233 | // each scalar must be contiguous |
234 | // if there is blanck memory between underlying scalars |
235 | for (int64_t i = 0; i < ntensors; i++) { |
236 | if (offsets[i] != i) { |
237 | return false; |
238 | } |
239 | } |
240 | } |
241 | // nesting tensors |
242 | else { |
243 | // if any underlying tensor is noncontiguous |
244 | const int64_t *sizemat_ptr = sizemat.data_ptr<int64_t>(), |
245 | *stridemat_ptr = stridemat.data_ptr<int64_t>(); |
246 | for (int64_t i = 0; i < ntensors; i++) { |
247 | if (stridemat_ptr[orig_dim - 1] != 1) { |
248 | return false; |
249 | } |
250 | int64_t product = sizemat_ptr[orig_dim - 1]; |
251 | for (int64_t j = orig_dim - 2; j >= 0; j--) { |
252 | if (stridemat_ptr[j] != product) { |
253 | return false; |
254 | } |
255 | product *= sizemat_ptr[j]; |
256 | } |
257 | sizemat_ptr += orig_dim; |
258 | stridemat_ptr += orig_dim; |
259 | } |
260 | // if there is blanck memory between underlying tensors |
261 | if (offsets[0] != 0) { |
262 | return false; |
263 | } |
264 | sizemat_ptr = sizemat.data_ptr<int64_t>(); |
265 | stridemat_ptr = stridemat.data_ptr<int64_t>(); |
266 | for (int64_t i = 1; i < ntensors; i++) { |
267 | if (offsets[i] != offsets[i - 1] + *sizemat_ptr * *stridemat_ptr) { |
268 | return false; |
269 | } |
270 | sizemat_ptr += orig_dim; |
271 | stridemat_ptr += orig_dim; |
272 | } |
273 | } |
274 | // everything is fine |
275 | return true; |
276 | } |
277 | |
278 | inline const at::Tensor& get_nested_size_tensor(const at::Tensor& tensor) { |
279 | return get_nested_tensor_impl(tensor)->get_nested_size_tensor(); |
280 | } |
281 | |
282 | } // namespace native |
283 | } // namespace at |
284 | |