1// Basic functions on sparse tensors
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3
4#include <ATen/core/Tensor.h>
5#include <ATen/Dispatch.h>
6#include <ATen/InitialTensorOptions.h>
7#include <ATen/Layout.h>
8#include <ATen/Parallel.h>
9#include <ATen/SparseCsrTensorImpl.h>
10#include <ATen/SparseCsrTensorUtils.h>
11#include <ATen/SparseTensorImpl.h>
12#include <ATen/native/LinearAlgebraUtils.h>
13
14#ifndef AT_PER_OPERATOR_HEADERS
15#include <ATen/Functions.h>
16#include <ATen/NativeFunctions.h>
17#else
18#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
19#include <ATen/ops/_nnz_native.h>
20#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
21#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
22#include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
23#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
24#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
25#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
26#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
27#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
28#include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
29#include <ATen/ops/_validate_sparse_csc_tensor_args_native.h>
30#include <ATen/ops/_validate_sparse_bsr_tensor_args_native.h>
31#include <ATen/ops/_validate_sparse_bsc_tensor_args_native.h>
32#include <ATen/ops/aminmax.h>
33#include <ATen/ops/ccol_indices_native.h>
34#include <ATen/ops/clone_native.h>
35#include <ATen/ops/col_indices_native.h>
36#include <ATen/ops/copy_native.h>
37#include <ATen/ops/crow_indices_native.h>
38#include <ATen/ops/dense_dim_native.h>
39#include <ATen/ops/empty.h>
40#include <ATen/ops/empty_like_native.h>
41#include <ATen/ops/empty_native.h>
42#include <ATen/ops/resize_as_sparse_native.h>
43#include <ATen/ops/resize_native.h>
44#include <ATen/ops/row_indices_native.h>
45#include <ATen/ops/select_native.h>
46#include <ATen/ops/select_copy.h>
47#include <ATen/ops/select_copy_native.h>
48#include <ATen/ops/sparse_compressed_tensor_native.h>
49#include <ATen/ops/sparse_csr_tensor_native.h>
50#include <ATen/ops/sparse_csc_tensor_native.h>
51#include <ATen/ops/sparse_bsr_tensor_native.h>
52#include <ATen/ops/sparse_bsc_tensor_native.h>
53#include <ATen/ops/sparse_dim_native.h>
54#include <ATen/ops/values_native.h>
55#include <ATen/ops/_validate_compressed_sparse_indices.h>
56#include <ATen/ops/where.h>
57#endif
58
59namespace at {
60namespace native {
61
62using namespace at::sparse_csr;
63
64namespace {
65
66bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& step) {
67 /*
68 This function solves the equation
69
70 input == arange(start, end, step)
71
72 for integers start, end, and step, if possible. If the solution
73 exists, returns true.
74 */
75 int64_t n = input.numel();
76 if (n == 0) {
77 // a trivial solution
78 start = end = 0;
79 step = 1;
80 } else if (n == 1) {
81 // a simple solution
82 start = input[0].item<int64_t>();
83 end = start + 1;
84 step = 1;
85 } else {
86 Tensor first_last = input.slice(0, 0, n, n - 1).cpu();
87 int64_t start_candidate = first_last[0].item<int64_t>();
88 int64_t end_candidate = first_last[1].item<int64_t>() + 1;
89 if (end_candidate - start_candidate == n) {
90 // a special solution
91 start = start_candidate;
92 end = end_candidate;
93 step = 1;
94 } else {
95 // detect if general solution exists
96 Tensor possible_steps = input.slice(0, 1).sub(input.slice(0, 0, n - 1));
97 Tensor possible_step = possible_steps[0];
98 if ((possible_steps.eq(possible_step)).all().item<bool>()) {
99 start = start_candidate;
100 end = end_candidate;
101 step = possible_step.item<int64_t>();
102 } else {
103 // no solution
104 return false;
105 }
106 }
107 }
108 return true;
109}
110
111} // end anonymous namespace
112
113/*
114 Validate the arguments to sparse compressed (CSR, CSC, BSR, and BSC)
115 tensor factory functions.
116
117 The CSR and BSR invariants for PyTorch are outlined in
118
119 https://pearu.github.io/csr_tensor_invariants.html
120 https://pearu.github.io/bsr_tensor_invariants.html
121
122 that in what follows are generalized for all sparse compressed
123 formats with support to batched and dense dimensions.
124*/
125
126void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) {
127 // Layout must be Sparse Compressed, 2.4
128 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{});
129
130 const std::string layout_name = layoutToString(layout, /*upper=*/ true);
131 const std::string compressed_indices_name = compressedIndicesName(layout);
132 const std::string plain_indices_name = plainIndicesName(layout);
133 const std::string compressed_dim_name = compressedDimName(layout);
134 const std::string plain_dim_name = plainDimName(layout);
135
136 // Layout Invariants
137
138 // Re 3.5 and 3.6: in the case of compressed/plain indices tensors,
139 // we require contiguity per-patch basis, that is, the last stride
140 // of these indices must be 1. The reasoning for this is that
141 // indices tensors within a patch are "atomic" in the sense that
142 // sliced compressed/plain indices would not represent the indices
143 // of any sparse compressed tensor as the slicing would break the
144 // description of the tensor index structure.
145
146 // 2.1
147 TORCH_CHECK(plain_indices.layout() == kStrided,
148 "expected ", plain_indices_name, " to be a strided tensor but got ", plain_indices.layout(), " tensor");
149
150 // 2.2
151 TORCH_CHECK(compressed_indices.layout() == kStrided,
152 "expected ", compressed_indices_name, " to be a strided tensor but got ", compressed_indices.layout(), " tensor");
153
154 const int base_ndim = 2; // corresponds to compressed and plain indices
155 const int batch_ndim = compressed_indices.dim() - 1;
156 const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
157 layout, "validate_sparse_compressed_tensor_args",
158 [&] { return 0; }, [&] { return 2; });
159 const int dense_ndim = values.dim() - batch_ndim - block_ndim - 1;
160
161 // 2.3
162 TORCH_CHECK(values.layout() == kStrided,
163 "expected values to be a strided tensor but got ", values.layout(), " tensor");
164
165 // 3.7 is dropped, that is, values tensor does not need to be
166 // contiguous, in general. Particular algorithms on sparse
167 // compressed tensors may require contiguity though.
168
169 // Shape and Strides invariants
170
171 // 3.2
172 TORCH_CHECK(
173 batch_ndim >= 0,
174 compressed_indices_name, " must have dimensionality >= 1 but got ", compressed_indices.dim());
175
176 // 3.3
177 TORCH_CHECK(
178 compressed_indices.dim() == plain_indices.dim(),
179 compressed_indices_name, " and ", plain_indices_name, " dimensionalities must be equal but got ",
180 compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively");
181
182 // 3.4
183 TORCH_CHECK(
184 dense_ndim >= 0,
185 "values must have dimensionality > sum of batch and block dimensionalities (=",
186 batch_ndim, " + ", block_ndim, ") but got ", values.dim());
187
188 // 3.5
189 TORCH_CHECK(plain_indices.stride(-1) == 1,
190 "expected ", plain_indices_name, " to be a contiguous tensor per batch");
191
192 // 3.6
193 TORCH_CHECK(compressed_indices.stride(-1) == 1,
194 "expected ", compressed_indices_name, " to be a contiguous tensor per batch");
195
196 // 3.1
197 TORCH_CHECK(
198 static_cast<int>(size.size()) == batch_ndim + base_ndim + dense_ndim,
199 "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=",
200 batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size());
201
202 // For CSR/CSC formats, we define blocksize=(1, 1) so that checking
203 // the sparse compressed tensor invariants can be unified with the
204 // BSR/BSC invariants.
205 // 3.10
206 DimVector blocksize{
207 (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 1)) : 1),
208 (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 2)) : 1),
209 };
210 TORCH_INTERNAL_ASSERT(blocksize.size() == 2 && blocksize[0] > 0 && blocksize[1] > 0);
211
212 // All batch sizes must be the same and consistent with tensor batchsize, 3.1, 3.8, 3.9, 3.10
213 DimVector batchsize = DimVector(size.slice(0, batch_ndim));
214 DimVector compressed_indices_batchsize = DimVector(compressed_indices.sizes().slice(0, batch_ndim));
215 DimVector plain_indices_batchsize = DimVector(plain_indices.sizes().slice(0, batch_ndim));
216 DimVector values_batchsize = DimVector(values.sizes().slice(0, batch_ndim));
217 const int values_nnz = (values.numel() ? values.size(batch_ndim) : 0);
218 DimVector values_blocksize = DimVector(values.sizes().slice(batch_ndim + 1, block_ndim));
219 DimVector values_densesize = DimVector(values.sizes().slice(batch_ndim + 1 + block_ndim, dense_ndim));
220 TORCH_CHECK(
221 batchsize == compressed_indices_batchsize && batchsize == plain_indices_batchsize && batchsize == values_batchsize,
222 "all batch dimensions of ", compressed_indices_name," (=", compressed_indices_batchsize, "), ", plain_indices_name," (=",
223 plain_indices_batchsize, "), and values (=", values_batchsize, ") must be equal to tensor batch dimensions (=",
224 batchsize, ")");
225
226 // A tensor constitutes of full blocks, 3.1
227 for (int i=0; i<block_ndim; i++) {
228 TORCH_CHECK(size[batch_ndim + i] % blocksize[i] == 0,
229 "tensor shape[", batch_ndim + i, "] (=", size[batch_ndim + i],
230 ") must be divisible with blocksize[", i, "] (=", blocksize[i],
231 ") as defined by values shape");
232 }
233 const int nrows = size[batch_ndim] / blocksize[0];
234 const int ncols = size[batch_ndim + 1] / blocksize[1];
235 int compressed_dim_size, plain_dim_size;
236 std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args",
237 [&] { return std::make_tuple(nrows, ncols); },
238 [&] { return std::make_tuple(ncols, nrows); });
239 // 3.8
240 TORCH_CHECK(
241 compressed_indices.size(-1) == compressed_dim_size + 1,
242 compressed_indices_name, ".shape[-1] must be equal to the number of ",
243 compressed_dim_name, "s + 1 (=", compressed_dim_size + 1, "), but got ", compressed_indices.size(-1));
244 // 3.9, 3.10
245 TORCH_CHECK(
246 plain_indices.size(-1) == values_nnz,
247 plain_indices_name, ".shape[-1] must be equal to nnz (=", values_nnz,
248 ") as defined by values.shape[", batch_ndim, "], but got ", plain_indices.size(-1));
249 // Type Invariants
250 auto compressed_indices_type = compressed_indices.scalar_type();
251 auto plain_indices_type = plain_indices.scalar_type();
252 // 1.1, 1.2, 1.3
253 TORCH_CHECK(
254 compressed_indices_type == plain_indices_type,
255 compressed_indices_name, " and ", plain_indices_name, " must have the same dtype, bot got ",
256 compressed_indices_type, " and ", plain_indices_type, ", respectively");
257 TORCH_CHECK(
258 compressed_indices_type == kInt || compressed_indices_type == kLong,
259 compressed_indices_name, " and ", plain_indices_name, " dtype must be Int or Long, but got ",
260 compressed_indices_type);
261
262 // Indices invariants
263 if (plain_indices.numel() > 0) {
264 at::_validate_compressed_sparse_indices(
265 /*is_crow = */layout == kSparseCsr || layout == kSparseBsr,
266 compressed_indices,
267 plain_indices,
268 compressed_dim_size,
269 plain_dim_size,
270 values_nnz
271 );
272 }
273
274 // Device Invariants
275 // 4.1
276 TORCH_CHECK(
277 values.device().type() == kCPU || values.device().type() == kCUDA,
278 "device type of values (",
279 values.device().type(),
280 ") must be CPU or CUDA");
281 // 4.2, 4.3, 4.4
282 TORCH_CHECK(
283 compressed_indices.get_device() == values.get_device(),
284 "device of ", compressed_indices_name, " (=",
285 compressed_indices.device(),
286 ") must match device of values (=",
287 values.device(),
288 ")");
289 TORCH_CHECK(
290 compressed_indices.get_device() == plain_indices.get_device(),
291 "device of ", compressed_indices_name, " (=",
292 compressed_indices.device(),
293 ") must match device of ", plain_indices_name," (=",
294 plain_indices.device(),
295 ")");
296}
297
298void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) {
299 _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout);
300}
301
302void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
303 _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr);
304}
305
306void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
307 _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc);
308}
309
310void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
311 _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr);
312}
313
314void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
315 _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc);
316}
317
318// Construction of CSR, CSC, BSR, and BSC tensors.
319
320// Note: The usage of "Csr" in names like SparseCsrTensor,
321// SparseCsrCPU, SparseCsrCUDA, and SparseCsrTensorImpl exists because
322// of historical reasons (that ought to be removed in future) and does
323// not mean that the corresponding functionality would be CSR layout
324// only specific.
325SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
326 // TODO: remove this comment after enabling autograd support for CSR tensor
327 // constructor.
328 // TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch());
329 Layout layout = AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(options.layout(), "new_compressed_tensor", [&] { return the_layout; });
330 DispatchKey dispatch_key;
331
332 TORCH_CHECK_NOT_IMPLEMENTED(
333 options.device().type() == kCPU || options.device().type() == kCUDA,
334 "Could not run 'new_compressed_tensor' from the '", options.device(), "' device.)");
335
336 if (options.device().is_cuda()) {
337 dispatch_key = DispatchKey::SparseCsrCUDA;
338 } else {
339 dispatch_key = DispatchKey::SparseCsrCPU;
340 }
341
342 return detail::make_tensor<SparseCsrTensorImpl>(DispatchKeySet(dispatch_key), options.device(), layout, options.dtype());
343}
344
345
346Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices,
347 const Tensor& plain_indices,
348 const Tensor& values,
349 IntArrayRef size,
350 c10::optional<ScalarType> dtype,
351 c10::optional<Layout> layout,
352 c10::optional<Device> device,
353 c10::optional<bool> pin_memory) {
354 if (!layout) {
355 AT_ERROR("sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none");
356 }
357 Layout layout_ = layout.value();
358 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
359 if (at::globalContext().checkSparseTensorInvariants()) {
360 _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
361 }
362 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
363 SparseCsrTensor self = new_compressed_tensor(options);
364 get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
365 return self;
366}
367
368template <Layout required_layout>
369Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices,
370 const Tensor& plain_indices,
371 const Tensor& values,
372 IntArrayRef size,
373 c10::optional<ScalarType> dtype,
374 c10::optional<Layout> layout,
375 c10::optional<Device> device,
376 c10::optional<bool> pin_memory) {
377 Layout layout_ = layout.value_or(required_layout);
378 TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
379 if (at::globalContext().checkSparseTensorInvariants()) {
380 _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
381 }
382 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
383 SparseCsrTensor self = new_compressed_tensor(options);
384 get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
385 return self;
386}
387
388#define SPARSE_COMPRESSED_TENSOR_UNSAFE(KIND, REQUIRED_LAYOUT) \
389 Tensor _sparse_##KIND##_tensor_unsafe(const Tensor& compressed_indices, \
390 const Tensor& plain_indices, \
391 const Tensor& values, \
392 IntArrayRef size, \
393 c10::optional<ScalarType> dtype, \
394 c10::optional<Layout> layout, \
395 c10::optional<Device> device, \
396 c10::optional<bool> pin_memory) { \
397 return _sparse_compressed_tensor_unsafe_template<REQUIRED_LAYOUT>(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \
398 }
399
400SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr);
401SPARSE_COMPRESSED_TENSOR_UNSAFE(csc, kSparseCsc);
402SPARSE_COMPRESSED_TENSOR_UNSAFE(bsr, kSparseBsr);
403SPARSE_COMPRESSED_TENSOR_UNSAFE(bsc, kSparseBsc);
404
405DimVector _estimate_sparse_compressed_tensor_size(
406 const Tensor& compressed_indices,
407 const Tensor& plain_indices,
408 const Tensor& values,
409 Layout layout) {
410 const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size", [&] { return 0; }, [&] { return 2; });
411 const int base_ndim = 2; // corresponds to compressed and plain indices
412 const int batch_ndim = compressed_indices.dim() - 1;
413 const std::string compressed_indices_name = compressedIndicesName(layout);
414 const std::string plain_indices_name = plainIndicesName(layout);
415 TORCH_CHECK(
416 batch_ndim >= 0,
417 compressed_indices_name, " must have dimensionality >= 1 but got ", compressed_indices.dim());
418 TORCH_CHECK(
419 compressed_indices.dim() == plain_indices.dim(),
420 compressed_indices_name, " and ", plain_indices_name, " dimensionalities must be equal but got ",
421 compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively");
422 const int dense_ndim = values.dim() - batch_ndim - block_ndim - 1;
423 TORCH_CHECK(
424 dense_ndim >= 0,
425 "values must have dimensionality > sum of batch and block dimensionalities (=",
426 batch_ndim, " + ", block_ndim, ") but got ", values.dim());
427 DimVector blocksize{
428 (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 1)) : 1),
429 (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 2)) : 1)
430 };
431 DimVector size = DimVector(compressed_indices.sizes().slice(0, batch_ndim));
432 int64_t compressed_dim_size = (compressed_indices.dim() > 0 && compressed_indices.size(-1) > 0 ? compressed_indices.size(-1) - 1 : 0);
433 int64_t plain_dim_size = AT_DISPATCH_INTEGRAL_TYPES(plain_indices.scalar_type(), "estimate_sparse_compressed_tensor_size",
434 [&]() -> int64_t {
435 if (plain_indices.numel() > 0) {
436 return plain_indices.max().item<scalar_t>() + 1;
437 } else {
438 return 0;
439 }
440 });
441 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size",
442 [&]{
443 size.push_back(compressed_dim_size * blocksize[0]);
444 size.push_back(plain_dim_size * blocksize[1]);
445 },
446 [&]{
447 size.push_back(plain_dim_size * blocksize[0]);
448 size.push_back(compressed_dim_size * blocksize[1]);
449 });
450 for (int i=0; i<dense_ndim; i++) {
451 int64_t j = batch_ndim + 1 + block_ndim + i;
452 size.push_back((j < values.dim() ? values.size(j) : 1));
453 }
454 TORCH_CHECK(
455 static_cast<int>(size.size()) == batch_ndim + base_ndim + dense_ndim,
456 "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=",
457 batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size());
458 return size;
459}
460
461// TODO: This constructor should probably use an ATen abstract method in order
462// to make autograd dispatch available for the CSR constructor. See the relevant
463// note in native_functions.yaml.
464Tensor sparse_compressed_tensor(
465 const Tensor& compressed_indices,
466 const Tensor& plain_indices,
467 const Tensor& values,
468 IntArrayRef size,
469 c10::optional<ScalarType> dtype,
470 c10::optional<Layout> layout,
471 c10::optional<Device> device,
472 c10::optional<bool> pin_memory) {
473
474 if (!layout) {
475 AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none");
476 }
477 Layout layout_ = layout.value();
478 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{});
479
480 // See [Note: hacky wrapper removal for TensorOptions]
481 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
482
483 return at::native::_sparse_compressed_tensor_unsafe(
484 compressed_indices,
485 plain_indices,
486 values,
487 size,
488 optTypeMetaToScalarType(options.dtype_opt()),
489 options.layout_opt(),
490 options.device_opt(),
491 options.pinned_memory_opt());
492}
493
494Tensor sparse_compressed_tensor(
495 const Tensor& compressed_indices,
496 const Tensor& plain_indices,
497 const Tensor& values,
498 c10::optional<ScalarType> dtype,
499 c10::optional<Layout> layout,
500 c10::optional<Device> device,
501 c10::optional<bool> pin_memory) {
502
503 if (!layout) {
504 AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none");
505 }
506 Layout layout_ = layout.value();
507 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{});
508
509 DimVector size = _estimate_sparse_compressed_tensor_size(compressed_indices, plain_indices, values, layout_);
510
511 // See [Note: hacky wrapper removal for TensorOptions]
512 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
513
514 return at::native::_sparse_compressed_tensor_unsafe(
515 compressed_indices,
516 plain_indices,
517 values,
518 size,
519 optTypeMetaToScalarType(options.dtype_opt()),
520 options.layout_opt(),
521 options.device_opt(),
522 options.pinned_memory_opt());
523}
524
525#define SPARSE_COMPRESSED_TENSOR(KIND, REQUIRED_LAYOUT) \
526 Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \
527 const Tensor& plain_indices, \
528 const Tensor& values, \
529 c10::optional<ScalarType> dtype, \
530 c10::optional<Layout> layout, \
531 c10::optional<Device> device, \
532 c10::optional<bool> pin_memory) { \
533 if (layout) { \
534 TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \
535 } \
536 c10::optional<Layout> layout_(REQUIRED_LAYOUT); \
537 return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, dtype, layout_, device, pin_memory); \
538 } \
539 Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \
540 const Tensor& plain_indices, \
541 const Tensor& values, \
542 IntArrayRef size, \
543 c10::optional<ScalarType> dtype, \
544 c10::optional<Layout> layout, \
545 c10::optional<Device> device, \
546 c10::optional<bool> pin_memory) { \
547 if (layout) { \
548 TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \
549 } \
550 c10::optional<Layout> layout_(REQUIRED_LAYOUT); \
551 return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, size, dtype, layout_, device, pin_memory); \
552 }
553
554SPARSE_COMPRESSED_TENSOR(csr, kSparseCsr)
555SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc)
556SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr)
557SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc)
558
559Tensor empty_sparse_compressed(
560 IntArrayRef size,
561 c10::optional<ScalarType> dtype,
562 c10::optional<Layout> layout,
563 c10::optional<Device> device,
564 c10::optional<bool> pin_memory,
565 c10::optional<MemoryFormat> optional_memory_format) {
566 check_size_nonnegative(size);
567 TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size ", size);
568
569 // Strided is the default layout for torch.empty.
570 Layout layout_ = layout.value_or(Layout::Strided);
571
572 // torch.empty cannot be used to create blocked tensors because its
573 // API lacks a method to specify the block size.
574 AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(layout_, "empty_sparse_compressed", [&]{});
575
576 int64_t nnz = 0;
577 auto compressed_indices_size = DimVector(size.slice(0, size.size() - 2));
578 auto plain_indices_and_values_size = DimVector(size.slice(0, size.size() - 2));
579 compressed_indices_size.push_back(size[compressedDimension(layout_, size)] + 1);
580 plain_indices_and_values_size.push_back(nnz);
581
582 TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
583 auto compressed_indices = at::empty(compressed_indices_size, options);
584 auto plain_indices = at::empty(plain_indices_and_values_size, options);
585 auto values = at::empty(plain_indices_and_values_size, options.dtype(dtype));
586
587 return at::native::_sparse_compressed_tensor_unsafe(compressed_indices,
588 plain_indices,
589 values,
590 size,
591 dtype,
592 layout,
593 device,
594 pin_memory);
595}
596
597const Tensor& resize_sparse_csr_(
598 const Tensor& self,
599 IntArrayRef size,
600 c10::optional<MemoryFormat> optional_memory_format) {
601 check_size_nonnegative(size);
602 TORCH_CHECK(size.size() >= 2, "torch.resize_: Only batched sparse CSR matrices are supported, but got size ", size);
603 TORCH_CHECK(
604 self.size(-1) <= size[size.size() - 1],
605 "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
606 "The original number of columns is ",
607 self.size(-1),
608 " while the requested new number of columns is ", size[size.size() - 1], ".");
609 get_sparse_csr_impl(self)->resize_(self._nnz(), size);
610 return self;
611}
612
613Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocking) {
614 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{});
615 TORCH_CHECK(
616 self.layout() == src.layout(),
617 "torch.copy_: copy of sparse compressed tensors having different layouts is not supported.",
618 " self layout is ", self.layout(), " and src layout is ", src.layout());
619 TORCH_CHECK(
620 self._nnz() == src._nnz(), // actually, values copy allows different shapes as long as operands are broadcastable
621 "torch.copy_: only sparse compressed tensors with the same number of specified elements are supported.");
622 auto self_compressed_dim = compressedDimension(self.layout(), self.sizes());
623 auto src_compressed_dim = compressedDimension(src.layout(), src.sizes());
624 auto self_compressed_dims = self.size(self_compressed_dim);
625 auto src_compressed_dims = src.size(compressedDimension(src.layout(), src.sizes()));
626 if (self_compressed_dim == src_compressed_dim) {
627 TORCH_CHECK(self_compressed_dims == src_compressed_dims,
628 "torch.copy_: expected shapes of self and src to match along dimension ",
629 self_compressed_dim, " for ",
630 self.layout(), " layout but the corresponding dimensions of self and src are ",
631 self_compressed_dims, " and ", src_compressed_dims, ", respectively.");
632 } else {
633 TORCH_CHECK(self_compressed_dims == src_compressed_dims,
634 "torch.copy_: expected shapes of self and src to match along dimensions ",
635 self_compressed_dim, " and ", src_compressed_dim, ", respectively, for ",
636 self.layout(), " layout but the corresponding dimensions of self and src are ",
637 self_compressed_dims, " and ", src_compressed_dims, ", respectively.");
638 }
639 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
640 [&]{},
641 [&]{
642 auto self_values = self.values();
643 auto src_values = src.values();
644 auto self_blocksize = DimVector(self_values.sizes().slice(self_values.dim()-2, 2));
645 auto src_blocksize = DimVector(src_values.sizes().slice(src_values.dim()-2, 2));
646 TORCH_CHECK(self_blocksize == src_blocksize,
647 "torch.copy_: copy of sparse compressed tensors having different block sizes is not supported.",
648 " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectively.");
649 });
650 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
651 [&]{
652 self.crow_indices().copy_(src.crow_indices(), non_blocking);
653 self.col_indices().copy_(src.col_indices(), non_blocking);
654 },
655 [&]{
656 self.ccol_indices().copy_(src.ccol_indices(), non_blocking);
657 self.row_indices().copy_(src.row_indices(), non_blocking);
658 });
659 self.values().copy_(src.values(), non_blocking);
660 return self;
661}
662
663// Access members of CSR tensors.
664int64_t _nnz_sparse_csr(const SparseCsrTensor& self) {
665 return get_sparse_csr_impl(self)->nnz();
666}
667
668Tensor values_sparse_csr(const Tensor& self) {
669 return get_sparse_csr_impl(self)->values().alias();
670}
671
672Tensor crow_indices_sparse_csr(const Tensor& self) {
673 return AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(self.layout(),
674 "crow_indices",
675 [&]{ return get_sparse_csr_impl(self)->compressed_indices().alias(); });
676}
677
678Tensor col_indices_sparse_csr(const Tensor& self) {
679 return AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(self.layout(),
680 "col_indices",
681 [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); });
682}
683
684Tensor ccol_indices_sparse_csr(const Tensor& self) {
685 return AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(self.layout(),
686 "ccol_indices",
687 [&]{ return get_sparse_csr_impl(self)->compressed_indices().alias(); });
688}
689
690Tensor row_indices_sparse_csr(const Tensor& self) {
691 return AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(self.layout(),
692 "row_indices",
693 [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); });
694}
695
696Tensor crow_indices_default(const Tensor& self) {
697 TORCH_CHECK(false, "crow_indices expected sparse row compressed tensor layout but got ", self.layout());
698}
699
700Tensor col_indices_default(const Tensor& self) {
701 TORCH_CHECK(false, "col_indices expected sparse row compressed tensor layout but got ", self.layout());
702}
703
704Tensor ccol_indices_default(const Tensor& self) {
705 TORCH_CHECK(false, "ccol_indices expected sparse column compressed tensor layout but got ", self.layout());
706}
707
708Tensor row_indices_default(const Tensor& self) {
709 TORCH_CHECK(false, "row_indices expected sparse column compressed tensor layout but got ", self.layout());
710}
711
712int64_t sparse_dim_sparse_csr(const SparseCsrTensor& self) {
713 return get_sparse_csr_impl(self)->sparse_dim();
714}
715
716int64_t dense_dim_sparse_csr(const SparseCsrTensor& self) {
717 return get_sparse_csr_impl(self)->dense_dim();
718}
719
720bool _is_same_size_as_sparse_compressed(
721 const SparseCsrTensor& self,
722 const SparseCsrTensor& src) {
723 return self.sizes().equals(src.sizes());
724}
725
726const SparseCsrTensor& resize_as_sparse_compressed_(
727 const SparseCsrTensor& self,
728 const SparseCsrTensor& src) {
729 auto src_layout = src.layout();
730 auto self_layout = self.layout();
731 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
732 src_layout, "resize_as_sparse_compressed_: src ", []() {});
733 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
734 self_layout, "resize_as_sparse_compressed_: self ", []() {});
735 // Note: The impl method does all required checking to see if resize/data copy
736 // on member tensors is required.
737 get_sparse_csr_impl(self)->resize_as_sparse_compressed_tensor_(src);
738 return self;
739}
740
741SparseCsrTensor clone_sparse_compressed(
742 const SparseCsrTensor& self,
743 c10::optional<c10::MemoryFormat> optional_memory_format) {
744 TORCH_CHECK(
745 !optional_memory_format.has_value(),
746 "unsupported memory format option ",
747 optional_memory_format.value());
748 TensorOptions options = self.options();
749 auto compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(),
750 "clone_sparse_compressed",
751 [&]{ return self.crow_indices(); },
752 [&]{ return self.ccol_indices(); });
753 auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(),
754 "clone_sparse_compressed",
755 [&]{ return self.col_indices(); },
756 [&]{ return self.row_indices(); });
757 return at::native::_sparse_compressed_tensor_unsafe(
758 compressed_indices.clone(),
759 plain_indices.clone(),
760 self.values().clone(),
761 self.sizes(),
762 optTypeMetaToScalarType(options.dtype_opt()),
763 options.layout_opt(),
764 options.device_opt(),
765 options.pinned_memory_opt());
766}
767
768Tensor empty_like_sparse_csr(
769 const Tensor& self,
770 c10::optional<ScalarType> dtype,
771 c10::optional<Layout> layout,
772 c10::optional<Device> device,
773 c10::optional<bool> pin_memory,
774 c10::optional<c10::MemoryFormat> optional_memory_format) {
775 TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
776 TensorOptions options =
777 self.options()
778 .merge_in(options_)
779 .merge_memory_format(optional_memory_format);
780
781 TORCH_CHECK(options.layout() == self.layout(),
782 "empty_like with different sparse layout is not supported (self is ",
783 self.layout(), " but you requested ", options.layout(), ")");
784 if (options.layout() == kSparseCsr) {
785 auto result = at::native::_sparse_csr_tensor_unsafe(
786 self.crow_indices().clone(),
787 self.col_indices().clone(),
788 at::empty(self.values().sizes(), options.layout(kStrided)),
789 self.sizes(),
790 optTypeMetaToScalarType(options.dtype()),
791 self.layout(),
792 options.device());
793 return result;
794 } else if (options.layout() == kSparseCsc) {
795 auto result = at::native::_sparse_csc_tensor_unsafe(
796 self.ccol_indices().clone(),
797 self.row_indices().clone(),
798 at::empty(self.values().sizes(), options.layout(kStrided)),
799 self.sizes(),
800 optTypeMetaToScalarType(options.dtype()),
801 self.layout(),
802 options.device());
803 return result;
804 } else if (options.layout() == kSparseBsr) {
805 auto result = at::native::_sparse_bsr_tensor_unsafe(
806 self.crow_indices().clone(),
807 self.col_indices().clone(),
808 at::empty(self.values().sizes(), options.layout(kStrided)),
809 self.sizes(),
810 optTypeMetaToScalarType(options.dtype()),
811 self.layout(),
812 options.device());
813
814 return result;
815 } else if (options.layout() == kSparseBsc) {
816 auto result = at::native::_sparse_bsc_tensor_unsafe(
817 self.ccol_indices().clone(),
818 self.row_indices().clone(),
819 at::empty(self.values().sizes(), options.layout(kStrided)),
820 self.sizes(),
821 optTypeMetaToScalarType(options.dtype()),
822 self.layout(),
823 options.device());
824 return result;
825 } else if (options.layout() == kStrided) {
826 return at::native::empty_like(self, dtype, layout, device, pin_memory, optional_memory_format);
827 } else {
828 TORCH_CHECK(false, "Layout ", options.layout(), " is not supported");
829 }
830}
831
832template <bool require_view, bool require_copy>
833Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index) {
834 constexpr const char* select_name = (require_view ? "select()" : "select_copy()");
835 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
836 self.layout(), "select", []() { return; });
837 TORCH_CHECK_INDEX(
838 self.dim() != 0, select_name, " cannot be applied to a 0-dim tensor.");
839 dim = maybe_wrap_dim(dim, self.dim());
840 auto size = self.size(dim);
841 if (index < -size || index >= size) {
842 TORCH_CHECK_INDEX(
843 false,
844 select_name, ": index ",
845 index,
846 " out of range for tensor of size ",
847 self.sizes(),
848 " at dimension ",
849 dim);
850 }
851 if (index < 0) {
852 index += size;
853 }
854
855 auto select_strided = [](const Tensor& self, int64_t dim, int64_t index) {
856 if (require_copy) {
857 return at::select_copy(self, dim, index);
858 } else {
859 return self.select(dim, index);
860 }
861 };
862
863 TORCH_INTERNAL_ASSERT(dim >= 0 && dim < self.dim());
864
865 auto new_sizes = DimVector(self.sizes());
866 new_sizes.erase(new_sizes.begin() + dim);
867 auto options = self.options();
868
869 Tensor plain_indices;
870 Tensor compressed_indices;
871 std::tie(compressed_indices, plain_indices) =
872 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
873 self.layout(),
874 "select",
875 [&]() {
876 return std::make_pair(self.crow_indices(), self.col_indices());
877 },
878 [&]() {
879 return std::make_pair(self.ccol_indices(), self.row_indices());
880 });
881 auto n_batch = compressed_indices.dim() - 1;
882
883 if (dim < n_batch) {
884 // Selecting batch dimension
885 return at::native::_sparse_compressed_tensor_unsafe(
886 compressed_indices.select(dim, index),
887 plain_indices.select(dim, index),
888 select_strided(self.values(), dim, index),
889 new_sizes,
890 optTypeMetaToScalarType(options.dtype_opt()),
891 options.layout_opt(),
892 options.device_opt(),
893 options.pinned_memory_opt());
894 } else if (dim < n_batch + 2) {
895 // Selecting sparse dimension
896 TORCH_CHECK(
897 n_batch == 0,
898 select_name, ": selecting sparse dimensions is not implemented for batched sparse compressed tensors.")
899 TORCH_INTERNAL_ASSERT(dim == 0 || dim == 1);
900
901 DimVector blocksize{1, 1};
902 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&] {}, [&] {
903 blocksize[0] = std::max<int64_t>(1, self.values().size(n_batch + 1));
904 blocksize[1] = std::max<int64_t>(1, self.values().size(n_batch + 2));
905 });
906
907 auto indices_options = compressed_indices.options();
908 int64_t fast_dim = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() { return 0; }, [&]() { return 1; });
909 int64_t other_dim = (dim == 0 ? 1 : 0);
910 Tensor indices;
911 Tensor values;
912 bool is_view = dim == fast_dim;
913 if (is_view) {
914 // select is always a view operation
915 Tensor start_end = compressed_indices.narrow(0, index / blocksize[dim], 2).cpu();
916 int64_t start = start_end[0].item<int64_t>();
917 int64_t end = start_end[1].item<int64_t>();
918 indices = plain_indices.slice(0, start, end);
919 values = self.values().slice(0, start, end);
920 } else {
921 Tensor decompressed_indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
922 .select(0, 0);
923
924 Tensor dim_indices = at::where(plain_indices.eq(index / blocksize[dim]))[0];
925 // Notice that dim_indices is a sorted sequence of non-negative
926 // distinct integers. Below we'll try to solve `dim_indices ==
927 // arange(start, stop, step)`. If the solution exists then the
928 // select will be a view operation also for the `dim !=
929 // fast_dim` case.
930 int64_t start{}, end{}, step{};
931 if (solve_arange(dim_indices, start, end, step)) {
932 indices = decompressed_indices.slice(0, start, end, step);
933 values = self.values().slice(0, start, end, step);
934 is_view = true;
935 } else {
936 // select will be a copy operation due to index_select!
937 indices = decompressed_indices.index_select(0, dim_indices);
938 values = self.values().index_select(0, dim_indices);
939 }
940 }
941
942 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() {},
943 [&]() {
944 /*
945 The formula for select indices and values below are best
946 explained by an example. Consider a BSR tensor with a
947 block size (2, 3) having four blocks (the other two blocks
948 contain all zeros and hence will not be specified):
949
950 [ 1 2 3] | [ 7 8 9]
951 [ 4 5 6] | [10 11 12]
952 ---------------------
953 [13 14 15] | [ 0 0 0]
954 [16 17 18] | [ 0 0 0]
955 -----------------------
956 [ 0 0 0] | [19 20 21]
957 [ 0 0 0] | [22 23 24]
958
959 that represents a 6 x 6 tensor:
960
961 [ 1 2 3 7 8 9 ]
962 [ 4 5 6 10 11 12 ]
963 [ 13 14 15 0 0 0 ]
964 [ 16 17 18 0 0 0 ]
965 [ 0 0 0 19 20 21 ]
966 [ 0 0 0 22 23 24 ]
967
968 The corresponding data for the BSR representation is:
969
970 crow_indices = [0 2 3 4]
971 col_indices = [0 1 0 1]
972 values = [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]], [[13 14 15], [16 17 18]], [[19 20 21], [22 23 24]] ]
973 shape = (6, 6)
974
975 From crow_indices, we can find that
976
977 row_indices = [0 0 1 2]
978
979 In the following, we'll illustrate the details of
980 computing the result of torch.select_copy(input, dim,
981 index) where dim is 0 or 1, and index is in
982 range(shape[dim]).
983
984 Select a row of a BSR tensor
985 ----------------------------
986
987 We will consider first the dim=0 case that corresponds to
988 selecting a index-th row of the tensor. For instance, for
989 dim=0 and index=1, the expected result would represent a
990 1D tensor:
991
992 [ 4 5 6 10 11 12 ]
993
994 that is a concatenated tensor of certain slices from the
995 first and the second block that is computed as follows:
996
997 values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1)
998 -> values[[0, 1]][:, 1 % 2].flatten(0, 1)
999 -> [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]] ][:, 1].flatten(0, 1)
1000 -> [ [4 5 6], [10 11 12]].flatten(0, 1)
1001 -> [ 4 5 6 10 11 12]
1002
1003 where dim_indices is found as
1004
1005 where(row_indices == index//blocksize[dim])
1006 -> where([0 0 1 2] == 1//2)
1007 -> [0 1]
1008
1009 The corresponding column indices are computed as
1010
1011 (col_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1012
1013 where other_dim is 1 if dim is 0, and 0 if dim is 1. Let's
1014 expand the above expression with the data in the example:
1015
1016 -> (col_indices[[0, 1]].mul(3).unsqueeze(1) + arange(3).unsqueeze(0)).flatten(0, 1)
1017 -> ([[0 1].mul(3).unsqueeze(1) + [[0 1 2]]).flatten(0, 1)
1018 -> ([[[0], [3]] + [[0 1 2]]).flatten(0, 1) <- here addition will use broadcasting rules!
1019 -> ([[[0 1 2], [3 4 5]]).flatten(0, 1)
1020 -> [0 1 2 3 4 5]
1021
1022 Finally, the select(dim=0, index=1) op on the given sparse
1023 compressed tensors will return a COO tensor:
1024
1025 sparse_coo_tensor([0 1 2 3 4 5].unsqueeze(0), [4 5 6 10 11 12], (6,))
1026
1027 that represents the expected result: [ 4 5 6 10 11 12 ]
1028
1029 Select a column of a BSR tensor
1030 -------------------------------
1031
1032 Next, we'll consider the dim=1 case that corresponds to
1033 selecting the index-th column of the tensor. For instance,
1034 for dim=1 and index=4, the expected result would represent
1035 a 1D tensor:
1036
1037 [ 8 11 0 0 20 23]
1038
1039 that is a concatenated tensor of certain slices from the
1040 second and the last block:
1041
1042 values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1)
1043 -> values[[1, 3]][:, :, 4 % 3 ].flatten(0, 1)
1044 -> [ [[7 8 9], [10 11 12]], [[19 20 21], [22 23 24]] ][:, 1, 1].flatten(0, 1)
1045 -> [ [8 11], [20 23]].flatten(0, 1)
1046 -> [ 8 11 20 23 ]
1047
1048 The corresponding row indices are computed as
1049
1050 (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1051
1052 where dim_indices is
1053
1054 where(col_indices == index//blocksize[dim])
1055 -> where([0 1 0 1] == 4//3)
1056 -> [1 3]
1057
1058 and we have
1059
1060 (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1061 -> (row_indices[[1 3]].mul(2).unsqueeze(1) + arange(2).unsqueeze(0)).flatten(0, 1)
1062 -> ([0 4].unsqueeze(1) + [0 1].unsqueeze(0)).flatten(0, 1)
1063 -> ([[0], [4]] + [[0 1]]).flatten(0, 1) <- here addition will use broadcasting rules!
1064 -> ([[0 1], [4 5]]).flatten(0, 1)
1065 -> [ 0 1 4 5 ]
1066
1067 Finally, the select(dim=1, index=4) op on the given sparse
1068 compressed tensors will return a COO tensor:
1069
1070 sparse_coo_tensor([0 1 4 5].unsqueeze(0), [8 11 20 23], (6,))
1071
1072 that represents the expected result: [ 8 11 0 0 20 23 ]
1073
1074 */
1075 Tensor subblock_indices = at::arange(0, blocksize[other_dim], indices_options);
1076 indices = indices.mul(blocksize[other_dim]).unsqueeze(1).add(subblock_indices.unsqueeze(0)).flatten(0, 1);
1077 values = values.select(dim + 1, index % blocksize[dim]).flatten(0, 1);
1078 // flatten(0, 1) can be a view or a copy operation. If view
1079 // is required, it will be checked below via is_alias_of,
1080 // otherwise, we'll check if copy is made here to avoid
1081 // unnecessary clone below:
1082 if (require_copy) {
1083 is_view = values.is_alias_of(self.values());
1084 }
1085 });
1086
1087 if (require_view) {
1088 TORCH_CHECK(values.is_alias_of(self.values()), select_name,
1089 ": no view exists for the given input, consider using torch.select_copy.");
1090 }
1091
1092 indices = indices.unsqueeze(0).to(kLong);
1093 if (require_copy && is_view) {
1094 values = values.clone();
1095 }
1096 return at::_sparse_coo_tensor_unsafe(indices, values, new_sizes)._coalesced_(true);
1097 } else {
1098 // Selecting dense dimension
1099 Tensor new_values = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
1100 self.layout(),
1101 "select",
1102 // Non blocked layout (2 sparse dims become 1 nnz dim in values, so dim
1103 // is found one position to the left)
1104 [&]() { return select_strided(self.values(), dim - 1, index); },
1105 // Block layout (2 sparse dims become 1 nnz dim + 2 block-shape dims in
1106 // values, so dim is found 1 position to the right)
1107 [&]() { return select_strided(self.values(), dim + 1, index); });
1108 return at::native::_sparse_compressed_tensor_unsafe(
1109 compressed_indices,
1110 plain_indices,
1111 new_values,
1112 new_sizes,
1113 optTypeMetaToScalarType(options.dtype_opt()),
1114 options.layout_opt(),
1115 options.device_opt(),
1116 options.pinned_memory_opt());
1117 }
1118}
1119
1120Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
1121 return select_sparse_csr_worker<true, false>(self, dim, index);
1122}
1123
1124Tensor select_copy_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
1125 return select_sparse_csr_worker<false, true>(self, dim, index);
1126}
1127
1128} // namespace native
1129} // namespace at
1130