1 | #pragma once |
2 | |
3 | #include <ATen/Parallel.h> |
4 | #include <ATen/SparseTensorImpl.h> |
5 | #include <ATen/core/Tensor.h> |
6 | |
7 | #ifndef AT_PER_OPERATOR_HEADERS |
8 | #include <ATen/Functions.h> |
9 | #else |
10 | #include <ATen/ops/empty.h> |
11 | #endif |
12 | |
13 | namespace at { |
14 | namespace sparse { |
15 | |
16 | // Just for documentary purposes |
17 | using SparseTensor = Tensor; |
18 | using SparseType = Type; |
19 | |
20 | // This is an internal utility function for getting at the SparseTensorImpl, |
21 | // so that we can write sparse tensor specific accessors for special fields |
22 | // in SparseTensor. You should only use this for writing low level |
23 | // setters/getters for SparseTensorImpl fields; otherwise, you should use |
24 | // the low level setters/getters that were implemented using this. |
25 | // |
26 | // This may be called repeatedly, so make sure it's pretty cheap. |
27 | inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) { |
28 | TORCH_INTERNAL_ASSERT( |
29 | self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor" ); |
30 | return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl()); |
31 | } |
32 | |
33 | // Takes indices and values and directly puts them into the sparse tensor, no |
34 | // copy. This used to be called THSTensor_(_move) |
35 | inline void alias_into_sparse( |
36 | const SparseTensor& self, |
37 | const Tensor& indices, |
38 | const Tensor& values) { |
39 | get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values); |
40 | } |
41 | |
42 | // Take indices and values and makes a (data) copy of them to put into the |
43 | // sparse indices/values. This used to be called THSTensor_(_set) |
44 | inline void copy_into_sparse( |
45 | const SparseTensor& self, |
46 | const Tensor& indices, |
47 | const Tensor& values, |
48 | bool non_blocking) { |
49 | alias_into_sparse( |
50 | self, |
51 | indices.to(self._indices().options(), non_blocking, /*copy=*/true), |
52 | values.to(self._values().options(), non_blocking, /*copy=*/true)); |
53 | } |
54 | |
55 | // TODO: put this into the public API |
56 | inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) { |
57 | return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl(); |
58 | } |
59 | |
60 | inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) { |
61 | return self.sparse_dim() == src.sparse_dim() && |
62 | self.dense_dim() == src.dense_dim(); |
63 | } |
64 | |
65 | // Give us a new values tensor, with the same dimensionality |
66 | // as 'values' but with a new number of non-zero elements. |
67 | // TODO: Expose this for real in ATen, some day? |
68 | // NB: Doesn't preserve data. |
69 | inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) { |
70 | std::vector<int64_t> size = values.sizes().vec(); |
71 | size[0] = nnz; |
72 | return at::empty(size, values.options()); |
73 | } |
74 | |
75 | // NOTE [ Flatten Sparse Indices ] |
76 | // This helper function flattens a sparse indices tensor (a Tensor) into a 1D |
77 | // indices tensor. E.g., |
78 | // input = [[2, 4, 0], |
79 | // [3, 1, 10]] |
80 | // full_size = [2, 12] |
81 | // output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10] |
82 | // |
83 | // In other words, assuming that each `indices[i, :]` is a valid index to a |
84 | // tensor `t` of shape `full_size`. This returns the corresponding indices to |
85 | // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`. |
86 | // if forceClone is true, the result will forced to be a clone of self. |
87 | // if force_clone is true, the result will forced to be a clone of self. |
88 | TORCH_API Tensor flatten_indices( |
89 | const Tensor& indices, |
90 | IntArrayRef full_size, |
91 | bool force_clone = false); |
92 | |
93 | // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten |
94 | // Sparse Indices ], except this one allows partial flatten: only flatten on |
95 | // specified dims. Note that the flatten indices might be uncoalesced if |
96 | // dims_to_flatten.size() < sparse_dim. Also if input indices is already |
97 | // coalesced, the flattened indices will also be sorted. |
98 | // |
99 | // args: |
100 | // indices: sparse tensor indices |
101 | // sizes: sparse tensor sizes |
102 | // dims_to_flatten: a list of dim index to flatten |
103 | // |
104 | // Ex1: |
105 | // indices = [[2, 4, 0], |
106 | // [3, 1, 3]] |
107 | // sizes = [2, 12] |
108 | // dims_to_flatten = [0, 1] |
109 | // new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3] |
110 | // |
111 | // Ex2: |
112 | // dims_to_flatten = [1] |
113 | // new_indices = [ 3, 1, 3 ] # uncoalesced |
114 | TORCH_API Tensor flatten_indices_by_dims( |
115 | const Tensor& indices, |
116 | const IntArrayRef& sizes, |
117 | const IntArrayRef& dims_to_flatten); |
118 | |
119 | // Find the CSR representation for a row `indices` from the COO format |
120 | TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz); |
121 | |
122 | } // namespace sparse |
123 | } // namespace at |
124 | |