1 | #pragma once |
2 | |
3 | #include <ATen/Tensor.h> |
4 | #include <c10/core/TensorImpl.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/irange.h> |
7 | |
8 | #ifndef AT_PER_OPERATOR_HEADERS |
9 | #include <ATen/Functions.h> |
10 | #else |
11 | #include <ATen/ops/empty.h> |
12 | #include <ATen/ops/resize.h> |
13 | #endif |
14 | |
15 | namespace at { |
16 | struct TORCH_API SparseTensorImpl : public TensorImpl { |
17 | // Stored in COO format, indices + values. |
18 | |
19 | // INVARIANTS: |
20 | // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape) |
21 | // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape) |
22 | // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz) |
23 | // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz, |
24 | // shape[sparse_dim:]) |
25 | |
26 | int64_t sparse_dim_ = 0; // number of sparse dimensions |
27 | int64_t dense_dim_ = 0; // number of dense dimensions |
28 | |
29 | Tensor indices_; // always a LongTensor |
30 | Tensor values_; |
31 | |
32 | // A sparse tensor is 'coalesced' if every index occurs at most once in |
33 | // the indices tensor, and the indices are in sorted order. (This means |
34 | // that it is very easy to convert a coalesced tensor to CSR format: you |
35 | // need only compute CSR format indices.) |
36 | // |
37 | // Most math operations can only be performed on coalesced sparse tensors, |
38 | // because many algorithms proceed by merging two sorted lists (of indices). |
39 | bool coalesced_ = false; |
40 | |
41 | // compute_numel with integer multiplication overflow check, see gh-57542 |
42 | void refresh_numel() { |
43 | TensorImpl::safe_refresh_numel(); |
44 | } |
45 | |
46 | public: |
47 | // Public for now... |
48 | explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta); |
49 | |
50 | void release_resources() override; |
51 | |
52 | int64_t nnz() const { |
53 | return values_.size(0); |
54 | } |
55 | |
56 | c10::SymInt sym_nnz() const { |
57 | return values_.sym_size(0); |
58 | } |
59 | int64_t sparse_dim() const { |
60 | return sparse_dim_; |
61 | } |
62 | int64_t dense_dim() const { |
63 | return dense_dim_; |
64 | } |
65 | bool coalesced() const { |
66 | return coalesced_; |
67 | } |
68 | Tensor indices() const { |
69 | return indices_; |
70 | } |
71 | Tensor values() const { |
72 | return values_; |
73 | } |
74 | |
75 | void set_size(int64_t dim, int64_t new_size) override; |
76 | void set_stride(int64_t dim, int64_t new_stride) override; |
77 | void set_storage_offset(int64_t storage_offset) override; |
78 | |
79 | #ifdef DEBUG |
80 | bool has_storage() const override; |
81 | #endif |
82 | |
83 | // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim |
84 | // with respect to indices and values |
85 | void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { |
86 | TORCH_CHECK( |
87 | allow_tensor_metadata_change(), |
88 | "raw_resize_ " , |
89 | err_msg_tensor_metadata_change_not_allowed); |
90 | TORCH_CHECK( |
91 | !has_symbolic_sizes_strides_, |
92 | "raw_resize_ called on tensor with symbolic shape" ) |
93 | set_sizes_and_strides(size, std::vector<int64_t>(size.size())); |
94 | sparse_dim_ = sparse_dim; |
95 | dense_dim_ = dense_dim; |
96 | refresh_numel(); |
97 | } |
98 | |
99 | // NOTE: This function preserves invariants of sparse_dim/dense_dim with |
100 | // respect to indices and values. |
101 | // |
102 | // NOTE: This function supports the following cases: |
103 | // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking |
104 | // the size of any of the dense dimensions. |
105 | // 2. When we keep the number of sparse dimensions unchanged, and NOT |
106 | // shrinking the size of any of the sparse dimensions. |
107 | // 3. When the sparse tensor has zero nnz, in which case we are free to change |
108 | // the shapes of both its sparse and dense dimensions. |
109 | // |
110 | // This function DOESN'T support (and will throw an error) the following |
111 | // cases: |
112 | // 1. When we attempt to change the number of sparse dimensions on a non-empty |
113 | // sparse tensor (such an operation will invalidate the indices stored). |
114 | // 2. When we attempt to change the number of dense dimensions on a non-empty |
115 | // sparse tensor (such an operation will behave differently from an equivalent |
116 | // dense tensor's resize method, and for API consistency we don't support it). |
117 | // 3. When we attempt to shrink the size of any of the dense dimensions on a |
118 | // non-empty sparse tensor (such an operation will behave differently from an |
119 | // equivalent dense tensor's resize method, and for API consistency we don't |
120 | // support it). |
121 | // 4. When we attempt to shrink the size of any of the sparse dimensions on a |
122 | // non-empty sparse tensor (this could make some of the stored indices |
123 | // out-of-bound and thus unsafe). |
124 | template <typename T> |
125 | void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) { |
126 | TORCH_CHECK( |
127 | allow_tensor_metadata_change(), |
128 | "resize_ " , |
129 | err_msg_tensor_metadata_change_not_allowed); |
130 | TORCH_CHECK( |
131 | !has_symbolic_sizes_strides_, |
132 | "resize_ called on tensor with symbolic shape" ) |
133 | TORCH_CHECK( |
134 | sparse_dim + dense_dim == static_cast<int64_t>(size.size()), |
135 | "number of dimensions must be sparse_dim (" , |
136 | sparse_dim, |
137 | ") + dense_dim (" , |
138 | dense_dim, |
139 | "), but got " , |
140 | size.size()); |
141 | if (nnz() > 0) { |
142 | auto alt_options_msg = |
143 | "You could try the following options:\n\ |
144 | 1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\ |
145 | 2. If you need to resize this tensor, you have the following options:\n\ |
146 | 1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\ |
147 | 2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor." ; |
148 | |
149 | TORCH_CHECK( |
150 | sparse_dim == sparse_dim_, |
151 | "changing the number of sparse dimensions (from " , |
152 | sparse_dim_, |
153 | " to " , |
154 | sparse_dim, |
155 | ") on a non-empty sparse tensor is not supported.\n" , |
156 | alt_options_msg); |
157 | |
158 | TORCH_CHECK( |
159 | dense_dim == dense_dim_, |
160 | "changing the number of dense dimensions (from " , |
161 | dense_dim_, |
162 | " to " , |
163 | dense_dim, |
164 | ") on a non-empty sparse tensor is not supported.\n" , |
165 | alt_options_msg); |
166 | |
167 | bool shrinking_sparse_dims = false; |
168 | bool shrinking_dense_dim = false; |
169 | auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim); |
170 | auto sparse_size_new = size.slice(0, sparse_dim); |
171 | for (const auto i : c10::irange(sparse_dim)) { |
172 | if (sparse_size_new[i] < sparse_size_original[i]) { |
173 | shrinking_sparse_dims = true; |
174 | break; |
175 | } |
176 | } |
177 | auto dense_size_original = generic_sizes<T>().slice(sparse_dim); |
178 | auto dense_size_new = size.slice(sparse_dim); |
179 | for (const auto i : c10::irange(dense_dim)) { |
180 | if (dense_size_new[i] < dense_size_original[i]) { |
181 | shrinking_dense_dim = true; |
182 | break; |
183 | } |
184 | } |
185 | |
186 | TORCH_CHECK( |
187 | !shrinking_sparse_dims, |
188 | "shrinking the size of sparse dimensions (from " , |
189 | sparse_size_original, |
190 | " to " , |
191 | sparse_size_new, |
192 | ") on a non-empty sparse tensor is not supported.\n" , |
193 | alt_options_msg); |
194 | |
195 | TORCH_CHECK( |
196 | !shrinking_dense_dim, |
197 | "shrinking the size of dense dimensions (from " , |
198 | dense_size_original, |
199 | " to " , |
200 | dense_size_new, |
201 | ") on a non-empty sparse tensor is not supported.\n" , |
202 | alt_options_msg); |
203 | } |
204 | |
205 | auto sizes_and_strides = generic_sizes<T>(); |
206 | const bool size_equals_sizes = std::equal( |
207 | size.begin(), |
208 | size.end(), |
209 | sizes_and_strides.begin(), |
210 | sizes_and_strides.end()); |
211 | if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) || |
212 | (dense_dim != dense_dim_)) { |
213 | auto nnz = at::symint::sizes<T>(values())[0]; |
214 | std::vector<T> values_size = {nnz}; |
215 | auto dense_size = size.slice(sparse_dim); |
216 | values_size.insert( |
217 | values_size.end(), dense_size.begin(), dense_size.end()); |
218 | at::symint::resize_<T>(values_, values_size); |
219 | at::symint::resize_<T>(indices_, {T(sparse_dim), nnz}); |
220 | } |
221 | |
222 | if (!size_equals_sizes) { |
223 | set_sizes_and_strides(size, std::vector<T>(size.size())); |
224 | } |
225 | sparse_dim_ = sparse_dim; |
226 | dense_dim_ = dense_dim; |
227 | refresh_numel(); |
228 | } |
229 | |
230 | void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) { |
231 | return _resize_(sparse_dim, dense_dim, size); |
232 | } |
233 | |
234 | void resize_( |
235 | int64_t sparse_dim, |
236 | int64_t dense_dim, |
237 | ArrayRef<c10::SymInt> size) { |
238 | return _resize_(sparse_dim, dense_dim, size); |
239 | } |
240 | |
241 | // NOTE: this function will resize the sparse tensor and also set `indices` |
242 | // and `values` to empty. |
243 | void resize_and_clear_( |
244 | int64_t sparse_dim, |
245 | int64_t dense_dim, |
246 | IntArrayRef size) { |
247 | TORCH_CHECK( |
248 | allow_tensor_metadata_change(), |
249 | "resize_and_clear_ " , |
250 | err_msg_tensor_metadata_change_not_allowed); |
251 | TORCH_CHECK( |
252 | !has_symbolic_sizes_strides_, |
253 | "resize_and_clear_ called on tensor with symbolic shape" ) |
254 | TORCH_CHECK( |
255 | sparse_dim + dense_dim == static_cast<int64_t>(size.size()), |
256 | "number of dimensions must be sparse_dim (" , |
257 | sparse_dim, |
258 | ") + dense_dim (" , |
259 | dense_dim, |
260 | "), but got " , |
261 | size.size()); |
262 | |
263 | set_sizes_and_strides(size, std::vector<int64_t>(size.size())); |
264 | sparse_dim_ = sparse_dim; |
265 | dense_dim_ = dense_dim; |
266 | |
267 | auto empty_indices = at::empty({sparse_dim, 0}, indices().options()); |
268 | std::vector<int64_t> values_size = {0}; |
269 | auto dense_size = sizes().slice(sparse_dim); |
270 | values_size.insert(values_size.end(), dense_size.begin(), dense_size.end()); |
271 | auto empty_values = at::empty(values_size, values().options()); |
272 | set_indices_and_values_unsafe(empty_indices, empty_values); |
273 | refresh_numel(); |
274 | } |
275 | |
276 | void set_coalesced(bool coalesced) { |
277 | TORCH_CHECK( |
278 | allow_tensor_metadata_change(), |
279 | "set_coalesced " , |
280 | err_msg_tensor_metadata_change_not_allowed); |
281 | coalesced_ = coalesced; |
282 | } |
283 | |
284 | // NOTE: this function is only used internally and not exposed to Python |
285 | // frontend |
286 | void set_nnz_and_narrow(int64_t new_nnz) { |
287 | TORCH_CHECK( |
288 | allow_tensor_metadata_change(), |
289 | "set_nnz_and_narrow " , |
290 | err_msg_tensor_metadata_change_not_allowed); |
291 | AT_ASSERT(new_nnz <= nnz()); |
292 | indices_ = indices_.narrow(1, 0, new_nnz); |
293 | values_ = values_.narrow(0, 0, new_nnz); |
294 | if (new_nnz < 2) { |
295 | coalesced_ = true; |
296 | } |
297 | } |
298 | |
299 | // Takes indices and values and directly puts them into the sparse tensor, no |
300 | // copy. NOTE: this function is unsafe because it doesn't check whether any |
301 | // indices are out of boundaries of `sizes`, so it should ONLY be used where |
302 | // we know that the indices are guaranteed to be within bounds. This used to |
303 | // be called THSTensor_(_move) NB: This used to be able to avoid a refcount |
304 | // bump, but I was too lazy to make it happen |
305 | void set_indices_and_values_unsafe( |
306 | const Tensor& indices, |
307 | const Tensor& values); |
308 | |
309 | /** |
310 | * Return a TensorImpl that is a shallow-copy of this TensorImpl. |
311 | * |
312 | * For usage of `version_counter` and `allow_tensor_metadata_change`, |
313 | * see NOTE [ TensorImpl Shallow-Copying ]. |
314 | */ |
315 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
316 | const c10::VariableVersion& version_counter, |
317 | bool allow_tensor_metadata_change) const override { |
318 | auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype()); |
319 | copy_tensor_metadata( |
320 | /*src_impl=*/this, |
321 | /*dest_impl=*/impl.get(), |
322 | /*version_counter=*/version_counter, |
323 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
324 | impl->refresh_numel(); |
325 | return impl; |
326 | } |
327 | |
328 | /** |
329 | * Return a TensorImpl that is a shallow-copy of this TensorImpl. |
330 | * |
331 | * For usage of `version_counter` and `allow_tensor_metadata_change`, |
332 | * see NOTE [ TensorImpl Shallow-Copying ]. |
333 | */ |
334 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
335 | c10::VariableVersion&& version_counter, |
336 | bool allow_tensor_metadata_change) const override { |
337 | auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype()); |
338 | copy_tensor_metadata( |
339 | /*src_impl=*/this, |
340 | /*dest_impl=*/impl.get(), |
341 | /*version_counter=*/std::move(version_counter), |
342 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
343 | impl->refresh_numel(); |
344 | return impl; |
345 | } |
346 | |
347 | /** |
348 | * Shallow-copies data from another TensorImpl into this TensorImpl. |
349 | * |
350 | * For why this function doesn't check this TensorImpl's |
351 | * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. |
352 | */ |
353 | void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override { |
354 | AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); |
355 | auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get()); |
356 | copy_tensor_metadata( |
357 | /*src_impl=*/sparse_impl, |
358 | /*dest_impl=*/this, |
359 | /*version_counter=*/version_counter(), |
360 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); |
361 | refresh_numel(); |
362 | } |
363 | |
364 | private: |
365 | explicit SparseTensorImpl( |
366 | at::DispatchKeySet, |
367 | const caffe2::TypeMeta, |
368 | at::Tensor indices, |
369 | at::Tensor values); |
370 | |
371 | /** |
372 | * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / |
373 | * storage_offset) from one TensorImpl to another TensorImpl. |
374 | * |
375 | * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE |
376 | * [ TensorImpl Shallow-Copying ]. |
377 | */ |
378 | static void copy_tensor_metadata( |
379 | const SparseTensorImpl* src_sparse_impl, |
380 | SparseTensorImpl* dest_sparse_impl, |
381 | const c10::VariableVersion& version_counter, |
382 | bool allow_tensor_metadata_change) { |
383 | TensorImpl::copy_tensor_metadata( |
384 | src_sparse_impl, |
385 | dest_sparse_impl, |
386 | version_counter, |
387 | allow_tensor_metadata_change); |
388 | |
389 | // Sparse-specific fields |
390 | dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim(); |
391 | dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim(); |
392 | dest_sparse_impl->indices_ = src_sparse_impl->indices(); |
393 | dest_sparse_impl->values_ = src_sparse_impl->values(); |
394 | dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced(); |
395 | } |
396 | |
397 | const char* tensorimpl_type_name() const override; |
398 | }; |
399 | |
400 | } // namespace at |
401 | |