1 | #pragma once |
2 | |
3 | #include <ATen/SparseCsrTensorImpl.h> |
4 | #include <ATen/SparseTensorImpl.h> |
5 | #include <ATen/SparseTensorUtils.h> |
6 | #include <ATen/core/Tensor.h> |
7 | |
8 | #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \ |
9 | [&] { \ |
10 | const auto& the_layout = LAYOUT; \ |
11 | switch (the_layout) { \ |
12 | case kSparseCsr: \ |
13 | case kSparseCsc: \ |
14 | case kSparseBsr: \ |
15 | case kSparseBsc: \ |
16 | return __VA_ARGS__(); \ |
17 | default: \ |
18 | AT_ERROR( \ |
19 | NAME, \ |
20 | " expected sparse compressed tensor layout but got ", \ |
21 | the_layout); \ |
22 | } \ |
23 | }() |
24 | |
25 | #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \ |
26 | LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \ |
27 | [&]() { \ |
28 | const auto& the_layout = LAYOUT; \ |
29 | switch (the_layout) { \ |
30 | case kSparseCsr: \ |
31 | case kSparseBsr: \ |
32 | return (ROW_DIM_ACTION)(); \ |
33 | case kSparseCsc: \ |
34 | case kSparseBsc: \ |
35 | return (COLUMN_DIM_ACTION)(); \ |
36 | default: \ |
37 | AT_ERROR( \ |
38 | NAME, \ |
39 | " expected sparse compressed tensor layout but got ", \ |
40 | the_layout); \ |
41 | } \ |
42 | }() |
43 | |
44 | #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \ |
45 | LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \ |
46 | [&]() { \ |
47 | const auto& the_layout = LAYOUT; \ |
48 | switch (the_layout) { \ |
49 | case kSparseCsr: \ |
50 | case kSparseCsc: \ |
51 | return (NO_BLOCK_ACTION)(); \ |
52 | case kSparseBsr: \ |
53 | case kSparseBsc: \ |
54 | return (BLOCK_ACTION)(); \ |
55 | default: \ |
56 | AT_ERROR( \ |
57 | NAME, \ |
58 | " expected sparse compressed tensor layout but got ", \ |
59 | the_layout); \ |
60 | } \ |
61 | }() |
62 | |
63 | #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \ |
64 | LAYOUT, NAME, ROW_DIM_ACTION) \ |
65 | [&]() { \ |
66 | const auto& the_layout = LAYOUT; \ |
67 | switch (the_layout) { \ |
68 | case kSparseCsr: \ |
69 | case kSparseBsr: \ |
70 | return (ROW_DIM_ACTION)(); \ |
71 | default: \ |
72 | AT_ERROR( \ |
73 | NAME, \ |
74 | " expected sparse row compressed tensor layout but got ", \ |
75 | the_layout); \ |
76 | } \ |
77 | }() |
78 | |
79 | #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \ |
80 | LAYOUT, NAME, COL_DIM_ACTION) \ |
81 | [&]() { \ |
82 | const auto& the_layout = LAYOUT; \ |
83 | switch (the_layout) { \ |
84 | case kSparseCsc: \ |
85 | case kSparseBsc: \ |
86 | return (COL_DIM_ACTION)(); \ |
87 | default: \ |
88 | AT_ERROR( \ |
89 | NAME, \ |
90 | " expected sparse column compressed tensor layout but got ", \ |
91 | the_layout); \ |
92 | } \ |
93 | }() |
94 | |
95 | #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \ |
96 | [&]() { \ |
97 | const auto& the_layout = LAYOUT; \ |
98 | switch (the_layout) { \ |
99 | case kSparseCsr: \ |
100 | case kSparseCsc: \ |
101 | return (ACTION)(); \ |
102 | default: \ |
103 | AT_ERROR( \ |
104 | NAME, \ |
105 | " expected sparse compressed (non-block) tensor layout but got ", \ |
106 | the_layout); \ |
107 | } \ |
108 | }() |
109 | |
110 | #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \ |
111 | [&]() { \ |
112 | const auto& the_layout = LAYOUT; \ |
113 | switch (the_layout) { \ |
114 | case kSparseBsr: \ |
115 | case kSparseBsc: \ |
116 | return (ACTION)(); \ |
117 | default: \ |
118 | AT_ERROR( \ |
119 | NAME, \ |
120 | " expected sparse compressed block tensor layout but got ", \ |
121 | the_layout); \ |
122 | } \ |
123 | }() |
124 | |
125 | #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \ |
126 | AT_DISPATCH_SWITCH( \ |
127 | TYPE, \ |
128 | NAME, \ |
129 | AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ |
130 | kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__)) |
131 | |
132 | namespace at { |
133 | namespace sparse_csr { |
134 | |
135 | using SparseCsrTensor = Tensor; |
136 | |
137 | inline bool is_sparse_compressed(const Layout& layout) { |
138 | switch (layout) { |
139 | case kSparseCsr: |
140 | case kSparseCsc: |
141 | case kSparseBsr: |
142 | case kSparseBsc: |
143 | return true; |
144 | default:; |
145 | } |
146 | return false; |
147 | } |
148 | |
149 | inline bool is_sparse_compressed(const Tensor& self) { |
150 | return is_sparse_compressed(self.layout()); |
151 | } |
152 | |
153 | inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) { |
154 | AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS( |
155 | self.layout(), "get_sparse_csr_impl" , [&] {}); |
156 | return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl()); |
157 | } |
158 | |
159 | inline std::string layoutToString( |
160 | Layout layout, |
161 | bool upper = false, |
162 | bool lower = false) { |
163 | switch (layout) { |
164 | case kSparseCsr: |
165 | return (upper ? "CSR" : (lower ? "csr" : "Csr" )); |
166 | case kSparseCsc: |
167 | return (upper ? "CSC" : (lower ? "csc" : "Csc" )); |
168 | case kSparseBsr: |
169 | return (upper ? "BSR" : (lower ? "bsr" : "Bsr" )); |
170 | case kSparseBsc: |
171 | return (upper ? "BSC" : (lower ? "bsc" : "Bsc" )); |
172 | default: |
173 | TORCH_CHECK(false, "Not a sparse compressed layout:" , layout); |
174 | return "" ; |
175 | } |
176 | } |
177 | |
178 | inline bool isCompressedRow(Layout layout) { |
179 | return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( |
180 | layout, "isCompressedRow" , [&] { return true; }, [&] { return false; }); |
181 | } |
182 | |
183 | inline bool isCompressedColumn(Layout layout) { |
184 | return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( |
185 | layout, |
186 | "isCompressedColumn" , |
187 | [&] { return false; }, |
188 | [&] { return true; }); |
189 | } |
190 | |
191 | inline std::string compressedIndicesName(Layout layout) { |
192 | return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( |
193 | layout, |
194 | "compressedIndicesName" , |
195 | [&] { return "crow_indices" ; }, |
196 | [&] { return "ccol_indices" ; }); |
197 | } |
198 | |
199 | inline std::string plainIndicesName(Layout layout) { |
200 | return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( |
201 | layout, |
202 | "plainIndicesName" , |
203 | [&] { return "col_indices" ; }, |
204 | [&] { return "row_indices" ; }); |
205 | } |
206 | |
207 | inline std::string compressedDimName(Layout layout) { |
208 | switch (layout) { |
209 | case kSparseCsr: |
210 | return "row" ; |
211 | case kSparseCsc: |
212 | return "column" ; |
213 | case kSparseBsr: |
214 | return "row block" ; |
215 | case kSparseBsc: |
216 | return "column block" ; |
217 | default: |
218 | TORCH_CHECK(false, "Not a sparse compressed layout:" , layout); |
219 | return "" ; |
220 | } |
221 | } |
222 | |
223 | inline std::string plainDimName(Layout layout) { |
224 | switch (layout) { |
225 | case kSparseCsr: |
226 | return "column" ; |
227 | case kSparseCsc: |
228 | return "row" ; |
229 | case kSparseBsr: |
230 | return "column block" ; |
231 | case kSparseBsc: |
232 | return "row block" ; |
233 | default: |
234 | TORCH_CHECK(false, "Not a sparse compressed layout:" , layout); |
235 | return "" ; |
236 | } |
237 | } |
238 | |
239 | inline int rowDimension(Layout layout, IntArrayRef size) { |
240 | return size.size() - (isCompressedRow(layout) ? 2 : 1); |
241 | } |
242 | |
243 | inline int columnDimension(Layout layout, IntArrayRef size) { |
244 | return size.size() - (isCompressedColumn(layout) ? 2 : 1); |
245 | } |
246 | |
247 | inline int compressedDimension( |
248 | Layout layout, |
249 | IntArrayRef size, |
250 | size_t dense_ndim = 0) { |
251 | return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1); |
252 | } |
253 | |
254 | inline int plainDimension( |
255 | Layout layout, |
256 | IntArrayRef size, |
257 | size_t dense_ndim = 0) { |
258 | return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2); |
259 | } |
260 | |
261 | inline int64_t numBatchDimensions(Tensor const& self) { |
262 | return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( |
263 | self.layout(), |
264 | "numBatchDimensions" , |
265 | [&self] { return self.crow_indices().dim() - 1; }, |
266 | [&self] { return self.ccol_indices().dim() - 1; }); |
267 | } |
268 | |
269 | inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) { |
270 | return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( |
271 | self.layout(), |
272 | "getCompressedPlainIndices" , |
273 | [&self] { |
274 | return std::make_pair(self.crow_indices(), self.col_indices()); |
275 | }, |
276 | [&self] { |
277 | return std::make_pair(self.ccol_indices(), self.row_indices()); |
278 | }); |
279 | } |
280 | |
281 | inline Layout flip_compressed_layout(Layout layout) { |
282 | switch (layout) { |
283 | case kSparseCsr: |
284 | return kSparseCsc; |
285 | case kSparseCsc: |
286 | return kSparseCsr; |
287 | case kSparseBsr: |
288 | return kSparseBsc; |
289 | case kSparseBsc: |
290 | return kSparseBsr; |
291 | default: |
292 | TORCH_CHECK(false, "Not a sparse compressed layout:" , layout); |
293 | return kSparseCsr; |
294 | } |
295 | } |
296 | |
297 | inline DimVector getBlockSize(Tensor const& self) { |
298 | int64_t n_batch = numBatchDimensions(self); |
299 | return at::DimVector(self.values().sizes().slice(n_batch + 1, 2)); |
300 | } |
301 | |
302 | inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) { |
303 | if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) { |
304 | int64_t n_batch = numBatchDimensions(self); |
305 | return self.values().sym_sizes().slice(n_batch + 1, 2).vec(); |
306 | } else { |
307 | return {}; |
308 | } |
309 | } |
310 | |
311 | } // namespace sparse_csr |
312 | } // namespace at |
313 | |