1 | #pragma once |
2 | |
3 | #include "taichi/common/core.h" |
4 | #include "taichi/inc/constants.h" |
5 | #include "taichi/ir/type_utils.h" |
6 | #include "taichi/program/ndarray.h" |
7 | #include "taichi/program/program.h" |
8 | #include "taichi/rhi/cuda/cuda_driver.h" |
9 | |
10 | #include "Eigen/Sparse" |
11 | |
12 | namespace taichi::lang { |
13 | |
14 | class SparseMatrix; |
15 | |
16 | class SparseMatrixBuilder { |
17 | public: |
18 | SparseMatrixBuilder(int rows, |
19 | int cols, |
20 | int max_num_triplets, |
21 | DataType dtype, |
22 | const std::string &storage_format, |
23 | Program *prog); |
24 | |
25 | void print_triplets_eigen(); |
26 | void print_triplets_cuda(); |
27 | |
28 | intptr_t get_ndarray_data_ptr() const; |
29 | |
30 | std::unique_ptr<SparseMatrix> build(); |
31 | |
32 | std::unique_ptr<SparseMatrix> build_cuda(); |
33 | |
34 | void clear(); |
35 | |
36 | private: |
37 | template <typename T, typename G> |
38 | void build_template(std::unique_ptr<SparseMatrix> &); |
39 | |
40 | template <typename T, typename G> |
41 | void print_triplets_template(); |
42 | |
43 | private: |
44 | uint64 num_triplets_{0}; |
45 | std::unique_ptr<Ndarray> ndarray_data_base_ptr_{nullptr}; |
46 | int rows_{0}; |
47 | int cols_{0}; |
48 | uint64 max_num_triplets_{0}; |
49 | bool built_{false}; |
50 | DataType dtype_{PrimitiveType::f32}; |
51 | std::string storage_format_{"col_major" }; |
52 | Program *prog_{nullptr}; |
53 | }; |
54 | |
55 | class SparseMatrix { |
56 | public: |
57 | SparseMatrix() : rows_(0), cols_(0), dtype_(PrimitiveType::f32){}; |
58 | SparseMatrix(int rows, int cols, DataType dt = PrimitiveType::f32) |
59 | : rows_{rows}, cols_(cols), dtype_(dt){}; |
60 | SparseMatrix(SparseMatrix &sm) |
61 | : rows_(sm.rows_), cols_(sm.cols_), dtype_(sm.dtype_) { |
62 | } |
63 | SparseMatrix(SparseMatrix &&sm) |
64 | : rows_(sm.rows_), cols_(sm.cols_), dtype_(sm.dtype_) { |
65 | } |
66 | virtual ~SparseMatrix() = default; |
67 | |
68 | virtual void build_triplets(void *triplets_adr) { |
69 | TI_NOT_IMPLEMENTED; |
70 | }; |
71 | |
72 | virtual void build_csr_from_coo(void *coo_row_ptr, |
73 | void *coo_col_ptr, |
74 | void *coo_values_ptr, |
75 | int nnz) { |
76 | TI_NOT_IMPLEMENTED; |
77 | } |
78 | inline const int num_rows() const { |
79 | return rows_; |
80 | } |
81 | |
82 | inline const int num_cols() const { |
83 | return cols_; |
84 | } |
85 | |
86 | virtual const std::string to_string() const { |
87 | return "" ; |
88 | } |
89 | |
90 | virtual const void *get_matrix() const { |
91 | return nullptr; |
92 | } |
93 | |
94 | inline const DataType get_data_type() const { |
95 | return dtype_; |
96 | } |
97 | |
98 | template <class T> |
99 | T get_element(int row, int col) { |
100 | std::cout << "get_element not implemented" << std::endl; |
101 | return 0; |
102 | } |
103 | |
104 | template <class T> |
105 | void set_element(int row, int col, T value) { |
106 | std::cout << "set_element not implemented" << std::endl; |
107 | return; |
108 | } |
109 | |
110 | protected: |
111 | int rows_{0}; |
112 | int cols_{0}; |
113 | DataType dtype_{PrimitiveType::f32}; |
114 | }; |
115 | |
116 | template <class EigenMatrix> |
117 | class EigenSparseMatrix : public SparseMatrix { |
118 | public: |
119 | explicit EigenSparseMatrix(int rows, int cols, DataType dt) |
120 | : SparseMatrix(rows, cols, dt), matrix_(rows, cols) { |
121 | } |
122 | EigenSparseMatrix(EigenSparseMatrix &sm) |
123 | : SparseMatrix(sm.num_rows(), sm.num_cols(), sm.dtype_), |
124 | matrix_(sm.matrix_) { |
125 | } |
126 | EigenSparseMatrix(EigenSparseMatrix &&sm) |
127 | : SparseMatrix(sm.num_rows(), sm.num_cols(), sm.dtype_), |
128 | matrix_(sm.matrix_) { |
129 | } |
130 | explicit EigenSparseMatrix(const EigenMatrix &em) |
131 | : SparseMatrix(em.rows(), em.cols()), matrix_(em) { |
132 | } |
133 | |
134 | ~EigenSparseMatrix() override = default; |
135 | |
136 | void build_triplets(void *triplets_adr) override; |
137 | const std::string to_string() const override; |
138 | |
139 | const void *get_matrix() const override { |
140 | return &matrix_; |
141 | }; |
142 | |
143 | virtual EigenSparseMatrix &operator+=(const EigenSparseMatrix &other) { |
144 | this->matrix_ += other.matrix_; |
145 | return *this; |
146 | }; |
147 | |
148 | friend EigenSparseMatrix operator+(const EigenSparseMatrix &lhs, |
149 | const EigenSparseMatrix &rhs) { |
150 | return EigenSparseMatrix(lhs.matrix_ + rhs.matrix_); |
151 | }; |
152 | |
153 | virtual EigenSparseMatrix &operator-=(const EigenSparseMatrix &other) { |
154 | this->matrix_ -= other.matrix_; |
155 | return *this; |
156 | } |
157 | |
158 | friend EigenSparseMatrix operator-(const EigenSparseMatrix &lhs, |
159 | const EigenSparseMatrix &rhs) { |
160 | return EigenSparseMatrix(lhs.matrix_ - rhs.matrix_); |
161 | }; |
162 | |
163 | virtual EigenSparseMatrix &operator*=(float scale) { |
164 | this->matrix_ *= scale; |
165 | return *this; |
166 | } |
167 | |
168 | friend EigenSparseMatrix operator*(const EigenSparseMatrix &sm, float scale) { |
169 | return EigenSparseMatrix(sm.matrix_ * scale); |
170 | } |
171 | |
172 | friend EigenSparseMatrix operator*(float scale, const EigenSparseMatrix &sm) { |
173 | return EigenSparseMatrix(sm.matrix_ * scale); |
174 | } |
175 | |
176 | friend EigenSparseMatrix operator*(const EigenSparseMatrix &lhs, |
177 | const EigenSparseMatrix &rhs) { |
178 | return EigenSparseMatrix(lhs.matrix_.cwiseProduct(rhs.matrix_)); |
179 | } |
180 | |
181 | EigenSparseMatrix transpose() { |
182 | return EigenSparseMatrix(matrix_.transpose()); |
183 | } |
184 | |
185 | EigenSparseMatrix matmul(const EigenSparseMatrix &sm) { |
186 | return EigenSparseMatrix(matrix_ * sm.matrix_); |
187 | } |
188 | |
189 | template <typename T> |
190 | T get_element(int row, int col) { |
191 | return matrix_.coeff(row, col); |
192 | } |
193 | |
194 | template <typename T> |
195 | void set_element(int row, int col, T value) { |
196 | matrix_.coeffRef(row, col) = value; |
197 | } |
198 | |
199 | template <class VT> |
200 | VT mat_vec_mul(const Eigen::Ref<const VT> &b) { |
201 | return matrix_ * b; |
202 | } |
203 | |
204 | void spmv(Program *prog, const Ndarray &x, const Ndarray &y); |
205 | |
206 | private: |
207 | EigenMatrix matrix_; |
208 | }; |
209 | |
210 | class CuSparseMatrix : public SparseMatrix { |
211 | public: |
212 | explicit CuSparseMatrix(int rows, int cols, DataType dt) |
213 | : SparseMatrix(rows, cols, dt) { |
214 | #if defined(TI_WITH_CUDA) |
215 | if (!CUSPARSEDriver::get_instance().is_loaded()) { |
216 | bool load_success = CUSPARSEDriver::get_instance().load_cusparse(); |
217 | if (!load_success) { |
218 | TI_ERROR("Failed to load cusparse library!" ); |
219 | } |
220 | } |
221 | #endif |
222 | } |
223 | explicit CuSparseMatrix(cusparseSpMatDescr_t A, |
224 | int rows, |
225 | int cols, |
226 | DataType dt, |
227 | void *csr_row_ptr, |
228 | void *csr_col_ind, |
229 | void *csr_val, |
230 | int nnz) |
231 | : SparseMatrix(rows, cols, dt), |
232 | matrix_(A), |
233 | csr_row_ptr_(csr_row_ptr), |
234 | csr_col_ind_(csr_col_ind), |
235 | csr_val_(csr_val), |
236 | nnz_(nnz) { |
237 | } |
238 | CuSparseMatrix(const CuSparseMatrix &sm) |
239 | : SparseMatrix(sm.rows_, sm.cols_, sm.dtype_), matrix_(sm.matrix_) { |
240 | } |
241 | |
242 | ~CuSparseMatrix() override; |
243 | |
244 | // TODO: Overload +=, -= and *= |
245 | friend std::unique_ptr<SparseMatrix> operator+(const CuSparseMatrix &lhs, |
246 | const CuSparseMatrix &rhs) { |
247 | auto m = lhs.addition(rhs, 1.0, 1.0); |
248 | return m; |
249 | }; |
250 | |
251 | friend std::unique_ptr<SparseMatrix> operator-(const CuSparseMatrix &lhs, |
252 | const CuSparseMatrix &rhs) { |
253 | return lhs.addition(rhs, 1.0, -1.0); |
254 | }; |
255 | |
256 | friend std::unique_ptr<SparseMatrix> operator*(const CuSparseMatrix &sm, |
257 | float scale) { |
258 | return sm.addition(sm, scale, 0.0); |
259 | } |
260 | |
261 | friend std::unique_ptr<SparseMatrix> operator*(float scale, |
262 | const CuSparseMatrix &sm) { |
263 | return sm.addition(sm, scale, 0.0); |
264 | } |
265 | |
266 | std::unique_ptr<SparseMatrix> addition(const CuSparseMatrix &other, |
267 | const float alpha, |
268 | const float beta) const; |
269 | |
270 | std::unique_ptr<SparseMatrix> matmul(const CuSparseMatrix &other) const; |
271 | |
272 | std::unique_ptr<SparseMatrix> gemm(const CuSparseMatrix &other, |
273 | const float alpha, |
274 | const float beta) const; |
275 | |
276 | std::unique_ptr<SparseMatrix> transpose() const; |
277 | |
278 | void build_csr_from_coo(void *coo_row_ptr, |
279 | void *coo_col_ptr, |
280 | void *coo_values_ptr, |
281 | int nnz) override; |
282 | |
283 | void spmv(Program *prog, const Ndarray &x, const Ndarray &y); |
284 | |
285 | const void *get_matrix() const override { |
286 | return &matrix_; |
287 | }; |
288 | |
289 | float get_element(int row, int col) const; |
290 | |
291 | const std::string to_string() const override; |
292 | |
293 | void *get_row_ptr() const { |
294 | return csr_row_ptr_; |
295 | } |
296 | void *get_col_ind() const { |
297 | return csr_col_ind_; |
298 | } |
299 | void *get_val_ptr() const { |
300 | return csr_val_; |
301 | } |
302 | int get_nnz() const { |
303 | return nnz_; |
304 | } |
305 | |
306 | private: |
307 | cusparseSpMatDescr_t matrix_{nullptr}; |
308 | void *csr_row_ptr_{nullptr}; |
309 | void *csr_col_ind_{nullptr}; |
310 | void *csr_val_{nullptr}; |
311 | int nnz_{0}; |
312 | }; |
313 | |
314 | std::unique_ptr<SparseMatrix> make_sparse_matrix( |
315 | int rows, |
316 | int cols, |
317 | DataType dt, |
318 | const std::string &storage_format); |
319 | std::unique_ptr<SparseMatrix> make_cu_sparse_matrix(int rows, |
320 | int cols, |
321 | DataType dt); |
322 | std::unique_ptr<SparseMatrix> make_cu_sparse_matrix(cusparseSpMatDescr_t mat, |
323 | int rows, |
324 | int cols, |
325 | DataType dt, |
326 | void *csr_row_ptr, |
327 | void *csr_col_ind, |
328 | void *csr_val_, |
329 | int nnz); |
330 | |
331 | void make_sparse_matrix_from_ndarray(Program *prog, |
332 | SparseMatrix &sm, |
333 | const Ndarray &ndarray); |
334 | } // namespace taichi::lang |
335 | |