1#pragma once
2
3#include <ATen/SparseCsrTensorImpl.h>
4#include <ATen/SparseTensorImpl.h>
5#include <ATen/SparseTensorUtils.h>
6#include <ATen/core/Tensor.h>
7
8#define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
9 [&] { \
10 const auto& the_layout = LAYOUT; \
11 switch (the_layout) { \
12 case kSparseCsr: \
13 case kSparseCsc: \
14 case kSparseBsr: \
15 case kSparseBsc: \
16 return __VA_ARGS__(); \
17 default: \
18 AT_ERROR( \
19 NAME, \
20 " expected sparse compressed tensor layout but got ", \
21 the_layout); \
22 } \
23 }()
24
25#define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
26 LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
27 [&]() { \
28 const auto& the_layout = LAYOUT; \
29 switch (the_layout) { \
30 case kSparseCsr: \
31 case kSparseBsr: \
32 return (ROW_DIM_ACTION)(); \
33 case kSparseCsc: \
34 case kSparseBsc: \
35 return (COLUMN_DIM_ACTION)(); \
36 default: \
37 AT_ERROR( \
38 NAME, \
39 " expected sparse compressed tensor layout but got ", \
40 the_layout); \
41 } \
42 }()
43
44#define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
45 LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
46 [&]() { \
47 const auto& the_layout = LAYOUT; \
48 switch (the_layout) { \
49 case kSparseCsr: \
50 case kSparseCsc: \
51 return (NO_BLOCK_ACTION)(); \
52 case kSparseBsr: \
53 case kSparseBsc: \
54 return (BLOCK_ACTION)(); \
55 default: \
56 AT_ERROR( \
57 NAME, \
58 " expected sparse compressed tensor layout but got ", \
59 the_layout); \
60 } \
61 }()
62
63#define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
64 LAYOUT, NAME, ROW_DIM_ACTION) \
65 [&]() { \
66 const auto& the_layout = LAYOUT; \
67 switch (the_layout) { \
68 case kSparseCsr: \
69 case kSparseBsr: \
70 return (ROW_DIM_ACTION)(); \
71 default: \
72 AT_ERROR( \
73 NAME, \
74 " expected sparse row compressed tensor layout but got ", \
75 the_layout); \
76 } \
77 }()
78
79#define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
80 LAYOUT, NAME, COL_DIM_ACTION) \
81 [&]() { \
82 const auto& the_layout = LAYOUT; \
83 switch (the_layout) { \
84 case kSparseCsc: \
85 case kSparseBsc: \
86 return (COL_DIM_ACTION)(); \
87 default: \
88 AT_ERROR( \
89 NAME, \
90 " expected sparse column compressed tensor layout but got ", \
91 the_layout); \
92 } \
93 }()
94
95#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
96 [&]() { \
97 const auto& the_layout = LAYOUT; \
98 switch (the_layout) { \
99 case kSparseCsr: \
100 case kSparseCsc: \
101 return (ACTION)(); \
102 default: \
103 AT_ERROR( \
104 NAME, \
105 " expected sparse compressed (non-block) tensor layout but got ", \
106 the_layout); \
107 } \
108 }()
109
110#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
111 [&]() { \
112 const auto& the_layout = LAYOUT; \
113 switch (the_layout) { \
114 case kSparseBsr: \
115 case kSparseBsc: \
116 return (ACTION)(); \
117 default: \
118 AT_ERROR( \
119 NAME, \
120 " expected sparse compressed block tensor layout but got ", \
121 the_layout); \
122 } \
123 }()
124
125#define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
126 AT_DISPATCH_SWITCH( \
127 TYPE, \
128 NAME, \
129 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
130 kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
131
132namespace at {
133namespace sparse_csr {
134
135using SparseCsrTensor = Tensor;
136
137inline bool is_sparse_compressed(const Layout& layout) {
138 switch (layout) {
139 case kSparseCsr:
140 case kSparseCsc:
141 case kSparseBsr:
142 case kSparseBsc:
143 return true;
144 default:;
145 }
146 return false;
147}
148
149inline bool is_sparse_compressed(const Tensor& self) {
150 return is_sparse_compressed(self.layout());
151}
152
153inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
154 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
155 self.layout(), "get_sparse_csr_impl", [&] {});
156 return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
157}
158
159inline std::string layoutToString(
160 Layout layout,
161 bool upper = false,
162 bool lower = false) {
163 switch (layout) {
164 case kSparseCsr:
165 return (upper ? "CSR" : (lower ? "csr" : "Csr"));
166 case kSparseCsc:
167 return (upper ? "CSC" : (lower ? "csc" : "Csc"));
168 case kSparseBsr:
169 return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
170 case kSparseBsc:
171 return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
172 default:
173 TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
174 return "";
175 }
176}
177
178inline bool isCompressedRow(Layout layout) {
179 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
180 layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
181}
182
183inline bool isCompressedColumn(Layout layout) {
184 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
185 layout,
186 "isCompressedColumn",
187 [&] { return false; },
188 [&] { return true; });
189}
190
191inline std::string compressedIndicesName(Layout layout) {
192 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
193 layout,
194 "compressedIndicesName",
195 [&] { return "crow_indices"; },
196 [&] { return "ccol_indices"; });
197}
198
199inline std::string plainIndicesName(Layout layout) {
200 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
201 layout,
202 "plainIndicesName",
203 [&] { return "col_indices"; },
204 [&] { return "row_indices"; });
205}
206
207inline std::string compressedDimName(Layout layout) {
208 switch (layout) {
209 case kSparseCsr:
210 return "row";
211 case kSparseCsc:
212 return "column";
213 case kSparseBsr:
214 return "row block";
215 case kSparseBsc:
216 return "column block";
217 default:
218 TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
219 return "";
220 }
221}
222
223inline std::string plainDimName(Layout layout) {
224 switch (layout) {
225 case kSparseCsr:
226 return "column";
227 case kSparseCsc:
228 return "row";
229 case kSparseBsr:
230 return "column block";
231 case kSparseBsc:
232 return "row block";
233 default:
234 TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
235 return "";
236 }
237}
238
239inline int rowDimension(Layout layout, IntArrayRef size) {
240 return size.size() - (isCompressedRow(layout) ? 2 : 1);
241}
242
243inline int columnDimension(Layout layout, IntArrayRef size) {
244 return size.size() - (isCompressedColumn(layout) ? 2 : 1);
245}
246
247inline int compressedDimension(
248 Layout layout,
249 IntArrayRef size,
250 size_t dense_ndim = 0) {
251 return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
252}
253
254inline int plainDimension(
255 Layout layout,
256 IntArrayRef size,
257 size_t dense_ndim = 0) {
258 return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
259}
260
261inline int64_t numBatchDimensions(Tensor const& self) {
262 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
263 self.layout(),
264 "numBatchDimensions",
265 [&self] { return self.crow_indices().dim() - 1; },
266 [&self] { return self.ccol_indices().dim() - 1; });
267}
268
269inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
270 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
271 self.layout(),
272 "getCompressedPlainIndices",
273 [&self] {
274 return std::make_pair(self.crow_indices(), self.col_indices());
275 },
276 [&self] {
277 return std::make_pair(self.ccol_indices(), self.row_indices());
278 });
279}
280
281inline Layout flip_compressed_layout(Layout layout) {
282 switch (layout) {
283 case kSparseCsr:
284 return kSparseCsc;
285 case kSparseCsc:
286 return kSparseCsr;
287 case kSparseBsr:
288 return kSparseBsc;
289 case kSparseBsc:
290 return kSparseBsr;
291 default:
292 TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
293 return kSparseCsr;
294 }
295}
296
297inline DimVector getBlockSize(Tensor const& self) {
298 int64_t n_batch = numBatchDimensions(self);
299 return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
300}
301
302inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
303 if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
304 int64_t n_batch = numBatchDimensions(self);
305 return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
306 } else {
307 return {};
308 }
309}
310
311} // namespace sparse_csr
312} // namespace at
313