1 | #include <ATen/ATen.h> |
2 | #include <ATen/NamedTensorUtils.h> |
3 | #include <ATen/WrapDimUtils.h> |
4 | #include <ATen/core/op_registration/op_registration.h> |
5 | #include <ATen/NestedTensorImpl.h> |
6 | #include <c10/core/DispatchKey.h> |
7 | #include <c10/core/DispatchKeySet.h> |
8 | #include <c10/util/Exception.h> |
9 | #include <c10/core/TensorImpl.h> |
10 | #include <c10/util/Logging.h> |
11 | |
12 | #include <numeric> |
13 | #include <functional> |
14 | |
15 | namespace { |
16 | inline void validate_nested_tensor_metadata( |
17 | const at::Tensor& nested_sizes, |
18 | const at::Tensor& nested_strides, |
19 | const std::vector<int64_t>& offsets) { |
20 | TORCH_INTERNAL_ASSERT(nested_sizes.is_contiguous()); |
21 | int64_t size_dim = nested_sizes.dim(); |
22 | TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2); |
23 | TORCH_INTERNAL_ASSERT(nested_strides.is_contiguous()); |
24 | TORCH_INTERNAL_ASSERT(nested_strides.dim() == size_dim); |
25 | TORCH_INTERNAL_ASSERT(nested_sizes.sizes() == nested_strides.sizes()); |
26 | TORCH_INTERNAL_ASSERT( |
27 | (size_dim == 0 && (int64_t)offsets.empty()) || |
28 | (size_dim == 2 && nested_sizes.size(0) == (int64_t)offsets.size())); |
29 | } |
30 | |
31 | /** |
32 | * Generates a nested key_set from a non-nested tensor. |
33 | * |
34 | * When creating a nested tensor from a non-nested tensor |
35 | * We want to maintain the same keyset as the buffer but |
36 | * swap non nested keys for nested ones |
37 | * |
38 | * @return Appropriate key set for nested tensor |
39 | */ |
40 | inline c10::DispatchKeySet generate_nested_key_set_from_buffer( |
41 | const at::Tensor& buffer) { |
42 | auto nested_key_set = buffer.key_set(); |
43 | const bool has_autograd = nested_key_set.has_any(c10::autograd_dispatch_keyset); |
44 | // Remove non_nested tensor specific keys |
45 | nested_key_set = nested_key_set - |
46 | c10::DispatchKeySet{c10::DispatchKey::Dense, c10::DispatchKey::Autograd}; |
47 | |
48 | // Add nested tensor specific keys |
49 | nested_key_set = |
50 | nested_key_set | c10::DispatchKeySet{c10::DispatchKey::NestedTensor}; |
51 | nested_key_set = |
52 | has_autograd ? nested_key_set | c10::autograd_nested : nested_key_set; |
53 | return nested_key_set; |
54 | } |
55 | |
56 | /** |
57 | * Generates a the correct view keyset. |
58 | * |
59 | * When creating a nested tensor view of base |
60 | * The appropriate keyset will be dependent on the nested |
61 | * status of the base |
62 | * |
63 | * @return Appropriate key set for nested tensor |
64 | */ |
65 | c10::DispatchKeySet get_view_key_set(const at::Tensor& base) { |
66 | return base.is_nested() ? base.key_set() |
67 | : generate_nested_key_set_from_buffer(base); |
68 | } |
69 | |
70 | } // namespace |
71 | namespace at { |
72 | namespace native { |
73 | |
74 | inline std::vector<int64_t> construct_opt_sizes(const at::Tensor& sizes) { |
75 | // torch.tensor([]) is considered to have `dim() = 1` and `size(0) = 0` |
76 | // torch.nested_tensor([]) should also has `dim() = 1` and `size(0) = 0` |
77 | if (sizes.dim() == 0) { |
78 | return std::vector<int64_t>({0}); |
79 | } |
80 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2); |
81 | std::vector<int64_t> result(1, sizes.sizes()[0]); |
82 | if (sizes.dim() > 0) { |
83 | size_t nested_dim = result.size(); |
84 | int64_t* sizes_ptr = sizes.data_ptr<int64_t>(); |
85 | result.resize(nested_dim + sizes.sizes()[1]); |
86 | int64_t sizes_size_0 = sizes.sizes()[0]; |
87 | int64_t sizes_size_1 = sizes.sizes()[1]; |
88 | for (const auto i : c10::irange(sizes_size_1)) { |
89 | result[nested_dim + i] = sizes_ptr[i]; |
90 | } |
91 | for (const auto j : c10::irange(sizes_size_1)) { |
92 | for (const auto i : c10::irange(sizes_size_0)) { |
93 | if (result[nested_dim + j] && |
94 | (result[nested_dim + j] != sizes_ptr[i * sizes.size(1) + j])) { |
95 | result[nested_dim + j] = -1; |
96 | } |
97 | } |
98 | } |
99 | } |
100 | return result; |
101 | } |
102 | |
103 | // assume contiguous, we can construct stride from size |
104 | inline at::Tensor construct_nested_stride_tensor(const at::Tensor& sizes) { |
105 | // empty `sizes` means empty nested tensor, so return empty strides |
106 | if (sizes.dim() == 0) { |
107 | return sizes; |
108 | } |
109 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2); |
110 | int64_t orig_dim = sizes.size(1); |
111 | // `sizes`.sizes() = ntensors x 0 means empty but shaped `sizes` |
112 | // in this case strides is also empty but shaped |
113 | if (orig_dim == 0) { |
114 | return sizes; |
115 | } |
116 | at::Tensor strides = sizes.new_empty(sizes.sizes()); |
117 | const int64_t* sizes_ptr = sizes.data_ptr<int64_t>(); |
118 | int64_t* strides_ptr = strides.data_ptr<int64_t>(); |
119 | for (int64_t i = 0; i < sizes.size(0); i++) { |
120 | strides_ptr[orig_dim - 1] = 1; |
121 | int64_t product = sizes_ptr[orig_dim - 1]; |
122 | for (int64_t j = orig_dim - 2; j >= 0; j--) { |
123 | strides_ptr[j] = product; |
124 | product *= sizes_ptr[j]; |
125 | } |
126 | sizes_ptr += orig_dim; |
127 | strides_ptr += orig_dim; |
128 | } |
129 | return strides; |
130 | } |
131 | |
132 | /** |
133 | * Create a vector of offsets assuming the nested tensor is contiguous |
134 | * |
135 | * This function iterates over the implicit ntensor outer dimension |
136 | * populating a vector with the num_elements in each implicit tensor. |
137 | * The first element is always 0 and the length of the returned vector |
138 | * is n_tensor. |
139 | * |
140 | * @return A vector of offsets |
141 | */ |
142 | inline std::vector<int64_t> construct_offsets(const at::Tensor& sizes) { |
143 | // empty `sizes` means empty nested tensor, so return empty strides |
144 | if (sizes.dim() == 0) { |
145 | return std::vector<int64_t>(); |
146 | } |
147 | int64_t ntensors = sizes.size(0), orig_dim = sizes.size(1); |
148 | std::vector<int64_t> offsets(ntensors); |
149 | // nesting scalars has easy offsets |
150 | if (orig_dim == 0) { |
151 | std::iota(offsets.begin(), offsets.end(), 0); |
152 | return offsets; |
153 | } |
154 | const int64_t* sizes_ptr = sizes.data_ptr<int64_t>(); |
155 | offsets[0] = 0; |
156 | for (const auto i : c10::irange(ntensors - 1)) { |
157 | const int64_t row_product = std::accumulate(sizes_ptr, sizes_ptr + orig_dim, 1, std::multiplies<int64_t>()); |
158 | offsets[i + 1] = offsets[i] + row_product; |
159 | sizes_ptr += orig_dim; |
160 | } |
161 | return offsets; |
162 | } |
163 | |
164 | NestedTensorImpl::NestedTensorImpl( |
165 | Storage storage, |
166 | c10::DispatchKeySet key_set, |
167 | const caffe2::TypeMeta data_type, |
168 | at::Tensor nested_size_tensor, |
169 | at::Tensor nested_stride_tensor, |
170 | std::vector<int64_t>&& offsets) |
171 | : TensorImpl(std::move(storage), key_set, data_type), |
172 | nested_size_tensor_(std::move(nested_size_tensor)), |
173 | nested_stride_tensor_(std::move(nested_stride_tensor)), |
174 | storage_offsets_(std::move(offsets)), |
175 | opt_sizes_(construct_opt_sizes(nested_size_tensor_)) { |
176 | C10_LOG_API_USAGE_ONCE("torch.NestedTensor" ); |
177 | TORCH_WARN_ONCE( |
178 | "The PyTorch API of nested tensors is in prototype stage and will change " |
179 | "in the near future." ); |
180 | auto storage_device = storage_.device(); |
181 | TORCH_INTERNAL_ASSERT( |
182 | storage_device.is_cpu() || storage_device.is_cuda(), |
183 | "NestedTensorImpl storage must be either CUDA or CPU but got " , |
184 | storage_device); |
185 | validate_nested_tensor_metadata(nested_size_tensor_, nested_stride_tensor_, storage_offsets_); |
186 | refresh_dim(); |
187 | set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes); |
188 | } |
189 | |
190 | NestedTensorImpl::NestedTensorImpl( |
191 | at::Tensor buffer, |
192 | at::Tensor nested_size_tensor, |
193 | at::Tensor nested_stride_tensor, |
194 | std::vector<int64_t>&& offsets) |
195 | : NestedTensorImpl( |
196 | buffer.storage(), |
197 | generate_nested_key_set_from_buffer(buffer), |
198 | buffer.dtype(), |
199 | nested_size_tensor, |
200 | nested_stride_tensor, |
201 | std::move(offsets)) { |
202 | |
203 | TORCH_INTERNAL_ASSERT( |
204 | buffer.dim() == 1, |
205 | "NestedTensorImpl buffer is required to be 1 dimensional but got a buffer with " , |
206 | buffer.dim(), |
207 | " dimensions." ); |
208 | } |
209 | |
210 | // assume contiguous, `nested_stride_tensor` and `offsets` |
211 | // can be infered from `nested_size_tensor` |
212 | NestedTensorImpl::NestedTensorImpl( |
213 | at::Tensor buffer, |
214 | at::Tensor nested_size_tensor) |
215 | : NestedTensorImpl( |
216 | buffer, |
217 | nested_size_tensor, |
218 | construct_nested_stride_tensor(nested_size_tensor), |
219 | construct_offsets(nested_size_tensor)) |
220 | {} |
221 | |
222 | NestedTensorImpl::NestedTensorImpl( |
223 | c10::TensorImpl::ImplType impl_type, |
224 | const at::Tensor& base_tensor, |
225 | at::Tensor nested_size_tensor, |
226 | at::Tensor nested_stride_tensor, |
227 | std::vector<int64_t>&& offsets) |
228 | : TensorImpl(impl_type, Storage(base_tensor.storage()), get_view_key_set(base_tensor), base_tensor.dtype()), |
229 | nested_size_tensor_(std::move(nested_size_tensor)), |
230 | nested_stride_tensor_(std::move(nested_stride_tensor)), |
231 | storage_offsets_(std::move(offsets)), |
232 | opt_sizes_(construct_opt_sizes(nested_size_tensor_)) { |
233 | validate_nested_tensor_metadata(nested_size_tensor_, nested_stride_tensor_, storage_offsets_); |
234 | refresh_dim(); |
235 | set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes); |
236 | } |
237 | |
238 | void NestedTensorImpl::refresh_dim() { |
239 | const auto my_dim = nested_size_tensor_.dim() ? nested_size_tensor_.sizes()[1] + 1 : 1; |
240 | sizes_and_strides_.resize(my_dim); |
241 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim() == my_dim); |
242 | } |
243 | |
244 | int64_t NestedTensorImpl::dim_custom() const { |
245 | return dim_default(); |
246 | } |
247 | |
248 | // Currently sizes and strides assume contiguous |
249 | int64_t NestedTensorImpl::numel_custom() const { |
250 | if (nested_size_tensor_.dim() == 0) { |
251 | return 0; |
252 | } |
253 | constexpr auto numel_max = std::min( |
254 | static_cast<uint64_t>(std::numeric_limits<int64_t>::max()), |
255 | static_cast<uint64_t>(std::numeric_limits<size_t>::max())); |
256 | |
257 | const auto nt_dim = nested_size_tensor_.size(1); |
258 | const int64_t* sizes_ptr = nested_size_tensor_.data_ptr<int64_t>(); |
259 | uint64_t num_elements{0}; |
260 | |
261 | for (const auto i : c10::irange(nested_size_tensor_.size(0))) { |
262 | uint64_t n = 1; |
263 | const auto start{sizes_ptr + i * nt_dim}; |
264 | const auto end{start + nt_dim}; |
265 | bool overflows = c10::safe_multiplies_u64(start, end, &n); |
266 | num_elements += n; |
267 | overflows |= (num_elements > numel_max); |
268 | TORCH_CHECK(!overflows, "numel: integer multiplication overflow" ); |
269 | } |
270 | return static_cast<int64_t>(num_elements); |
271 | } |
272 | |
273 | |
274 | c10::SymInt NestedTensorImpl::sym_numel_custom() const { |
275 | return NestedTensorImpl::numel_custom(); |
276 | } |
277 | |
278 | bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const { |
279 | return nested_tensor_impl_is_contiguous(this); |
280 | } |
281 | IntArrayRef NestedTensorImpl::sizes_custom() const { |
282 | TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor" ); |
283 | } |
284 | c10::SymIntArrayRef NestedTensorImpl::sym_sizes_custom() const { |
285 | TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor" ); |
286 | } |
287 | |
288 | c10::SymIntArrayRef NestedTensorImpl::sym_strides_custom() const { |
289 | TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor" ); |
290 | } |
291 | |
292 | IntArrayRef NestedTensorImpl::strides_custom() const { |
293 | TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor" ); |
294 | } |
295 | |
296 | const char* NestedTensorImpl::tensorimpl_type_name() const { |
297 | return "NestedTensorImpl" ; |
298 | } |
299 | |
300 | |
301 | template <typename VariableVersion> |
302 | c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach_core( |
303 | VariableVersion&& version_counter, |
304 | bool allow_tensor_metadata_change) const { |
305 | if (key_set_.has(DispatchKey::Python) && |
306 | !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { |
307 | auto r = pyobj_slot_.load_pyobj_interpreter()->detach(this); |
308 | if (r) { |
309 | r->set_version_counter(std::forward<VariableVersion>(version_counter)); |
310 | r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); |
311 | return r; |
312 | } |
313 | // otherwise just copy the TensorImpl and not the PyObject. Since |
314 | // the interpreter is dead no one can call us out on it |
315 | } |
316 | auto impl = c10::make_intrusive<NestedTensorImpl>( |
317 | storage_, |
318 | key_set_, |
319 | data_type_, |
320 | nested_size_tensor_, |
321 | nested_stride_tensor_, |
322 | std::vector<int64_t>(storage_offsets_)); |
323 | |
324 | copy_tensor_metadata( |
325 | /*src_impl=*/this, |
326 | /*dest_impl=*/impl.get(), |
327 | /*version_counter=*/std::forward<VariableVersion>(version_counter), |
328 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
329 | return impl; |
330 | } |
331 | |
332 | c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach( |
333 | const c10::VariableVersion& version_counter, |
334 | bool allow_tensor_metadata_change) const { |
335 | return shallow_copy_and_detach_core( |
336 | version_counter, allow_tensor_metadata_change); |
337 | } |
338 | |
339 | c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach( |
340 | c10::VariableVersion&& version_counter, |
341 | bool allow_tensor_metadata_change) const { |
342 | return shallow_copy_and_detach_core( |
343 | std::move(version_counter), allow_tensor_metadata_change); |
344 | } |
345 | |
346 | } // namespace native |
347 | } // namespace at |
348 | |