1#pragma once
2
3#include <ATen/Tensor.h>
4#include <c10/core/TensorImpl.h>
5#include <c10/util/Exception.h>
6#include <c10/util/irange.h>
7
8#ifndef AT_PER_OPERATOR_HEADERS
9#include <ATen/Functions.h>
10#else
11#include <ATen/ops/empty.h>
12#include <ATen/ops/resize.h>
13#endif
14
15namespace at {
16struct TORCH_API SparseTensorImpl : public TensorImpl {
17 // Stored in COO format, indices + values.
18
19 // INVARIANTS:
20 // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
21 // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
22 // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
23 // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz,
24 // shape[sparse_dim:])
25
26 int64_t sparse_dim_ = 0; // number of sparse dimensions
27 int64_t dense_dim_ = 0; // number of dense dimensions
28
29 Tensor indices_; // always a LongTensor
30 Tensor values_;
31
32 // A sparse tensor is 'coalesced' if every index occurs at most once in
33 // the indices tensor, and the indices are in sorted order. (This means
34 // that it is very easy to convert a coalesced tensor to CSR format: you
35 // need only compute CSR format indices.)
36 //
37 // Most math operations can only be performed on coalesced sparse tensors,
38 // because many algorithms proceed by merging two sorted lists (of indices).
39 bool coalesced_ = false;
40
41 // compute_numel with integer multiplication overflow check, see gh-57542
42 void refresh_numel() {
43 TensorImpl::safe_refresh_numel();
44 }
45
46 public:
47 // Public for now...
48 explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
49
50 void release_resources() override;
51
52 int64_t nnz() const {
53 return values_.size(0);
54 }
55
56 c10::SymInt sym_nnz() const {
57 return values_.sym_size(0);
58 }
59 int64_t sparse_dim() const {
60 return sparse_dim_;
61 }
62 int64_t dense_dim() const {
63 return dense_dim_;
64 }
65 bool coalesced() const {
66 return coalesced_;
67 }
68 Tensor indices() const {
69 return indices_;
70 }
71 Tensor values() const {
72 return values_;
73 }
74
75 void set_size(int64_t dim, int64_t new_size) override;
76 void set_stride(int64_t dim, int64_t new_stride) override;
77 void set_storage_offset(int64_t storage_offset) override;
78
79#ifdef DEBUG
80 bool has_storage() const override;
81#endif
82
83 // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim
84 // with respect to indices and values
85 void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
86 TORCH_CHECK(
87 allow_tensor_metadata_change(),
88 "raw_resize_ ",
89 err_msg_tensor_metadata_change_not_allowed);
90 TORCH_CHECK(
91 !has_symbolic_sizes_strides_,
92 "raw_resize_ called on tensor with symbolic shape")
93 set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
94 sparse_dim_ = sparse_dim;
95 dense_dim_ = dense_dim;
96 refresh_numel();
97 }
98
99 // NOTE: This function preserves invariants of sparse_dim/dense_dim with
100 // respect to indices and values.
101 //
102 // NOTE: This function supports the following cases:
103 // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking
104 // the size of any of the dense dimensions.
105 // 2. When we keep the number of sparse dimensions unchanged, and NOT
106 // shrinking the size of any of the sparse dimensions.
107 // 3. When the sparse tensor has zero nnz, in which case we are free to change
108 // the shapes of both its sparse and dense dimensions.
109 //
110 // This function DOESN'T support (and will throw an error) the following
111 // cases:
112 // 1. When we attempt to change the number of sparse dimensions on a non-empty
113 // sparse tensor (such an operation will invalidate the indices stored).
114 // 2. When we attempt to change the number of dense dimensions on a non-empty
115 // sparse tensor (such an operation will behave differently from an equivalent
116 // dense tensor's resize method, and for API consistency we don't support it).
117 // 3. When we attempt to shrink the size of any of the dense dimensions on a
118 // non-empty sparse tensor (such an operation will behave differently from an
119 // equivalent dense tensor's resize method, and for API consistency we don't
120 // support it).
121 // 4. When we attempt to shrink the size of any of the sparse dimensions on a
122 // non-empty sparse tensor (this could make some of the stored indices
123 // out-of-bound and thus unsafe).
124 template <typename T>
125 void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) {
126 TORCH_CHECK(
127 allow_tensor_metadata_change(),
128 "resize_ ",
129 err_msg_tensor_metadata_change_not_allowed);
130 TORCH_CHECK(
131 !has_symbolic_sizes_strides_,
132 "resize_ called on tensor with symbolic shape")
133 TORCH_CHECK(
134 sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
135 "number of dimensions must be sparse_dim (",
136 sparse_dim,
137 ") + dense_dim (",
138 dense_dim,
139 "), but got ",
140 size.size());
141 if (nnz() > 0) {
142 auto alt_options_msg =
143 "You could try the following options:\n\
1441. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
1452. If you need to resize this tensor, you have the following options:\n\
146 1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
147 2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
148
149 TORCH_CHECK(
150 sparse_dim == sparse_dim_,
151 "changing the number of sparse dimensions (from ",
152 sparse_dim_,
153 " to ",
154 sparse_dim,
155 ") on a non-empty sparse tensor is not supported.\n",
156 alt_options_msg);
157
158 TORCH_CHECK(
159 dense_dim == dense_dim_,
160 "changing the number of dense dimensions (from ",
161 dense_dim_,
162 " to ",
163 dense_dim,
164 ") on a non-empty sparse tensor is not supported.\n",
165 alt_options_msg);
166
167 bool shrinking_sparse_dims = false;
168 bool shrinking_dense_dim = false;
169 auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim);
170 auto sparse_size_new = size.slice(0, sparse_dim);
171 for (const auto i : c10::irange(sparse_dim)) {
172 if (sparse_size_new[i] < sparse_size_original[i]) {
173 shrinking_sparse_dims = true;
174 break;
175 }
176 }
177 auto dense_size_original = generic_sizes<T>().slice(sparse_dim);
178 auto dense_size_new = size.slice(sparse_dim);
179 for (const auto i : c10::irange(dense_dim)) {
180 if (dense_size_new[i] < dense_size_original[i]) {
181 shrinking_dense_dim = true;
182 break;
183 }
184 }
185
186 TORCH_CHECK(
187 !shrinking_sparse_dims,
188 "shrinking the size of sparse dimensions (from ",
189 sparse_size_original,
190 " to ",
191 sparse_size_new,
192 ") on a non-empty sparse tensor is not supported.\n",
193 alt_options_msg);
194
195 TORCH_CHECK(
196 !shrinking_dense_dim,
197 "shrinking the size of dense dimensions (from ",
198 dense_size_original,
199 " to ",
200 dense_size_new,
201 ") on a non-empty sparse tensor is not supported.\n",
202 alt_options_msg);
203 }
204
205 auto sizes_and_strides = generic_sizes<T>();
206 const bool size_equals_sizes = std::equal(
207 size.begin(),
208 size.end(),
209 sizes_and_strides.begin(),
210 sizes_and_strides.end());
211 if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) ||
212 (dense_dim != dense_dim_)) {
213 auto nnz = at::symint::sizes<T>(values())[0];
214 std::vector<T> values_size = {nnz};
215 auto dense_size = size.slice(sparse_dim);
216 values_size.insert(
217 values_size.end(), dense_size.begin(), dense_size.end());
218 at::symint::resize_<T>(values_, values_size);
219 at::symint::resize_<T>(indices_, {T(sparse_dim), nnz});
220 }
221
222 if (!size_equals_sizes) {
223 set_sizes_and_strides(size, std::vector<T>(size.size()));
224 }
225 sparse_dim_ = sparse_dim;
226 dense_dim_ = dense_dim;
227 refresh_numel();
228 }
229
230 void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) {
231 return _resize_(sparse_dim, dense_dim, size);
232 }
233
234 void resize_(
235 int64_t sparse_dim,
236 int64_t dense_dim,
237 ArrayRef<c10::SymInt> size) {
238 return _resize_(sparse_dim, dense_dim, size);
239 }
240
241 // NOTE: this function will resize the sparse tensor and also set `indices`
242 // and `values` to empty.
243 void resize_and_clear_(
244 int64_t sparse_dim,
245 int64_t dense_dim,
246 IntArrayRef size) {
247 TORCH_CHECK(
248 allow_tensor_metadata_change(),
249 "resize_and_clear_ ",
250 err_msg_tensor_metadata_change_not_allowed);
251 TORCH_CHECK(
252 !has_symbolic_sizes_strides_,
253 "resize_and_clear_ called on tensor with symbolic shape")
254 TORCH_CHECK(
255 sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
256 "number of dimensions must be sparse_dim (",
257 sparse_dim,
258 ") + dense_dim (",
259 dense_dim,
260 "), but got ",
261 size.size());
262
263 set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
264 sparse_dim_ = sparse_dim;
265 dense_dim_ = dense_dim;
266
267 auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
268 std::vector<int64_t> values_size = {0};
269 auto dense_size = sizes().slice(sparse_dim);
270 values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
271 auto empty_values = at::empty(values_size, values().options());
272 set_indices_and_values_unsafe(empty_indices, empty_values);
273 refresh_numel();
274 }
275
276 void set_coalesced(bool coalesced) {
277 TORCH_CHECK(
278 allow_tensor_metadata_change(),
279 "set_coalesced ",
280 err_msg_tensor_metadata_change_not_allowed);
281 coalesced_ = coalesced;
282 }
283
284 // NOTE: this function is only used internally and not exposed to Python
285 // frontend
286 void set_nnz_and_narrow(int64_t new_nnz) {
287 TORCH_CHECK(
288 allow_tensor_metadata_change(),
289 "set_nnz_and_narrow ",
290 err_msg_tensor_metadata_change_not_allowed);
291 AT_ASSERT(new_nnz <= nnz());
292 indices_ = indices_.narrow(1, 0, new_nnz);
293 values_ = values_.narrow(0, 0, new_nnz);
294 if (new_nnz < 2) {
295 coalesced_ = true;
296 }
297 }
298
299 // Takes indices and values and directly puts them into the sparse tensor, no
300 // copy. NOTE: this function is unsafe because it doesn't check whether any
301 // indices are out of boundaries of `sizes`, so it should ONLY be used where
302 // we know that the indices are guaranteed to be within bounds. This used to
303 // be called THSTensor_(_move) NB: This used to be able to avoid a refcount
304 // bump, but I was too lazy to make it happen
305 void set_indices_and_values_unsafe(
306 const Tensor& indices,
307 const Tensor& values);
308
309 /**
310 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
311 *
312 * For usage of `version_counter` and `allow_tensor_metadata_change`,
313 * see NOTE [ TensorImpl Shallow-Copying ].
314 */
315 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
316 const c10::VariableVersion& version_counter,
317 bool allow_tensor_metadata_change) const override {
318 auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
319 copy_tensor_metadata(
320 /*src_impl=*/this,
321 /*dest_impl=*/impl.get(),
322 /*version_counter=*/version_counter,
323 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
324 impl->refresh_numel();
325 return impl;
326 }
327
328 /**
329 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
330 *
331 * For usage of `version_counter` and `allow_tensor_metadata_change`,
332 * see NOTE [ TensorImpl Shallow-Copying ].
333 */
334 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
335 c10::VariableVersion&& version_counter,
336 bool allow_tensor_metadata_change) const override {
337 auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
338 copy_tensor_metadata(
339 /*src_impl=*/this,
340 /*dest_impl=*/impl.get(),
341 /*version_counter=*/std::move(version_counter),
342 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
343 impl->refresh_numel();
344 return impl;
345 }
346
347 /**
348 * Shallow-copies data from another TensorImpl into this TensorImpl.
349 *
350 * For why this function doesn't check this TensorImpl's
351 * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
352 */
353 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
354 AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
355 auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
356 copy_tensor_metadata(
357 /*src_impl=*/sparse_impl,
358 /*dest_impl=*/this,
359 /*version_counter=*/version_counter(),
360 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
361 refresh_numel();
362 }
363
364 private:
365 explicit SparseTensorImpl(
366 at::DispatchKeySet,
367 const caffe2::TypeMeta,
368 at::Tensor indices,
369 at::Tensor values);
370
371 /**
372 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
373 * storage_offset) from one TensorImpl to another TensorImpl.
374 *
375 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
376 * [ TensorImpl Shallow-Copying ].
377 */
378 static void copy_tensor_metadata(
379 const SparseTensorImpl* src_sparse_impl,
380 SparseTensorImpl* dest_sparse_impl,
381 const c10::VariableVersion& version_counter,
382 bool allow_tensor_metadata_change) {
383 TensorImpl::copy_tensor_metadata(
384 src_sparse_impl,
385 dest_sparse_impl,
386 version_counter,
387 allow_tensor_metadata_change);
388
389 // Sparse-specific fields
390 dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
391 dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
392 dest_sparse_impl->indices_ = src_sparse_impl->indices();
393 dest_sparse_impl->values_ = src_sparse_impl->values();
394 dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
395 }
396
397 const char* tensorimpl_type_name() const override;
398};
399
400} // namespace at
401