1#pragma once
2
3#include "taichi/common/core.h"
4#include "taichi/inc/constants.h"
5#include "taichi/ir/type_utils.h"
6#include "taichi/program/ndarray.h"
7#include "taichi/program/program.h"
8#include "taichi/rhi/cuda/cuda_driver.h"
9
10#include "Eigen/Sparse"
11
12namespace taichi::lang {
13
14class SparseMatrix;
15
16class SparseMatrixBuilder {
17 public:
18 SparseMatrixBuilder(int rows,
19 int cols,
20 int max_num_triplets,
21 DataType dtype,
22 const std::string &storage_format,
23 Program *prog);
24
25 void print_triplets_eigen();
26 void print_triplets_cuda();
27
28 intptr_t get_ndarray_data_ptr() const;
29
30 std::unique_ptr<SparseMatrix> build();
31
32 std::unique_ptr<SparseMatrix> build_cuda();
33
34 void clear();
35
36 private:
37 template <typename T, typename G>
38 void build_template(std::unique_ptr<SparseMatrix> &);
39
40 template <typename T, typename G>
41 void print_triplets_template();
42
43 private:
44 uint64 num_triplets_{0};
45 std::unique_ptr<Ndarray> ndarray_data_base_ptr_{nullptr};
46 int rows_{0};
47 int cols_{0};
48 uint64 max_num_triplets_{0};
49 bool built_{false};
50 DataType dtype_{PrimitiveType::f32};
51 std::string storage_format_{"col_major"};
52 Program *prog_{nullptr};
53};
54
55class SparseMatrix {
56 public:
57 SparseMatrix() : rows_(0), cols_(0), dtype_(PrimitiveType::f32){};
58 SparseMatrix(int rows, int cols, DataType dt = PrimitiveType::f32)
59 : rows_{rows}, cols_(cols), dtype_(dt){};
60 SparseMatrix(SparseMatrix &sm)
61 : rows_(sm.rows_), cols_(sm.cols_), dtype_(sm.dtype_) {
62 }
63 SparseMatrix(SparseMatrix &&sm)
64 : rows_(sm.rows_), cols_(sm.cols_), dtype_(sm.dtype_) {
65 }
66 virtual ~SparseMatrix() = default;
67
68 virtual void build_triplets(void *triplets_adr) {
69 TI_NOT_IMPLEMENTED;
70 };
71
72 virtual void build_csr_from_coo(void *coo_row_ptr,
73 void *coo_col_ptr,
74 void *coo_values_ptr,
75 int nnz) {
76 TI_NOT_IMPLEMENTED;
77 }
78 inline const int num_rows() const {
79 return rows_;
80 }
81
82 inline const int num_cols() const {
83 return cols_;
84 }
85
86 virtual const std::string to_string() const {
87 return "";
88 }
89
90 virtual const void *get_matrix() const {
91 return nullptr;
92 }
93
94 inline const DataType get_data_type() const {
95 return dtype_;
96 }
97
98 template <class T>
99 T get_element(int row, int col) {
100 std::cout << "get_element not implemented" << std::endl;
101 return 0;
102 }
103
104 template <class T>
105 void set_element(int row, int col, T value) {
106 std::cout << "set_element not implemented" << std::endl;
107 return;
108 }
109
110 protected:
111 int rows_{0};
112 int cols_{0};
113 DataType dtype_{PrimitiveType::f32};
114};
115
116template <class EigenMatrix>
117class EigenSparseMatrix : public SparseMatrix {
118 public:
119 explicit EigenSparseMatrix(int rows, int cols, DataType dt)
120 : SparseMatrix(rows, cols, dt), matrix_(rows, cols) {
121 }
122 EigenSparseMatrix(EigenSparseMatrix &sm)
123 : SparseMatrix(sm.num_rows(), sm.num_cols(), sm.dtype_),
124 matrix_(sm.matrix_) {
125 }
126 EigenSparseMatrix(EigenSparseMatrix &&sm)
127 : SparseMatrix(sm.num_rows(), sm.num_cols(), sm.dtype_),
128 matrix_(sm.matrix_) {
129 }
130 explicit EigenSparseMatrix(const EigenMatrix &em)
131 : SparseMatrix(em.rows(), em.cols()), matrix_(em) {
132 }
133
134 ~EigenSparseMatrix() override = default;
135
136 void build_triplets(void *triplets_adr) override;
137 const std::string to_string() const override;
138
139 const void *get_matrix() const override {
140 return &matrix_;
141 };
142
143 virtual EigenSparseMatrix &operator+=(const EigenSparseMatrix &other) {
144 this->matrix_ += other.matrix_;
145 return *this;
146 };
147
148 friend EigenSparseMatrix operator+(const EigenSparseMatrix &lhs,
149 const EigenSparseMatrix &rhs) {
150 return EigenSparseMatrix(lhs.matrix_ + rhs.matrix_);
151 };
152
153 virtual EigenSparseMatrix &operator-=(const EigenSparseMatrix &other) {
154 this->matrix_ -= other.matrix_;
155 return *this;
156 }
157
158 friend EigenSparseMatrix operator-(const EigenSparseMatrix &lhs,
159 const EigenSparseMatrix &rhs) {
160 return EigenSparseMatrix(lhs.matrix_ - rhs.matrix_);
161 };
162
163 virtual EigenSparseMatrix &operator*=(float scale) {
164 this->matrix_ *= scale;
165 return *this;
166 }
167
168 friend EigenSparseMatrix operator*(const EigenSparseMatrix &sm, float scale) {
169 return EigenSparseMatrix(sm.matrix_ * scale);
170 }
171
172 friend EigenSparseMatrix operator*(float scale, const EigenSparseMatrix &sm) {
173 return EigenSparseMatrix(sm.matrix_ * scale);
174 }
175
176 friend EigenSparseMatrix operator*(const EigenSparseMatrix &lhs,
177 const EigenSparseMatrix &rhs) {
178 return EigenSparseMatrix(lhs.matrix_.cwiseProduct(rhs.matrix_));
179 }
180
181 EigenSparseMatrix transpose() {
182 return EigenSparseMatrix(matrix_.transpose());
183 }
184
185 EigenSparseMatrix matmul(const EigenSparseMatrix &sm) {
186 return EigenSparseMatrix(matrix_ * sm.matrix_);
187 }
188
189 template <typename T>
190 T get_element(int row, int col) {
191 return matrix_.coeff(row, col);
192 }
193
194 template <typename T>
195 void set_element(int row, int col, T value) {
196 matrix_.coeffRef(row, col) = value;
197 }
198
199 template <class VT>
200 VT mat_vec_mul(const Eigen::Ref<const VT> &b) {
201 return matrix_ * b;
202 }
203
204 void spmv(Program *prog, const Ndarray &x, const Ndarray &y);
205
206 private:
207 EigenMatrix matrix_;
208};
209
210class CuSparseMatrix : public SparseMatrix {
211 public:
212 explicit CuSparseMatrix(int rows, int cols, DataType dt)
213 : SparseMatrix(rows, cols, dt) {
214#if defined(TI_WITH_CUDA)
215 if (!CUSPARSEDriver::get_instance().is_loaded()) {
216 bool load_success = CUSPARSEDriver::get_instance().load_cusparse();
217 if (!load_success) {
218 TI_ERROR("Failed to load cusparse library!");
219 }
220 }
221#endif
222 }
223 explicit CuSparseMatrix(cusparseSpMatDescr_t A,
224 int rows,
225 int cols,
226 DataType dt,
227 void *csr_row_ptr,
228 void *csr_col_ind,
229 void *csr_val,
230 int nnz)
231 : SparseMatrix(rows, cols, dt),
232 matrix_(A),
233 csr_row_ptr_(csr_row_ptr),
234 csr_col_ind_(csr_col_ind),
235 csr_val_(csr_val),
236 nnz_(nnz) {
237 }
238 CuSparseMatrix(const CuSparseMatrix &sm)
239 : SparseMatrix(sm.rows_, sm.cols_, sm.dtype_), matrix_(sm.matrix_) {
240 }
241
242 ~CuSparseMatrix() override;
243
244 // TODO: Overload +=, -= and *=
245 friend std::unique_ptr<SparseMatrix> operator+(const CuSparseMatrix &lhs,
246 const CuSparseMatrix &rhs) {
247 auto m = lhs.addition(rhs, 1.0, 1.0);
248 return m;
249 };
250
251 friend std::unique_ptr<SparseMatrix> operator-(const CuSparseMatrix &lhs,
252 const CuSparseMatrix &rhs) {
253 return lhs.addition(rhs, 1.0, -1.0);
254 };
255
256 friend std::unique_ptr<SparseMatrix> operator*(const CuSparseMatrix &sm,
257 float scale) {
258 return sm.addition(sm, scale, 0.0);
259 }
260
261 friend std::unique_ptr<SparseMatrix> operator*(float scale,
262 const CuSparseMatrix &sm) {
263 return sm.addition(sm, scale, 0.0);
264 }
265
266 std::unique_ptr<SparseMatrix> addition(const CuSparseMatrix &other,
267 const float alpha,
268 const float beta) const;
269
270 std::unique_ptr<SparseMatrix> matmul(const CuSparseMatrix &other) const;
271
272 std::unique_ptr<SparseMatrix> gemm(const CuSparseMatrix &other,
273 const float alpha,
274 const float beta) const;
275
276 std::unique_ptr<SparseMatrix> transpose() const;
277
278 void build_csr_from_coo(void *coo_row_ptr,
279 void *coo_col_ptr,
280 void *coo_values_ptr,
281 int nnz) override;
282
283 void spmv(Program *prog, const Ndarray &x, const Ndarray &y);
284
285 const void *get_matrix() const override {
286 return &matrix_;
287 };
288
289 float get_element(int row, int col) const;
290
291 const std::string to_string() const override;
292
293 void *get_row_ptr() const {
294 return csr_row_ptr_;
295 }
296 void *get_col_ind() const {
297 return csr_col_ind_;
298 }
299 void *get_val_ptr() const {
300 return csr_val_;
301 }
302 int get_nnz() const {
303 return nnz_;
304 }
305
306 private:
307 cusparseSpMatDescr_t matrix_{nullptr};
308 void *csr_row_ptr_{nullptr};
309 void *csr_col_ind_{nullptr};
310 void *csr_val_{nullptr};
311 int nnz_{0};
312};
313
314std::unique_ptr<SparseMatrix> make_sparse_matrix(
315 int rows,
316 int cols,
317 DataType dt,
318 const std::string &storage_format);
319std::unique_ptr<SparseMatrix> make_cu_sparse_matrix(int rows,
320 int cols,
321 DataType dt);
322std::unique_ptr<SparseMatrix> make_cu_sparse_matrix(cusparseSpMatDescr_t mat,
323 int rows,
324 int cols,
325 DataType dt,
326 void *csr_row_ptr,
327 void *csr_col_ind,
328 void *csr_val_,
329 int nnz);
330
331void make_sparse_matrix_from_ndarray(Program *prog,
332 SparseMatrix &sm,
333 const Ndarray &ndarray);
334} // namespace taichi::lang
335