1#include <ATen/SparseTensorUtils.h>
2
3#include <ATen/ATen.h>
4#include <ATen/SparseTensorImpl.h>
5#include <ATen/Parallel.h>
6#include <c10/util/irange.h>
7
8namespace at { namespace sparse {
9
10// NOTE [ Flatten Sparse Indices ]
11// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
12// indices tensor. E.g.,
13// input = [[2, 4, 0],
14// [3, 1, 10]]
15// full_size = [2, 12]
16// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
17//
18// In other words, assuming that each `indices[i, :]` is a valid index to a
19// tensor `t` of shape `full_size`. This returns the corresponding indices to
20// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
21// if forceClone is true, the result will forced to be a clone of self.
22// if force_clone is true, the result will forced to be a clone of self.
23Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone /*= false*/) {
24 int64_t sparse_dim = indices.size(0);
25 if (sparse_dim == 1) {
26 if (force_clone) {
27 return indices.squeeze(0).clone(at::MemoryFormat::Contiguous);
28 } else {
29 return indices.squeeze(0);
30 }
31 } else {
32 std::vector<int64_t> indices_mult_cpu_vec;
33 indices_mult_cpu_vec.resize(sparse_dim);
34 int64_t mult = 1;
35 for (int64_t i = sparse_dim - 1; i >= 0; i--) {
36 indices_mult_cpu_vec[i] = mult;
37 mult *= full_size[i];
38 }
39 Tensor indices_mult_cpu = at::from_blob(
40 indices_mult_cpu_vec.data(),
41 // NOLINTNEXTLINE(bugprone-argument-comment)
42 /*size=*/{sparse_dim, 1},
43 indices.options().device(kCPU).dtype(kLong));
44 // NB: must be blocking because this blob may be freed after
45 // this closure, and non_blocking copy will see
46 // garbage.
47 Tensor indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false);
48 // Ideally we want matmul but matmul is slow on CPU Long and not implemented
49 // on CUDA Long. So mul is faster.
50 return indices.mul(indices_mult).sum(0);
51 }
52}
53
54// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ],
55// except this one allows partial flatten: only flatten on specified dims. Note that
56// the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim.
57// Also if input indices is already coalesced, the flattened indices will also be sorted.
58//
59// args:
60// indices: sparse tensor indices
61// sizes: sparse tensor sizes
62// dims_to_flatten: a list of dim index to flatten
63//
64// Ex1:
65// indices = [[2, 4, 0],
66// [3, 1, 3]]
67// sizes = [2, 12]
68// dims_to_flatten = [0, 1]
69// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
70//
71// Ex2:
72// dims_to_flatten = [1]
73// new_indices = [ 3, 1, 3 ] # uncoalesced
74Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){
75 Tensor new_indices = at::zeros({indices.size(1)}, indices.options());
76 for (auto d : dims_to_flatten) {
77 new_indices.mul_(sizes[d]);
78 new_indices.add_(indices.select(0, d));
79 }
80 return new_indices;
81}
82
83Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) {
84 /*
85 Find the CSR representation for a row `indices` from the COO format
86 Inputs:
87 `indices` is the row pointer from COO indices
88 `dim` is the row dimensionality
89 `nnz` is the number of non-zeros
90
91 Output:
92 `csr` is a compressed row array in a CSR format
93 */
94 Tensor csr = at::zeros({dim + 1}, kLong);
95
96 // TODO: eliminate this conditional when zero-size dims supported correctly
97 if (nnz > 0) {
98 auto csr_accessor = csr.accessor<int64_t, 1>();
99 // Convert the sparse matrix to CSR format
100 at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) {
101 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
102 int64_t h, hp0, hp1;
103 for (const auto i : c10::irange(start, end)) {
104 hp0 = indices[i];
105 hp1 = (i+1 == nnz) ? dim : indices[i+1];
106 if (hp0 != hp1) {
107 for (h = hp0; h < hp1; h++) {
108 csr_accessor[h+1] = i+1;
109 }
110 }
111 }
112 });
113 }
114 return csr;
115}
116
117}} // namespace at::sparse
118