1 | // Basic functions on sparse tensors |
2 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
3 | |
4 | #include <ATen/core/Tensor.h> |
5 | #include <ATen/Dispatch.h> |
6 | #include <ATen/InitialTensorOptions.h> |
7 | #include <ATen/Layout.h> |
8 | #include <ATen/Parallel.h> |
9 | #include <ATen/SparseCsrTensorImpl.h> |
10 | #include <ATen/SparseCsrTensorUtils.h> |
11 | #include <ATen/SparseTensorImpl.h> |
12 | #include <ATen/native/LinearAlgebraUtils.h> |
13 | |
14 | #ifndef AT_PER_OPERATOR_HEADERS |
15 | #include <ATen/Functions.h> |
16 | #include <ATen/NativeFunctions.h> |
17 | #else |
18 | #include <ATen/ops/_convert_indices_from_csr_to_coo.h> |
19 | #include <ATen/ops/_nnz_native.h> |
20 | #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h> |
21 | #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h> |
22 | #include <ATen/ops/_sparse_csc_tensor_unsafe_native.h> |
23 | #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h> |
24 | #include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h> |
25 | #include <ATen/ops/_sparse_coo_tensor_unsafe_native.h> |
26 | #include <ATen/ops/_sparse_coo_tensor_unsafe.h> |
27 | #include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h> |
28 | #include <ATen/ops/_validate_sparse_csr_tensor_args_native.h> |
29 | #include <ATen/ops/_validate_sparse_csc_tensor_args_native.h> |
30 | #include <ATen/ops/_validate_sparse_bsr_tensor_args_native.h> |
31 | #include <ATen/ops/_validate_sparse_bsc_tensor_args_native.h> |
32 | #include <ATen/ops/aminmax.h> |
33 | #include <ATen/ops/ccol_indices_native.h> |
34 | #include <ATen/ops/clone_native.h> |
35 | #include <ATen/ops/col_indices_native.h> |
36 | #include <ATen/ops/copy_native.h> |
37 | #include <ATen/ops/crow_indices_native.h> |
38 | #include <ATen/ops/dense_dim_native.h> |
39 | #include <ATen/ops/empty.h> |
40 | #include <ATen/ops/empty_like_native.h> |
41 | #include <ATen/ops/empty_native.h> |
42 | #include <ATen/ops/resize_as_sparse_native.h> |
43 | #include <ATen/ops/resize_native.h> |
44 | #include <ATen/ops/row_indices_native.h> |
45 | #include <ATen/ops/select_native.h> |
46 | #include <ATen/ops/select_copy.h> |
47 | #include <ATen/ops/select_copy_native.h> |
48 | #include <ATen/ops/sparse_compressed_tensor_native.h> |
49 | #include <ATen/ops/sparse_csr_tensor_native.h> |
50 | #include <ATen/ops/sparse_csc_tensor_native.h> |
51 | #include <ATen/ops/sparse_bsr_tensor_native.h> |
52 | #include <ATen/ops/sparse_bsc_tensor_native.h> |
53 | #include <ATen/ops/sparse_dim_native.h> |
54 | #include <ATen/ops/values_native.h> |
55 | #include <ATen/ops/_validate_compressed_sparse_indices.h> |
56 | #include <ATen/ops/where.h> |
57 | #endif |
58 | |
59 | namespace at { |
60 | namespace native { |
61 | |
62 | using namespace at::sparse_csr; |
63 | |
64 | namespace { |
65 | |
66 | bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& step) { |
67 | /* |
68 | This function solves the equation |
69 | |
70 | input == arange(start, end, step) |
71 | |
72 | for integers start, end, and step, if possible. If the solution |
73 | exists, returns true. |
74 | */ |
75 | int64_t n = input.numel(); |
76 | if (n == 0) { |
77 | // a trivial solution |
78 | start = end = 0; |
79 | step = 1; |
80 | } else if (n == 1) { |
81 | // a simple solution |
82 | start = input[0].item<int64_t>(); |
83 | end = start + 1; |
84 | step = 1; |
85 | } else { |
86 | Tensor first_last = input.slice(0, 0, n, n - 1).cpu(); |
87 | int64_t start_candidate = first_last[0].item<int64_t>(); |
88 | int64_t end_candidate = first_last[1].item<int64_t>() + 1; |
89 | if (end_candidate - start_candidate == n) { |
90 | // a special solution |
91 | start = start_candidate; |
92 | end = end_candidate; |
93 | step = 1; |
94 | } else { |
95 | // detect if general solution exists |
96 | Tensor possible_steps = input.slice(0, 1).sub(input.slice(0, 0, n - 1)); |
97 | Tensor possible_step = possible_steps[0]; |
98 | if ((possible_steps.eq(possible_step)).all().item<bool>()) { |
99 | start = start_candidate; |
100 | end = end_candidate; |
101 | step = possible_step.item<int64_t>(); |
102 | } else { |
103 | // no solution |
104 | return false; |
105 | } |
106 | } |
107 | } |
108 | return true; |
109 | } |
110 | |
111 | } // end anonymous namespace |
112 | |
113 | /* |
114 | Validate the arguments to sparse compressed (CSR, CSC, BSR, and BSC) |
115 | tensor factory functions. |
116 | |
117 | The CSR and BSR invariants for PyTorch are outlined in |
118 | |
119 | https://pearu.github.io/csr_tensor_invariants.html |
120 | https://pearu.github.io/bsr_tensor_invariants.html |
121 | |
122 | that in what follows are generalized for all sparse compressed |
123 | formats with support to batched and dense dimensions. |
124 | */ |
125 | |
126 | void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) { |
127 | // Layout must be Sparse Compressed, 2.4 |
128 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args" , [&]{}); |
129 | |
130 | const std::string layout_name = layoutToString(layout, /*upper=*/ true); |
131 | const std::string compressed_indices_name = compressedIndicesName(layout); |
132 | const std::string plain_indices_name = plainIndicesName(layout); |
133 | const std::string compressed_dim_name = compressedDimName(layout); |
134 | const std::string plain_dim_name = plainDimName(layout); |
135 | |
136 | // Layout Invariants |
137 | |
138 | // Re 3.5 and 3.6: in the case of compressed/plain indices tensors, |
139 | // we require contiguity per-patch basis, that is, the last stride |
140 | // of these indices must be 1. The reasoning for this is that |
141 | // indices tensors within a patch are "atomic" in the sense that |
142 | // sliced compressed/plain indices would not represent the indices |
143 | // of any sparse compressed tensor as the slicing would break the |
144 | // description of the tensor index structure. |
145 | |
146 | // 2.1 |
147 | TORCH_CHECK(plain_indices.layout() == kStrided, |
148 | "expected " , plain_indices_name, " to be a strided tensor but got " , plain_indices.layout(), " tensor" ); |
149 | |
150 | // 2.2 |
151 | TORCH_CHECK(compressed_indices.layout() == kStrided, |
152 | "expected " , compressed_indices_name, " to be a strided tensor but got " , compressed_indices.layout(), " tensor" ); |
153 | |
154 | const int base_ndim = 2; // corresponds to compressed and plain indices |
155 | const int batch_ndim = compressed_indices.dim() - 1; |
156 | const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( |
157 | layout, "validate_sparse_compressed_tensor_args" , |
158 | [&] { return 0; }, [&] { return 2; }); |
159 | const int dense_ndim = values.dim() - batch_ndim - block_ndim - 1; |
160 | |
161 | // 2.3 |
162 | TORCH_CHECK(values.layout() == kStrided, |
163 | "expected values to be a strided tensor but got " , values.layout(), " tensor" ); |
164 | |
165 | // 3.7 is dropped, that is, values tensor does not need to be |
166 | // contiguous, in general. Particular algorithms on sparse |
167 | // compressed tensors may require contiguity though. |
168 | |
169 | // Shape and Strides invariants |
170 | |
171 | // 3.2 |
172 | TORCH_CHECK( |
173 | batch_ndim >= 0, |
174 | compressed_indices_name, " must have dimensionality >= 1 but got " , compressed_indices.dim()); |
175 | |
176 | // 3.3 |
177 | TORCH_CHECK( |
178 | compressed_indices.dim() == plain_indices.dim(), |
179 | compressed_indices_name, " and " , plain_indices_name, " dimensionalities must be equal but got " , |
180 | compressed_indices.dim(), " and " , plain_indices.dim(), ", respectively" ); |
181 | |
182 | // 3.4 |
183 | TORCH_CHECK( |
184 | dense_ndim >= 0, |
185 | "values must have dimensionality > sum of batch and block dimensionalities (=" , |
186 | batch_ndim, " + " , block_ndim, ") but got " , values.dim()); |
187 | |
188 | // 3.5 |
189 | TORCH_CHECK(plain_indices.stride(-1) == 1, |
190 | "expected " , plain_indices_name, " to be a contiguous tensor per batch" ); |
191 | |
192 | // 3.6 |
193 | TORCH_CHECK(compressed_indices.stride(-1) == 1, |
194 | "expected " , compressed_indices_name, " to be a contiguous tensor per batch" ); |
195 | |
196 | // 3.1 |
197 | TORCH_CHECK( |
198 | static_cast<int>(size.size()) == batch_ndim + base_ndim + dense_ndim, |
199 | "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=" , |
200 | batch_ndim, " + " , base_ndim, " + " , dense_ndim, ") but got " , size.size()); |
201 | |
202 | // For CSR/CSC formats, we define blocksize=(1, 1) so that checking |
203 | // the sparse compressed tensor invariants can be unified with the |
204 | // BSR/BSC invariants. |
205 | // 3.10 |
206 | DimVector blocksize{ |
207 | (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 1)) : 1), |
208 | (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 2)) : 1), |
209 | }; |
210 | TORCH_INTERNAL_ASSERT(blocksize.size() == 2 && blocksize[0] > 0 && blocksize[1] > 0); |
211 | |
212 | // All batch sizes must be the same and consistent with tensor batchsize, 3.1, 3.8, 3.9, 3.10 |
213 | DimVector batchsize = DimVector(size.slice(0, batch_ndim)); |
214 | DimVector compressed_indices_batchsize = DimVector(compressed_indices.sizes().slice(0, batch_ndim)); |
215 | DimVector plain_indices_batchsize = DimVector(plain_indices.sizes().slice(0, batch_ndim)); |
216 | DimVector values_batchsize = DimVector(values.sizes().slice(0, batch_ndim)); |
217 | const int values_nnz = (values.numel() ? values.size(batch_ndim) : 0); |
218 | DimVector values_blocksize = DimVector(values.sizes().slice(batch_ndim + 1, block_ndim)); |
219 | DimVector values_densesize = DimVector(values.sizes().slice(batch_ndim + 1 + block_ndim, dense_ndim)); |
220 | TORCH_CHECK( |
221 | batchsize == compressed_indices_batchsize && batchsize == plain_indices_batchsize && batchsize == values_batchsize, |
222 | "all batch dimensions of " , compressed_indices_name," (=" , compressed_indices_batchsize, "), " , plain_indices_name," (=" , |
223 | plain_indices_batchsize, "), and values (=" , values_batchsize, ") must be equal to tensor batch dimensions (=" , |
224 | batchsize, ")" ); |
225 | |
226 | // A tensor constitutes of full blocks, 3.1 |
227 | for (int i=0; i<block_ndim; i++) { |
228 | TORCH_CHECK(size[batch_ndim + i] % blocksize[i] == 0, |
229 | "tensor shape[" , batch_ndim + i, "] (=" , size[batch_ndim + i], |
230 | ") must be divisible with blocksize[" , i, "] (=" , blocksize[i], |
231 | ") as defined by values shape" ); |
232 | } |
233 | const int nrows = size[batch_ndim] / blocksize[0]; |
234 | const int ncols = size[batch_ndim + 1] / blocksize[1]; |
235 | int compressed_dim_size, plain_dim_size; |
236 | std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args" , |
237 | [&] { return std::make_tuple(nrows, ncols); }, |
238 | [&] { return std::make_tuple(ncols, nrows); }); |
239 | // 3.8 |
240 | TORCH_CHECK( |
241 | compressed_indices.size(-1) == compressed_dim_size + 1, |
242 | compressed_indices_name, ".shape[-1] must be equal to the number of " , |
243 | compressed_dim_name, "s + 1 (=" , compressed_dim_size + 1, "), but got " , compressed_indices.size(-1)); |
244 | // 3.9, 3.10 |
245 | TORCH_CHECK( |
246 | plain_indices.size(-1) == values_nnz, |
247 | plain_indices_name, ".shape[-1] must be equal to nnz (=" , values_nnz, |
248 | ") as defined by values.shape[" , batch_ndim, "], but got " , plain_indices.size(-1)); |
249 | // Type Invariants |
250 | auto compressed_indices_type = compressed_indices.scalar_type(); |
251 | auto plain_indices_type = plain_indices.scalar_type(); |
252 | // 1.1, 1.2, 1.3 |
253 | TORCH_CHECK( |
254 | compressed_indices_type == plain_indices_type, |
255 | compressed_indices_name, " and " , plain_indices_name, " must have the same dtype, bot got " , |
256 | compressed_indices_type, " and " , plain_indices_type, ", respectively" ); |
257 | TORCH_CHECK( |
258 | compressed_indices_type == kInt || compressed_indices_type == kLong, |
259 | compressed_indices_name, " and " , plain_indices_name, " dtype must be Int or Long, but got " , |
260 | compressed_indices_type); |
261 | |
262 | // Indices invariants |
263 | if (plain_indices.numel() > 0) { |
264 | at::_validate_compressed_sparse_indices( |
265 | /*is_crow = */layout == kSparseCsr || layout == kSparseBsr, |
266 | compressed_indices, |
267 | plain_indices, |
268 | compressed_dim_size, |
269 | plain_dim_size, |
270 | values_nnz |
271 | ); |
272 | } |
273 | |
274 | // Device Invariants |
275 | // 4.1 |
276 | TORCH_CHECK( |
277 | values.device().type() == kCPU || values.device().type() == kCUDA, |
278 | "device type of values (" , |
279 | values.device().type(), |
280 | ") must be CPU or CUDA" ); |
281 | // 4.2, 4.3, 4.4 |
282 | TORCH_CHECK( |
283 | compressed_indices.get_device() == values.get_device(), |
284 | "device of " , compressed_indices_name, " (=" , |
285 | compressed_indices.device(), |
286 | ") must match device of values (=" , |
287 | values.device(), |
288 | ")" ); |
289 | TORCH_CHECK( |
290 | compressed_indices.get_device() == plain_indices.get_device(), |
291 | "device of " , compressed_indices_name, " (=" , |
292 | compressed_indices.device(), |
293 | ") must match device of " , plain_indices_name," (=" , |
294 | plain_indices.device(), |
295 | ")" ); |
296 | } |
297 | |
298 | void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) { |
299 | _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout); |
300 | } |
301 | |
302 | void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) { |
303 | _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr); |
304 | } |
305 | |
306 | void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) { |
307 | _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc); |
308 | } |
309 | |
310 | void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) { |
311 | _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr); |
312 | } |
313 | |
314 | void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) { |
315 | _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc); |
316 | } |
317 | |
318 | // Construction of CSR, CSC, BSR, and BSC tensors. |
319 | |
320 | // Note: The usage of "Csr" in names like SparseCsrTensor, |
321 | // SparseCsrCPU, SparseCsrCUDA, and SparseCsrTensorImpl exists because |
322 | // of historical reasons (that ought to be removed in future) and does |
323 | // not mean that the corresponding functionality would be CSR layout |
324 | // only specific. |
325 | SparseCsrTensor new_compressed_tensor(const TensorOptions& options) { |
326 | // TODO: remove this comment after enabling autograd support for CSR tensor |
327 | // constructor. |
328 | // TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch()); |
329 | Layout layout = AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(options.layout(), "new_compressed_tensor" , [&] { return the_layout; }); |
330 | DispatchKey dispatch_key; |
331 | |
332 | TORCH_CHECK_NOT_IMPLEMENTED( |
333 | options.device().type() == kCPU || options.device().type() == kCUDA, |
334 | "Could not run 'new_compressed_tensor' from the '" , options.device(), "' device.)" ); |
335 | |
336 | if (options.device().is_cuda()) { |
337 | dispatch_key = DispatchKey::SparseCsrCUDA; |
338 | } else { |
339 | dispatch_key = DispatchKey::SparseCsrCPU; |
340 | } |
341 | |
342 | return detail::make_tensor<SparseCsrTensorImpl>(DispatchKeySet(dispatch_key), options.device(), layout, options.dtype()); |
343 | } |
344 | |
345 | |
346 | Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices, |
347 | const Tensor& plain_indices, |
348 | const Tensor& values, |
349 | IntArrayRef size, |
350 | c10::optional<ScalarType> dtype, |
351 | c10::optional<Layout> layout, |
352 | c10::optional<Device> device, |
353 | c10::optional<bool> pin_memory) { |
354 | if (!layout) { |
355 | AT_ERROR("sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none" ); |
356 | } |
357 | Layout layout_ = layout.value(); |
358 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe" , [&]{}); |
359 | if (at::globalContext().checkSparseTensorInvariants()) { |
360 | _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); |
361 | } |
362 | TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); |
363 | SparseCsrTensor self = new_compressed_tensor(options); |
364 | get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size); |
365 | return self; |
366 | } |
367 | |
368 | template <Layout required_layout> |
369 | Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices, |
370 | const Tensor& plain_indices, |
371 | const Tensor& values, |
372 | IntArrayRef size, |
373 | c10::optional<ScalarType> dtype, |
374 | c10::optional<Layout> layout, |
375 | c10::optional<Device> device, |
376 | c10::optional<bool> pin_memory) { |
377 | Layout layout_ = layout.value_or(required_layout); |
378 | TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be " ,required_layout, " but got " , layout_); |
379 | if (at::globalContext().checkSparseTensorInvariants()) { |
380 | _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); |
381 | } |
382 | TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); |
383 | SparseCsrTensor self = new_compressed_tensor(options); |
384 | get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size); |
385 | return self; |
386 | } |
387 | |
388 | #define SPARSE_COMPRESSED_TENSOR_UNSAFE(KIND, REQUIRED_LAYOUT) \ |
389 | Tensor _sparse_##KIND##_tensor_unsafe(const Tensor& compressed_indices, \ |
390 | const Tensor& plain_indices, \ |
391 | const Tensor& values, \ |
392 | IntArrayRef size, \ |
393 | c10::optional<ScalarType> dtype, \ |
394 | c10::optional<Layout> layout, \ |
395 | c10::optional<Device> device, \ |
396 | c10::optional<bool> pin_memory) { \ |
397 | return _sparse_compressed_tensor_unsafe_template<REQUIRED_LAYOUT>(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \ |
398 | } |
399 | |
400 | SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr); |
401 | SPARSE_COMPRESSED_TENSOR_UNSAFE(csc, kSparseCsc); |
402 | SPARSE_COMPRESSED_TENSOR_UNSAFE(bsr, kSparseBsr); |
403 | SPARSE_COMPRESSED_TENSOR_UNSAFE(bsc, kSparseBsc); |
404 | |
405 | DimVector _estimate_sparse_compressed_tensor_size( |
406 | const Tensor& compressed_indices, |
407 | const Tensor& plain_indices, |
408 | const Tensor& values, |
409 | Layout layout) { |
410 | const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size" , [&] { return 0; }, [&] { return 2; }); |
411 | const int base_ndim = 2; // corresponds to compressed and plain indices |
412 | const int batch_ndim = compressed_indices.dim() - 1; |
413 | const std::string compressed_indices_name = compressedIndicesName(layout); |
414 | const std::string plain_indices_name = plainIndicesName(layout); |
415 | TORCH_CHECK( |
416 | batch_ndim >= 0, |
417 | compressed_indices_name, " must have dimensionality >= 1 but got " , compressed_indices.dim()); |
418 | TORCH_CHECK( |
419 | compressed_indices.dim() == plain_indices.dim(), |
420 | compressed_indices_name, " and " , plain_indices_name, " dimensionalities must be equal but got " , |
421 | compressed_indices.dim(), " and " , plain_indices.dim(), ", respectively" ); |
422 | const int dense_ndim = values.dim() - batch_ndim - block_ndim - 1; |
423 | TORCH_CHECK( |
424 | dense_ndim >= 0, |
425 | "values must have dimensionality > sum of batch and block dimensionalities (=" , |
426 | batch_ndim, " + " , block_ndim, ") but got " , values.dim()); |
427 | DimVector blocksize{ |
428 | (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 1)) : 1), |
429 | (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 2)) : 1) |
430 | }; |
431 | DimVector size = DimVector(compressed_indices.sizes().slice(0, batch_ndim)); |
432 | int64_t compressed_dim_size = (compressed_indices.dim() > 0 && compressed_indices.size(-1) > 0 ? compressed_indices.size(-1) - 1 : 0); |
433 | int64_t plain_dim_size = AT_DISPATCH_INTEGRAL_TYPES(plain_indices.scalar_type(), "estimate_sparse_compressed_tensor_size" , |
434 | [&]() -> int64_t { |
435 | if (plain_indices.numel() > 0) { |
436 | return plain_indices.max().item<scalar_t>() + 1; |
437 | } else { |
438 | return 0; |
439 | } |
440 | }); |
441 | AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size" , |
442 | [&]{ |
443 | size.push_back(compressed_dim_size * blocksize[0]); |
444 | size.push_back(plain_dim_size * blocksize[1]); |
445 | }, |
446 | [&]{ |
447 | size.push_back(plain_dim_size * blocksize[0]); |
448 | size.push_back(compressed_dim_size * blocksize[1]); |
449 | }); |
450 | for (int i=0; i<dense_ndim; i++) { |
451 | int64_t j = batch_ndim + 1 + block_ndim + i; |
452 | size.push_back((j < values.dim() ? values.size(j) : 1)); |
453 | } |
454 | TORCH_CHECK( |
455 | static_cast<int>(size.size()) == batch_ndim + base_ndim + dense_ndim, |
456 | "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=" , |
457 | batch_ndim, " + " , base_ndim, " + " , dense_ndim, ") but got " , size.size()); |
458 | return size; |
459 | } |
460 | |
461 | // TODO: This constructor should probably use an ATen abstract method in order |
462 | // to make autograd dispatch available for the CSR constructor. See the relevant |
463 | // note in native_functions.yaml. |
464 | Tensor sparse_compressed_tensor( |
465 | const Tensor& compressed_indices, |
466 | const Tensor& plain_indices, |
467 | const Tensor& values, |
468 | IntArrayRef size, |
469 | c10::optional<ScalarType> dtype, |
470 | c10::optional<Layout> layout, |
471 | c10::optional<Device> device, |
472 | c10::optional<bool> pin_memory) { |
473 | |
474 | if (!layout) { |
475 | AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none" ); |
476 | } |
477 | Layout layout_ = layout.value(); |
478 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor" , [&]{}); |
479 | |
480 | // See [Note: hacky wrapper removal for TensorOptions] |
481 | TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); |
482 | |
483 | return at::native::_sparse_compressed_tensor_unsafe( |
484 | compressed_indices, |
485 | plain_indices, |
486 | values, |
487 | size, |
488 | optTypeMetaToScalarType(options.dtype_opt()), |
489 | options.layout_opt(), |
490 | options.device_opt(), |
491 | options.pinned_memory_opt()); |
492 | } |
493 | |
494 | Tensor sparse_compressed_tensor( |
495 | const Tensor& compressed_indices, |
496 | const Tensor& plain_indices, |
497 | const Tensor& values, |
498 | c10::optional<ScalarType> dtype, |
499 | c10::optional<Layout> layout, |
500 | c10::optional<Device> device, |
501 | c10::optional<bool> pin_memory) { |
502 | |
503 | if (!layout) { |
504 | AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none" ); |
505 | } |
506 | Layout layout_ = layout.value(); |
507 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor" , [&]{}); |
508 | |
509 | DimVector size = _estimate_sparse_compressed_tensor_size(compressed_indices, plain_indices, values, layout_); |
510 | |
511 | // See [Note: hacky wrapper removal for TensorOptions] |
512 | TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); |
513 | |
514 | return at::native::_sparse_compressed_tensor_unsafe( |
515 | compressed_indices, |
516 | plain_indices, |
517 | values, |
518 | size, |
519 | optTypeMetaToScalarType(options.dtype_opt()), |
520 | options.layout_opt(), |
521 | options.device_opt(), |
522 | options.pinned_memory_opt()); |
523 | } |
524 | |
525 | #define SPARSE_COMPRESSED_TENSOR(KIND, REQUIRED_LAYOUT) \ |
526 | Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \ |
527 | const Tensor& plain_indices, \ |
528 | const Tensor& values, \ |
529 | c10::optional<ScalarType> dtype, \ |
530 | c10::optional<Layout> layout, \ |
531 | c10::optional<Device> device, \ |
532 | c10::optional<bool> pin_memory) { \ |
533 | if (layout) { \ |
534 | TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \ |
535 | } \ |
536 | c10::optional<Layout> layout_(REQUIRED_LAYOUT); \ |
537 | return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, dtype, layout_, device, pin_memory); \ |
538 | } \ |
539 | Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \ |
540 | const Tensor& plain_indices, \ |
541 | const Tensor& values, \ |
542 | IntArrayRef size, \ |
543 | c10::optional<ScalarType> dtype, \ |
544 | c10::optional<Layout> layout, \ |
545 | c10::optional<Device> device, \ |
546 | c10::optional<bool> pin_memory) { \ |
547 | if (layout) { \ |
548 | TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \ |
549 | } \ |
550 | c10::optional<Layout> layout_(REQUIRED_LAYOUT); \ |
551 | return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, size, dtype, layout_, device, pin_memory); \ |
552 | } |
553 | |
554 | SPARSE_COMPRESSED_TENSOR(csr, kSparseCsr) |
555 | SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc) |
556 | SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr) |
557 | SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc) |
558 | |
559 | Tensor empty_sparse_compressed( |
560 | IntArrayRef size, |
561 | c10::optional<ScalarType> dtype, |
562 | c10::optional<Layout> layout, |
563 | c10::optional<Device> device, |
564 | c10::optional<bool> pin_memory, |
565 | c10::optional<MemoryFormat> optional_memory_format) { |
566 | check_size_nonnegative(size); |
567 | TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size " , size); |
568 | |
569 | // Strided is the default layout for torch.empty. |
570 | Layout layout_ = layout.value_or(Layout::Strided); |
571 | |
572 | // torch.empty cannot be used to create blocked tensors because its |
573 | // API lacks a method to specify the block size. |
574 | AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(layout_, "empty_sparse_compressed" , [&]{}); |
575 | |
576 | int64_t nnz = 0; |
577 | auto compressed_indices_size = DimVector(size.slice(0, size.size() - 2)); |
578 | auto plain_indices_and_values_size = DimVector(size.slice(0, size.size() - 2)); |
579 | compressed_indices_size.push_back(size[compressedDimension(layout_, size)] + 1); |
580 | plain_indices_and_values_size.push_back(nnz); |
581 | |
582 | TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory); |
583 | auto compressed_indices = at::empty(compressed_indices_size, options); |
584 | auto plain_indices = at::empty(plain_indices_and_values_size, options); |
585 | auto values = at::empty(plain_indices_and_values_size, options.dtype(dtype)); |
586 | |
587 | return at::native::_sparse_compressed_tensor_unsafe(compressed_indices, |
588 | plain_indices, |
589 | values, |
590 | size, |
591 | dtype, |
592 | layout, |
593 | device, |
594 | pin_memory); |
595 | } |
596 | |
597 | const Tensor& resize_sparse_csr_( |
598 | const Tensor& self, |
599 | IntArrayRef size, |
600 | c10::optional<MemoryFormat> optional_memory_format) { |
601 | check_size_nonnegative(size); |
602 | TORCH_CHECK(size.size() >= 2, "torch.resize_: Only batched sparse CSR matrices are supported, but got size " , size); |
603 | TORCH_CHECK( |
604 | self.size(-1) <= size[size.size() - 1], |
605 | "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. " , |
606 | "The original number of columns is " , |
607 | self.size(-1), |
608 | " while the requested new number of columns is " , size[size.size() - 1], "." ); |
609 | get_sparse_csr_impl(self)->resize_(self._nnz(), size); |
610 | return self; |
611 | } |
612 | |
613 | Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocking) { |
614 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_" , [&]{}); |
615 | TORCH_CHECK( |
616 | self.layout() == src.layout(), |
617 | "torch.copy_: copy of sparse compressed tensors having different layouts is not supported." , |
618 | " self layout is " , self.layout(), " and src layout is " , src.layout()); |
619 | TORCH_CHECK( |
620 | self._nnz() == src._nnz(), // actually, values copy allows different shapes as long as operands are broadcastable |
621 | "torch.copy_: only sparse compressed tensors with the same number of specified elements are supported." ); |
622 | auto self_compressed_dim = compressedDimension(self.layout(), self.sizes()); |
623 | auto src_compressed_dim = compressedDimension(src.layout(), src.sizes()); |
624 | auto self_compressed_dims = self.size(self_compressed_dim); |
625 | auto src_compressed_dims = src.size(compressedDimension(src.layout(), src.sizes())); |
626 | if (self_compressed_dim == src_compressed_dim) { |
627 | TORCH_CHECK(self_compressed_dims == src_compressed_dims, |
628 | "torch.copy_: expected shapes of self and src to match along dimension " , |
629 | self_compressed_dim, " for " , |
630 | self.layout(), " layout but the corresponding dimensions of self and src are " , |
631 | self_compressed_dims, " and " , src_compressed_dims, ", respectively." ); |
632 | } else { |
633 | TORCH_CHECK(self_compressed_dims == src_compressed_dims, |
634 | "torch.copy_: expected shapes of self and src to match along dimensions " , |
635 | self_compressed_dim, " and " , src_compressed_dim, ", respectively, for " , |
636 | self.layout(), " layout but the corresponding dimensions of self and src are " , |
637 | self_compressed_dims, " and " , src_compressed_dims, ", respectively." ); |
638 | } |
639 | AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_" , |
640 | [&]{}, |
641 | [&]{ |
642 | auto self_values = self.values(); |
643 | auto src_values = src.values(); |
644 | auto self_blocksize = DimVector(self_values.sizes().slice(self_values.dim()-2, 2)); |
645 | auto src_blocksize = DimVector(src_values.sizes().slice(src_values.dim()-2, 2)); |
646 | TORCH_CHECK(self_blocksize == src_blocksize, |
647 | "torch.copy_: copy of sparse compressed tensors having different block sizes is not supported." , |
648 | " self and src block sizes are " , self_blocksize, " and " , src_blocksize, ", respectively." ); |
649 | }); |
650 | AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_" , |
651 | [&]{ |
652 | self.crow_indices().copy_(src.crow_indices(), non_blocking); |
653 | self.col_indices().copy_(src.col_indices(), non_blocking); |
654 | }, |
655 | [&]{ |
656 | self.ccol_indices().copy_(src.ccol_indices(), non_blocking); |
657 | self.row_indices().copy_(src.row_indices(), non_blocking); |
658 | }); |
659 | self.values().copy_(src.values(), non_blocking); |
660 | return self; |
661 | } |
662 | |
663 | // Access members of CSR tensors. |
664 | int64_t _nnz_sparse_csr(const SparseCsrTensor& self) { |
665 | return get_sparse_csr_impl(self)->nnz(); |
666 | } |
667 | |
668 | Tensor values_sparse_csr(const Tensor& self) { |
669 | return get_sparse_csr_impl(self)->values().alias(); |
670 | } |
671 | |
672 | Tensor crow_indices_sparse_csr(const Tensor& self) { |
673 | return AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(self.layout(), |
674 | "crow_indices" , |
675 | [&]{ return get_sparse_csr_impl(self)->compressed_indices().alias(); }); |
676 | } |
677 | |
678 | Tensor col_indices_sparse_csr(const Tensor& self) { |
679 | return AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(self.layout(), |
680 | "col_indices" , |
681 | [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); }); |
682 | } |
683 | |
684 | Tensor ccol_indices_sparse_csr(const Tensor& self) { |
685 | return AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(self.layout(), |
686 | "ccol_indices" , |
687 | [&]{ return get_sparse_csr_impl(self)->compressed_indices().alias(); }); |
688 | } |
689 | |
690 | Tensor row_indices_sparse_csr(const Tensor& self) { |
691 | return AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(self.layout(), |
692 | "row_indices" , |
693 | [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); }); |
694 | } |
695 | |
696 | Tensor crow_indices_default(const Tensor& self) { |
697 | TORCH_CHECK(false, "crow_indices expected sparse row compressed tensor layout but got " , self.layout()); |
698 | } |
699 | |
700 | Tensor col_indices_default(const Tensor& self) { |
701 | TORCH_CHECK(false, "col_indices expected sparse row compressed tensor layout but got " , self.layout()); |
702 | } |
703 | |
704 | Tensor ccol_indices_default(const Tensor& self) { |
705 | TORCH_CHECK(false, "ccol_indices expected sparse column compressed tensor layout but got " , self.layout()); |
706 | } |
707 | |
708 | Tensor row_indices_default(const Tensor& self) { |
709 | TORCH_CHECK(false, "row_indices expected sparse column compressed tensor layout but got " , self.layout()); |
710 | } |
711 | |
712 | int64_t sparse_dim_sparse_csr(const SparseCsrTensor& self) { |
713 | return get_sparse_csr_impl(self)->sparse_dim(); |
714 | } |
715 | |
716 | int64_t dense_dim_sparse_csr(const SparseCsrTensor& self) { |
717 | return get_sparse_csr_impl(self)->dense_dim(); |
718 | } |
719 | |
720 | bool _is_same_size_as_sparse_compressed( |
721 | const SparseCsrTensor& self, |
722 | const SparseCsrTensor& src) { |
723 | return self.sizes().equals(src.sizes()); |
724 | } |
725 | |
726 | const SparseCsrTensor& resize_as_sparse_compressed_( |
727 | const SparseCsrTensor& self, |
728 | const SparseCsrTensor& src) { |
729 | auto src_layout = src.layout(); |
730 | auto self_layout = self.layout(); |
731 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS( |
732 | src_layout, "resize_as_sparse_compressed_: src " , []() {}); |
733 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS( |
734 | self_layout, "resize_as_sparse_compressed_: self " , []() {}); |
735 | // Note: The impl method does all required checking to see if resize/data copy |
736 | // on member tensors is required. |
737 | get_sparse_csr_impl(self)->resize_as_sparse_compressed_tensor_(src); |
738 | return self; |
739 | } |
740 | |
741 | SparseCsrTensor clone_sparse_compressed( |
742 | const SparseCsrTensor& self, |
743 | c10::optional<c10::MemoryFormat> optional_memory_format) { |
744 | TORCH_CHECK( |
745 | !optional_memory_format.has_value(), |
746 | "unsupported memory format option " , |
747 | optional_memory_format.value()); |
748 | TensorOptions options = self.options(); |
749 | auto compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), |
750 | "clone_sparse_compressed" , |
751 | [&]{ return self.crow_indices(); }, |
752 | [&]{ return self.ccol_indices(); }); |
753 | auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), |
754 | "clone_sparse_compressed" , |
755 | [&]{ return self.col_indices(); }, |
756 | [&]{ return self.row_indices(); }); |
757 | return at::native::_sparse_compressed_tensor_unsafe( |
758 | compressed_indices.clone(), |
759 | plain_indices.clone(), |
760 | self.values().clone(), |
761 | self.sizes(), |
762 | optTypeMetaToScalarType(options.dtype_opt()), |
763 | options.layout_opt(), |
764 | options.device_opt(), |
765 | options.pinned_memory_opt()); |
766 | } |
767 | |
768 | Tensor empty_like_sparse_csr( |
769 | const Tensor& self, |
770 | c10::optional<ScalarType> dtype, |
771 | c10::optional<Layout> layout, |
772 | c10::optional<Device> device, |
773 | c10::optional<bool> pin_memory, |
774 | c10::optional<c10::MemoryFormat> optional_memory_format) { |
775 | TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); |
776 | TensorOptions options = |
777 | self.options() |
778 | .merge_in(options_) |
779 | .merge_memory_format(optional_memory_format); |
780 | |
781 | TORCH_CHECK(options.layout() == self.layout(), |
782 | "empty_like with different sparse layout is not supported (self is " , |
783 | self.layout(), " but you requested " , options.layout(), ")" ); |
784 | if (options.layout() == kSparseCsr) { |
785 | auto result = at::native::_sparse_csr_tensor_unsafe( |
786 | self.crow_indices().clone(), |
787 | self.col_indices().clone(), |
788 | at::empty(self.values().sizes(), options.layout(kStrided)), |
789 | self.sizes(), |
790 | optTypeMetaToScalarType(options.dtype()), |
791 | self.layout(), |
792 | options.device()); |
793 | return result; |
794 | } else if (options.layout() == kSparseCsc) { |
795 | auto result = at::native::_sparse_csc_tensor_unsafe( |
796 | self.ccol_indices().clone(), |
797 | self.row_indices().clone(), |
798 | at::empty(self.values().sizes(), options.layout(kStrided)), |
799 | self.sizes(), |
800 | optTypeMetaToScalarType(options.dtype()), |
801 | self.layout(), |
802 | options.device()); |
803 | return result; |
804 | } else if (options.layout() == kSparseBsr) { |
805 | auto result = at::native::_sparse_bsr_tensor_unsafe( |
806 | self.crow_indices().clone(), |
807 | self.col_indices().clone(), |
808 | at::empty(self.values().sizes(), options.layout(kStrided)), |
809 | self.sizes(), |
810 | optTypeMetaToScalarType(options.dtype()), |
811 | self.layout(), |
812 | options.device()); |
813 | |
814 | return result; |
815 | } else if (options.layout() == kSparseBsc) { |
816 | auto result = at::native::_sparse_bsc_tensor_unsafe( |
817 | self.ccol_indices().clone(), |
818 | self.row_indices().clone(), |
819 | at::empty(self.values().sizes(), options.layout(kStrided)), |
820 | self.sizes(), |
821 | optTypeMetaToScalarType(options.dtype()), |
822 | self.layout(), |
823 | options.device()); |
824 | return result; |
825 | } else if (options.layout() == kStrided) { |
826 | return at::native::empty_like(self, dtype, layout, device, pin_memory, optional_memory_format); |
827 | } else { |
828 | TORCH_CHECK(false, "Layout " , options.layout(), " is not supported" ); |
829 | } |
830 | } |
831 | |
832 | template <bool require_view, bool require_copy> |
833 | Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index) { |
834 | constexpr const char* select_name = (require_view ? "select()" : "select_copy()" ); |
835 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS( |
836 | self.layout(), "select" , []() { return; }); |
837 | TORCH_CHECK_INDEX( |
838 | self.dim() != 0, select_name, " cannot be applied to a 0-dim tensor." ); |
839 | dim = maybe_wrap_dim(dim, self.dim()); |
840 | auto size = self.size(dim); |
841 | if (index < -size || index >= size) { |
842 | TORCH_CHECK_INDEX( |
843 | false, |
844 | select_name, ": index " , |
845 | index, |
846 | " out of range for tensor of size " , |
847 | self.sizes(), |
848 | " at dimension " , |
849 | dim); |
850 | } |
851 | if (index < 0) { |
852 | index += size; |
853 | } |
854 | |
855 | auto select_strided = [](const Tensor& self, int64_t dim, int64_t index) { |
856 | if (require_copy) { |
857 | return at::select_copy(self, dim, index); |
858 | } else { |
859 | return self.select(dim, index); |
860 | } |
861 | }; |
862 | |
863 | TORCH_INTERNAL_ASSERT(dim >= 0 && dim < self.dim()); |
864 | |
865 | auto new_sizes = DimVector(self.sizes()); |
866 | new_sizes.erase(new_sizes.begin() + dim); |
867 | auto options = self.options(); |
868 | |
869 | Tensor plain_indices; |
870 | Tensor compressed_indices; |
871 | std::tie(compressed_indices, plain_indices) = |
872 | AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( |
873 | self.layout(), |
874 | "select" , |
875 | [&]() { |
876 | return std::make_pair(self.crow_indices(), self.col_indices()); |
877 | }, |
878 | [&]() { |
879 | return std::make_pair(self.ccol_indices(), self.row_indices()); |
880 | }); |
881 | auto n_batch = compressed_indices.dim() - 1; |
882 | |
883 | if (dim < n_batch) { |
884 | // Selecting batch dimension |
885 | return at::native::_sparse_compressed_tensor_unsafe( |
886 | compressed_indices.select(dim, index), |
887 | plain_indices.select(dim, index), |
888 | select_strided(self.values(), dim, index), |
889 | new_sizes, |
890 | optTypeMetaToScalarType(options.dtype_opt()), |
891 | options.layout_opt(), |
892 | options.device_opt(), |
893 | options.pinned_memory_opt()); |
894 | } else if (dim < n_batch + 2) { |
895 | // Selecting sparse dimension |
896 | TORCH_CHECK( |
897 | n_batch == 0, |
898 | select_name, ": selecting sparse dimensions is not implemented for batched sparse compressed tensors." ) |
899 | TORCH_INTERNAL_ASSERT(dim == 0 || dim == 1); |
900 | |
901 | DimVector blocksize{1, 1}; |
902 | AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select" , [&] {}, [&] { |
903 | blocksize[0] = std::max<int64_t>(1, self.values().size(n_batch + 1)); |
904 | blocksize[1] = std::max<int64_t>(1, self.values().size(n_batch + 2)); |
905 | }); |
906 | |
907 | auto indices_options = compressed_indices.options(); |
908 | int64_t fast_dim = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select" , [&]() { return 0; }, [&]() { return 1; }); |
909 | int64_t other_dim = (dim == 0 ? 1 : 0); |
910 | Tensor indices; |
911 | Tensor values; |
912 | bool is_view = dim == fast_dim; |
913 | if (is_view) { |
914 | // select is always a view operation |
915 | Tensor start_end = compressed_indices.narrow(0, index / blocksize[dim], 2).cpu(); |
916 | int64_t start = start_end[0].item<int64_t>(); |
917 | int64_t end = start_end[1].item<int64_t>(); |
918 | indices = plain_indices.slice(0, start, end); |
919 | values = self.values().slice(0, start, end); |
920 | } else { |
921 | Tensor decompressed_indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices) |
922 | .select(0, 0); |
923 | |
924 | Tensor dim_indices = at::where(plain_indices.eq(index / blocksize[dim]))[0]; |
925 | // Notice that dim_indices is a sorted sequence of non-negative |
926 | // distinct integers. Below we'll try to solve `dim_indices == |
927 | // arange(start, stop, step)`. If the solution exists then the |
928 | // select will be a view operation also for the `dim != |
929 | // fast_dim` case. |
930 | int64_t start{}, end{}, step{}; |
931 | if (solve_arange(dim_indices, start, end, step)) { |
932 | indices = decompressed_indices.slice(0, start, end, step); |
933 | values = self.values().slice(0, start, end, step); |
934 | is_view = true; |
935 | } else { |
936 | // select will be a copy operation due to index_select! |
937 | indices = decompressed_indices.index_select(0, dim_indices); |
938 | values = self.values().index_select(0, dim_indices); |
939 | } |
940 | } |
941 | |
942 | AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select" , [&]() {}, |
943 | [&]() { |
944 | /* |
945 | The formula for select indices and values below are best |
946 | explained by an example. Consider a BSR tensor with a |
947 | block size (2, 3) having four blocks (the other two blocks |
948 | contain all zeros and hence will not be specified): |
949 | |
950 | [ 1 2 3] | [ 7 8 9] |
951 | [ 4 5 6] | [10 11 12] |
952 | --------------------- |
953 | [13 14 15] | [ 0 0 0] |
954 | [16 17 18] | [ 0 0 0] |
955 | ----------------------- |
956 | [ 0 0 0] | [19 20 21] |
957 | [ 0 0 0] | [22 23 24] |
958 | |
959 | that represents a 6 x 6 tensor: |
960 | |
961 | [ 1 2 3 7 8 9 ] |
962 | [ 4 5 6 10 11 12 ] |
963 | [ 13 14 15 0 0 0 ] |
964 | [ 16 17 18 0 0 0 ] |
965 | [ 0 0 0 19 20 21 ] |
966 | [ 0 0 0 22 23 24 ] |
967 | |
968 | The corresponding data for the BSR representation is: |
969 | |
970 | crow_indices = [0 2 3 4] |
971 | col_indices = [0 1 0 1] |
972 | values = [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]], [[13 14 15], [16 17 18]], [[19 20 21], [22 23 24]] ] |
973 | shape = (6, 6) |
974 | |
975 | From crow_indices, we can find that |
976 | |
977 | row_indices = [0 0 1 2] |
978 | |
979 | In the following, we'll illustrate the details of |
980 | computing the result of torch.select_copy(input, dim, |
981 | index) where dim is 0 or 1, and index is in |
982 | range(shape[dim]). |
983 | |
984 | Select a row of a BSR tensor |
985 | ---------------------------- |
986 | |
987 | We will consider first the dim=0 case that corresponds to |
988 | selecting a index-th row of the tensor. For instance, for |
989 | dim=0 and index=1, the expected result would represent a |
990 | 1D tensor: |
991 | |
992 | [ 4 5 6 10 11 12 ] |
993 | |
994 | that is a concatenated tensor of certain slices from the |
995 | first and the second block that is computed as follows: |
996 | |
997 | values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1) |
998 | -> values[[0, 1]][:, 1 % 2].flatten(0, 1) |
999 | -> [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]] ][:, 1].flatten(0, 1) |
1000 | -> [ [4 5 6], [10 11 12]].flatten(0, 1) |
1001 | -> [ 4 5 6 10 11 12] |
1002 | |
1003 | where dim_indices is found as |
1004 | |
1005 | where(row_indices == index//blocksize[dim]) |
1006 | -> where([0 0 1 2] == 1//2) |
1007 | -> [0 1] |
1008 | |
1009 | The corresponding column indices are computed as |
1010 | |
1011 | (col_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1) |
1012 | |
1013 | where other_dim is 1 if dim is 0, and 0 if dim is 1. Let's |
1014 | expand the above expression with the data in the example: |
1015 | |
1016 | -> (col_indices[[0, 1]].mul(3).unsqueeze(1) + arange(3).unsqueeze(0)).flatten(0, 1) |
1017 | -> ([[0 1].mul(3).unsqueeze(1) + [[0 1 2]]).flatten(0, 1) |
1018 | -> ([[[0], [3]] + [[0 1 2]]).flatten(0, 1) <- here addition will use broadcasting rules! |
1019 | -> ([[[0 1 2], [3 4 5]]).flatten(0, 1) |
1020 | -> [0 1 2 3 4 5] |
1021 | |
1022 | Finally, the select(dim=0, index=1) op on the given sparse |
1023 | compressed tensors will return a COO tensor: |
1024 | |
1025 | sparse_coo_tensor([0 1 2 3 4 5].unsqueeze(0), [4 5 6 10 11 12], (6,)) |
1026 | |
1027 | that represents the expected result: [ 4 5 6 10 11 12 ] |
1028 | |
1029 | Select a column of a BSR tensor |
1030 | ------------------------------- |
1031 | |
1032 | Next, we'll consider the dim=1 case that corresponds to |
1033 | selecting the index-th column of the tensor. For instance, |
1034 | for dim=1 and index=4, the expected result would represent |
1035 | a 1D tensor: |
1036 | |
1037 | [ 8 11 0 0 20 23] |
1038 | |
1039 | that is a concatenated tensor of certain slices from the |
1040 | second and the last block: |
1041 | |
1042 | values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1) |
1043 | -> values[[1, 3]][:, :, 4 % 3 ].flatten(0, 1) |
1044 | -> [ [[7 8 9], [10 11 12]], [[19 20 21], [22 23 24]] ][:, 1, 1].flatten(0, 1) |
1045 | -> [ [8 11], [20 23]].flatten(0, 1) |
1046 | -> [ 8 11 20 23 ] |
1047 | |
1048 | The corresponding row indices are computed as |
1049 | |
1050 | (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1) |
1051 | |
1052 | where dim_indices is |
1053 | |
1054 | where(col_indices == index//blocksize[dim]) |
1055 | -> where([0 1 0 1] == 4//3) |
1056 | -> [1 3] |
1057 | |
1058 | and we have |
1059 | |
1060 | (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1) |
1061 | -> (row_indices[[1 3]].mul(2).unsqueeze(1) + arange(2).unsqueeze(0)).flatten(0, 1) |
1062 | -> ([0 4].unsqueeze(1) + [0 1].unsqueeze(0)).flatten(0, 1) |
1063 | -> ([[0], [4]] + [[0 1]]).flatten(0, 1) <- here addition will use broadcasting rules! |
1064 | -> ([[0 1], [4 5]]).flatten(0, 1) |
1065 | -> [ 0 1 4 5 ] |
1066 | |
1067 | Finally, the select(dim=1, index=4) op on the given sparse |
1068 | compressed tensors will return a COO tensor: |
1069 | |
1070 | sparse_coo_tensor([0 1 4 5].unsqueeze(0), [8 11 20 23], (6,)) |
1071 | |
1072 | that represents the expected result: [ 8 11 0 0 20 23 ] |
1073 | |
1074 | */ |
1075 | Tensor subblock_indices = at::arange(0, blocksize[other_dim], indices_options); |
1076 | indices = indices.mul(blocksize[other_dim]).unsqueeze(1).add(subblock_indices.unsqueeze(0)).flatten(0, 1); |
1077 | values = values.select(dim + 1, index % blocksize[dim]).flatten(0, 1); |
1078 | // flatten(0, 1) can be a view or a copy operation. If view |
1079 | // is required, it will be checked below via is_alias_of, |
1080 | // otherwise, we'll check if copy is made here to avoid |
1081 | // unnecessary clone below: |
1082 | if (require_copy) { |
1083 | is_view = values.is_alias_of(self.values()); |
1084 | } |
1085 | }); |
1086 | |
1087 | if (require_view) { |
1088 | TORCH_CHECK(values.is_alias_of(self.values()), select_name, |
1089 | ": no view exists for the given input, consider using torch.select_copy." ); |
1090 | } |
1091 | |
1092 | indices = indices.unsqueeze(0).to(kLong); |
1093 | if (require_copy && is_view) { |
1094 | values = values.clone(); |
1095 | } |
1096 | return at::_sparse_coo_tensor_unsafe(indices, values, new_sizes)._coalesced_(true); |
1097 | } else { |
1098 | // Selecting dense dimension |
1099 | Tensor new_values = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( |
1100 | self.layout(), |
1101 | "select" , |
1102 | // Non blocked layout (2 sparse dims become 1 nnz dim in values, so dim |
1103 | // is found one position to the left) |
1104 | [&]() { return select_strided(self.values(), dim - 1, index); }, |
1105 | // Block layout (2 sparse dims become 1 nnz dim + 2 block-shape dims in |
1106 | // values, so dim is found 1 position to the right) |
1107 | [&]() { return select_strided(self.values(), dim + 1, index); }); |
1108 | return at::native::_sparse_compressed_tensor_unsafe( |
1109 | compressed_indices, |
1110 | plain_indices, |
1111 | new_values, |
1112 | new_sizes, |
1113 | optTypeMetaToScalarType(options.dtype_opt()), |
1114 | options.layout_opt(), |
1115 | options.device_opt(), |
1116 | options.pinned_memory_opt()); |
1117 | } |
1118 | } |
1119 | |
1120 | Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { |
1121 | return select_sparse_csr_worker<true, false>(self, dim, index); |
1122 | } |
1123 | |
1124 | Tensor select_copy_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { |
1125 | return select_sparse_csr_worker<false, true>(self, dim, index); |
1126 | } |
1127 | |
1128 | } // namespace native |
1129 | } // namespace at |
1130 | |