1#include "taichi/ir/type_utils.h"
2
3#include "sparse_solver.h"
4
5#include <unordered_map>
6
7namespace taichi::lang {
8#define EIGEN_LLT_SOLVER_INSTANTIATION(dt, type, order) \
9 template class EigenSparseSolver< \
10 Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
11 Eigen::order##Ordering<int>>, \
12 Eigen::SparseMatrix<dt>>;
13#define EIGEN_LU_SOLVER_INSTANTIATION(dt, type, order) \
14 template class EigenSparseSolver< \
15 Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \
16 Eigen::order##Ordering<int>>, \
17 Eigen::SparseMatrix<dt>>;
18// Explicit instantiation of EigenSparseSolver
19EIGEN_LLT_SOLVER_INSTANTIATION(float32, LLT, AMD);
20EIGEN_LLT_SOLVER_INSTANTIATION(float32, LLT, COLAMD);
21EIGEN_LLT_SOLVER_INSTANTIATION(float32, LDLT, AMD);
22EIGEN_LLT_SOLVER_INSTANTIATION(float32, LDLT, COLAMD);
23EIGEN_LU_SOLVER_INSTANTIATION(float32, LU, AMD);
24EIGEN_LU_SOLVER_INSTANTIATION(float32, LU, COLAMD);
25EIGEN_LLT_SOLVER_INSTANTIATION(float64, LLT, AMD);
26EIGEN_LLT_SOLVER_INSTANTIATION(float64, LLT, COLAMD);
27EIGEN_LLT_SOLVER_INSTANTIATION(float64, LDLT, AMD);
28EIGEN_LLT_SOLVER_INSTANTIATION(float64, LDLT, COLAMD);
29EIGEN_LU_SOLVER_INSTANTIATION(float64, LU, AMD);
30EIGEN_LU_SOLVER_INSTANTIATION(float64, LU, COLAMD);
31} // namespace taichi::lang
32
33// Explicit instantiation of the template class EigenSparseSolver::solve
34#define EIGEN_LLT_SOLVE_INSTANTIATION(dt, type, order, df) \
35 using T##dt = Eigen::VectorX##df; \
36 using S##dt##type##order = \
37 Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
38 Eigen::order##Ordering<int>>; \
39 template T##dt \
40 EigenSparseSolver<S##dt##type##order, Eigen::SparseMatrix<dt>>::solve( \
41 const T##dt &b);
42#define EIGEN_LU_SOLVE_INSTANTIATION(dt, type, order, df) \
43 using LUT##dt = Eigen::VectorX##df; \
44 using LUS##dt##type##order = \
45 Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \
46 Eigen::order##Ordering<int>>; \
47 template LUT##dt \
48 EigenSparseSolver<LUS##dt##type##order, Eigen::SparseMatrix<dt>>::solve( \
49 const LUT##dt &b);
50
51// Explicit instantiation of the template class EigenSparseSolver::solve_rf
52#define INSTANTIATE_LLT_SOLVE_RF(dt, type, order, df) \
53 using llt##dt##type##order = \
54 Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
55 Eigen::order##Ordering<int>>; \
56 template void EigenSparseSolver<llt##dt##type##order, \
57 Eigen::SparseMatrix<dt>>::solve_rf<df, \
58 dt>( \
59 Program * prog, const SparseMatrix &sm, const Ndarray &b, \
60 const Ndarray &x);
61
62#define INSTANTIATE_LU_SOLVE_RF(dt, type, order, df) \
63 using lu##dt##type##order = \
64 Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \
65 Eigen::order##Ordering<int>>; \
66 template void EigenSparseSolver<lu##dt##type##order, \
67 Eigen::SparseMatrix<dt>>::solve_rf<df, \
68 dt>( \
69 Program * prog, const SparseMatrix &sm, const Ndarray &b, \
70 const Ndarray &x);
71
72#define MAKE_EIGEN_SOLVER(dt, type, order) \
73 std::make_unique<EigenSparseSolver##dt##type##order>()
74
75#define MAKE_SOLVER(dt, type, order) \
76 { \
77 {#dt, #type, #order}, []() -> std::unique_ptr<SparseSolver> { \
78 return MAKE_EIGEN_SOLVER(dt, type, order); \
79 } \
80 }
81
82using Triplets = std::tuple<std::string, std::string, std::string>;
83namespace {
84struct key_hash {
85 std::size_t operator()(const Triplets &k) const {
86 auto h1 = std::hash<std::string>{}(std::get<0>(k));
87 auto h2 = std::hash<std::string>{}(std::get<1>(k));
88 auto h3 = std::hash<std::string>{}(std::get<2>(k));
89 return h1 ^ h2 ^ h3;
90 }
91};
92} // namespace
93
94namespace taichi::lang {
95
96#define GET_EM(sm) \
97 const EigenMatrix *mat = (const EigenMatrix *)(sm.get_matrix());
98
99template <class EigenSolver, class EigenMatrix>
100bool EigenSparseSolver<EigenSolver, EigenMatrix>::compute(
101 const SparseMatrix &sm) {
102 if (!is_initialized_) {
103 SparseSolver::init_solver(sm.num_rows(), sm.num_cols(), sm.get_data_type());
104 }
105 GET_EM(sm);
106 solver_.compute(*mat);
107 if (solver_.info() != Eigen::Success) {
108 return false;
109 } else
110 return true;
111}
112template <class EigenSolver, class EigenMatrix>
113void EigenSparseSolver<EigenSolver, EigenMatrix>::analyze_pattern(
114 const SparseMatrix &sm) {
115 if (!is_initialized_) {
116 SparseSolver::init_solver(sm.num_rows(), sm.num_cols(), sm.get_data_type());
117 }
118 GET_EM(sm);
119 solver_.analyzePattern(*mat);
120}
121
122template <class EigenSolver, class EigenMatrix>
123void EigenSparseSolver<EigenSolver, EigenMatrix>::factorize(
124 const SparseMatrix &sm) {
125 GET_EM(sm);
126 solver_.factorize(*mat);
127}
128
129template <class EigenSolver, class EigenMatrix>
130template <typename T>
131T EigenSparseSolver<EigenSolver, EigenMatrix>::solve(const T &b) {
132 return solver_.solve(b);
133}
134
135EIGEN_LLT_SOLVE_INSTANTIATION(float32, LLT, AMD, f);
136EIGEN_LLT_SOLVE_INSTANTIATION(float32, LLT, COLAMD, f);
137EIGEN_LLT_SOLVE_INSTANTIATION(float32, LDLT, AMD, f);
138EIGEN_LLT_SOLVE_INSTANTIATION(float32, LDLT, COLAMD, f);
139EIGEN_LU_SOLVE_INSTANTIATION(float32, LU, AMD, f);
140EIGEN_LU_SOLVE_INSTANTIATION(float32, LU, COLAMD, f);
141EIGEN_LLT_SOLVE_INSTANTIATION(float64, LLT, AMD, d);
142EIGEN_LLT_SOLVE_INSTANTIATION(float64, LLT, COLAMD, d);
143EIGEN_LLT_SOLVE_INSTANTIATION(float64, LDLT, AMD, d);
144EIGEN_LLT_SOLVE_INSTANTIATION(float64, LDLT, COLAMD, d);
145EIGEN_LU_SOLVE_INSTANTIATION(float64, LU, AMD, d);
146EIGEN_LU_SOLVE_INSTANTIATION(float64, LU, COLAMD, d);
147
148template <class EigenSolver, class EigenMatrix>
149bool EigenSparseSolver<EigenSolver, EigenMatrix>::info() {
150 return solver_.info() == Eigen::Success;
151}
152
153template <class EigenSolver, class EigenMatrix>
154template <typename T, typename V>
155void EigenSparseSolver<EigenSolver, EigenMatrix>::solve_rf(
156 Program *prog,
157 const SparseMatrix &sm,
158 const Ndarray &b,
159 const Ndarray &x) {
160 size_t db = prog->get_ndarray_data_ptr_as_int(&b);
161 size_t dX = prog->get_ndarray_data_ptr_as_int(&x);
162 Eigen::Map<T>((V *)dX, rows_) = solver_.solve(Eigen::Map<T>((V *)db, cols_));
163}
164
165INSTANTIATE_LLT_SOLVE_RF(float32, LLT, COLAMD, Eigen::VectorXf)
166INSTANTIATE_LLT_SOLVE_RF(float32, LDLT, COLAMD, Eigen::VectorXf)
167INSTANTIATE_LLT_SOLVE_RF(float32, LLT, AMD, Eigen::VectorXf)
168INSTANTIATE_LLT_SOLVE_RF(float32, LDLT, AMD, Eigen::VectorXf)
169INSTANTIATE_LU_SOLVE_RF(float32, LU, AMD, Eigen::VectorXf)
170INSTANTIATE_LU_SOLVE_RF(float32, LU, COLAMD, Eigen::VectorXf)
171INSTANTIATE_LLT_SOLVE_RF(float64, LLT, COLAMD, Eigen::VectorXd)
172INSTANTIATE_LLT_SOLVE_RF(float64, LDLT, COLAMD, Eigen::VectorXd)
173INSTANTIATE_LLT_SOLVE_RF(float64, LLT, AMD, Eigen::VectorXd)
174INSTANTIATE_LLT_SOLVE_RF(float64, LDLT, AMD, Eigen::VectorXd)
175INSTANTIATE_LU_SOLVE_RF(float64, LU, AMD, Eigen::VectorXd)
176INSTANTIATE_LU_SOLVE_RF(float64, LU, COLAMD, Eigen::VectorXd)
177
178CuSparseSolver::CuSparseSolver() {
179 init_solver();
180}
181
182void CuSparseSolver::init_solver() {
183#if defined(TI_WITH_CUDA)
184 if (!CUSPARSEDriver::get_instance().is_loaded()) {
185 bool load_success = CUSPARSEDriver::get_instance().load_cusparse();
186 if (!load_success) {
187 TI_ERROR("Failed to load cusparse library!");
188 }
189 }
190 if (!CUSOLVERDriver::get_instance().is_loaded()) {
191 bool load_success = CUSOLVERDriver::get_instance().load_cusolver();
192 if (!load_success) {
193 TI_ERROR("Failed to load cusolver library!");
194 }
195 }
196#endif
197}
198void CuSparseSolver::reorder(const CuSparseMatrix &A) {
199#if defined(TI_WITH_CUDA)
200 size_t rowsA = A.num_rows();
201 size_t colsA = A.num_cols();
202 size_t nnzA = A.get_nnz();
203 void *d_csrRowPtrA = A.get_row_ptr();
204 void *d_csrColIndA = A.get_col_ind();
205 void *d_csrValA = A.get_val_ptr();
206 CUSOLVERDriver::get_instance().csSpCreate(&cusolver_handle_);
207 CUSPARSEDriver::get_instance().cpCreate(&cusparse_handel_);
208 CUSPARSEDriver::get_instance().cpCreateMatDescr(&descr_);
209 CUSPARSEDriver::get_instance().cpSetMatType(descr_,
210 CUSPARSE_MATRIX_TYPE_GENERAL);
211 CUSPARSEDriver::get_instance().cpSetMatIndexBase(descr_,
212 CUSPARSE_INDEX_BASE_ZERO);
213 float *h_csrValA = nullptr;
214 h_Q_ = (int *)malloc(sizeof(int) * colsA);
215 h_csrRowPtrB_ = (int *)malloc(sizeof(int) * (rowsA + 1));
216 h_csrColIndB_ = (int *)malloc(sizeof(int) * nnzA);
217 h_csrValB_ = (float *)malloc(sizeof(float) * nnzA);
218 h_csrValA = (float *)malloc(sizeof(float) * nnzA);
219 h_mapBfromA_ = (int *)malloc(sizeof(int) * nnzA);
220 assert(nullptr != h_Q_);
221 assert(nullptr != h_csrRowPtrB_);
222 assert(nullptr != h_csrColIndB_);
223 assert(nullptr != h_csrValB_);
224 assert(nullptr != h_mapBfromA_);
225
226 CUDADriver::get_instance().memcpy_device_to_host(h_csrRowPtrB_, d_csrRowPtrA,
227 sizeof(int) * (rowsA + 1));
228 CUDADriver::get_instance().memcpy_device_to_host(h_csrColIndB_, d_csrColIndA,
229 sizeof(int) * nnzA);
230 CUDADriver::get_instance().memcpy_device_to_host(h_csrValA, d_csrValA,
231 sizeof(float) * nnzA);
232
233 // compoute h_Q_
234 CUSOLVERDriver::get_instance().csSpXcsrsymamdHost(cusolver_handle_, rowsA,
235 nnzA, descr_, h_csrRowPtrB_,
236 h_csrColIndB_, h_Q_);
237 CUDADriver::get_instance().malloc((void **)&d_Q_, sizeof(int) * colsA);
238 CUDADriver::get_instance().memcpy_host_to_device((void *)d_Q_, (void *)h_Q_,
239 sizeof(int) * (colsA));
240 size_t size_perm = 0;
241 CUSOLVERDriver::get_instance().csSpXcsrperm_bufferSizeHost(
242 cusolver_handle_, rowsA, colsA, nnzA, descr_, h_csrRowPtrB_,
243 h_csrColIndB_, h_Q_, h_Q_, &size_perm);
244 void *buffer_cpu = (void *)malloc(sizeof(char) * size_perm);
245 assert(nullptr != buffer_cpu);
246 for (int j = 0; j < nnzA; j++) {
247 h_mapBfromA_[j] = j;
248 }
249 CUSOLVERDriver::get_instance().csSpXcsrpermHost(
250 cusolver_handle_, rowsA, colsA, nnzA, descr_, h_csrRowPtrB_,
251 h_csrColIndB_, h_Q_, h_Q_, h_mapBfromA_, buffer_cpu);
252 // B = A( mapBfromA )
253 for (int j = 0; j < nnzA; j++) {
254 h_csrValB_[j] = h_csrValA[h_mapBfromA_[j]];
255 }
256 CUDADriver::get_instance().malloc((void **)&d_csrRowPtrB_,
257 sizeof(int) * (rowsA + 1));
258 CUDADriver::get_instance().malloc((void **)&d_csrColIndB_,
259 sizeof(int) * nnzA);
260 CUDADriver::get_instance().malloc((void **)&d_csrValB_, sizeof(float) * nnzA);
261 CUDADriver::get_instance().memcpy_host_to_device(
262 (void *)d_csrRowPtrB_, (void *)h_csrRowPtrB_, sizeof(int) * (rowsA + 1));
263 CUDADriver::get_instance().memcpy_host_to_device(
264 (void *)d_csrColIndB_, (void *)h_csrColIndB_, sizeof(int) * nnzA);
265 CUDADriver::get_instance().memcpy_host_to_device(
266 (void *)d_csrValB_, (void *)h_csrValB_, sizeof(float) * nnzA);
267 free(h_csrValA);
268 free(buffer_cpu);
269#endif
270}
271
272// Reference:
273// https://github.com/NVIDIA/cuda-samples/blob/master/Samples/4_CUDA_Libraries/cuSolverSp_LowlevelCholesky/cuSolverSp_LowlevelCholesky.cpp
274void CuSparseSolver::analyze_pattern(const SparseMatrix &sm) {
275 switch (solver_type_) {
276 case SolverType::Cholesky:
277 analyze_pattern_cholesky(sm);
278 break;
279 case SolverType::LU:
280 analyze_pattern_lu(sm);
281 break;
282 default:
283 TI_NOT_IMPLEMENTED
284 }
285}
286void CuSparseSolver::analyze_pattern_cholesky(const SparseMatrix &sm) {
287#if defined(TI_WITH_CUDA)
288 // Retrive the info of the sparse matrix
289 SparseMatrix &sm_no_cv = const_cast<SparseMatrix &>(sm);
290 CuSparseMatrix &A = static_cast<CuSparseMatrix &>(sm_no_cv);
291
292 // step 1: reorder the sparse matrix
293 reorder(A);
294
295 // step 2: create opaque info structure
296 CUSOLVERDriver::get_instance().csSpCreateCsrcholInfo(&info_);
297
298 // step 3: analyze chol(A) to know structure of L
299 size_t rowsA = A.num_rows();
300 size_t nnzA = A.get_nnz();
301 CUSOLVERDriver::get_instance().csSpXcsrcholAnalysis(
302 cusolver_handle_, rowsA, nnzA, descr_, d_csrRowPtrB_, d_csrColIndB_,
303 info_);
304 is_analyzed_ = true;
305#else
306 TI_NOT_IMPLEMENTED
307#endif
308}
309void CuSparseSolver::analyze_pattern_lu(const SparseMatrix &sm) {
310#if defined(TI_WITH_CUDA)
311 // Retrive the info of the sparse matrix
312 SparseMatrix &sm_no_cv = const_cast<SparseMatrix &>(sm);
313 CuSparseMatrix &A = static_cast<CuSparseMatrix &>(sm_no_cv);
314
315 // step 1: reorder the sparse matrix
316 reorder(A);
317
318 // step 2: create opaque info structure
319 CUSOLVERDriver::get_instance().csSpCreateCsrluInfoHost(&lu_info_);
320
321 // step 3: analyze LU(B) to know structure of Q and R, and upper bound for
322 // nnz(L+U)
323 size_t rowsA = A.num_rows();
324 size_t nnzA = A.get_nnz();
325 CUSOLVERDriver::get_instance().csSpXcsrluAnalysisHost(
326 cusolver_handle_, rowsA, nnzA, descr_, h_csrRowPtrB_, h_csrColIndB_,
327 lu_info_);
328 is_analyzed_ = true;
329#else
330 TI_NOT_IMPLEMENTED
331#endif
332}
333void CuSparseSolver::factorize(const SparseMatrix &sm) {
334 switch (solver_type_) {
335 case SolverType::Cholesky:
336 factorize_cholesky(sm);
337 break;
338 case SolverType::LU:
339 factorize_lu(sm);
340 break;
341 default:
342 TI_NOT_IMPLEMENTED
343 }
344}
345void CuSparseSolver::factorize_cholesky(const SparseMatrix &sm) {
346#if defined(TI_WITH_CUDA)
347 // Retrive the info of the sparse matrix
348 SparseMatrix *sm_no_cv = const_cast<SparseMatrix *>(&sm);
349 CuSparseMatrix *A = static_cast<CuSparseMatrix *>(sm_no_cv);
350 size_t rowsA = A->num_rows();
351 size_t nnzA = A->get_nnz();
352
353 size_t size_internal = 0;
354 size_t size_chol = 0; // size of working space for csrlu
355 // step 1: workspace for chol(A)
356 CUSOLVERDriver::get_instance().csSpScsrcholBufferInfo(
357 cusolver_handle_, rowsA, nnzA, descr_, d_csrValB_, d_csrRowPtrB_,
358 d_csrColIndB_, info_, &size_internal, &size_chol);
359
360 if (size_chol > 0)
361 CUDADriver::get_instance().malloc(&gpu_buffer_, sizeof(char) * size_chol);
362
363 // step 2: compute A = L*L^T
364 CUSOLVERDriver::get_instance().csSpScsrcholFactor(
365 cusolver_handle_, rowsA, nnzA, descr_, d_csrValB_, d_csrRowPtrB_,
366 d_csrColIndB_, info_, gpu_buffer_);
367 // step 3: check if the matrix is singular
368 const float tol = 1.e-14;
369 int singularity = 0;
370 CUSOLVERDriver::get_instance().csSpScsrcholZeroPivot(cusolver_handle_, info_,
371 tol, &singularity);
372 TI_ASSERT(singularity == -1);
373 is_factorized_ = true;
374#else
375 TI_NOT_IMPLEMENTED
376#endif
377}
378void CuSparseSolver::factorize_lu(const SparseMatrix &sm) {
379#if defined(TI_WITH_CUDA)
380 // Retrive the info of the sparse matrix
381 SparseMatrix *sm_no_cv = const_cast<SparseMatrix *>(&sm);
382 CuSparseMatrix *A = static_cast<CuSparseMatrix *>(sm_no_cv);
383 size_t rowsA = A->num_rows();
384 size_t nnzA = A->get_nnz();
385 // step 4: workspace for LU(B)
386 size_t size_lu = 0;
387 size_t buffer_size = 0;
388 CUSOLVERDriver::get_instance().csSpScsrluBufferInfoHost(
389 cusolver_handle_, rowsA, nnzA, descr_, h_csrValB_, h_csrRowPtrB_,
390 h_csrColIndB_, lu_info_, &buffer_size, &size_lu);
391
392 if (cpu_buffer_)
393 free(cpu_buffer_);
394 cpu_buffer_ = (void *)malloc(sizeof(char) * size_lu);
395 assert(nullptr != cpu_buffer_);
396
397 // step 5: compute Ppivot * B = L * U
398 CUSOLVERDriver::get_instance().csSpScsrluFactorHost(
399 cusolver_handle_, rowsA, nnzA, descr_, h_csrValB_, h_csrRowPtrB_,
400 h_csrColIndB_, lu_info_, 1.0f, cpu_buffer_);
401
402 // step 6: check singularity by tol
403 int singularity = 0;
404 const float tol = 1.e-6;
405 CUSOLVERDriver::get_instance().csSpScsrluZeroPivotHost(
406 cusolver_handle_, lu_info_, tol, &singularity);
407 TI_ASSERT(singularity == -1);
408 is_factorized_ = true;
409#else
410 TI_NOT_IMPLEMENTED
411#endif
412}
413void CuSparseSolver::solve_rf(Program *prog,
414 const SparseMatrix &sm,
415 const Ndarray &b,
416 const Ndarray &x) {
417 switch (solver_type_) {
418 case SolverType::Cholesky:
419 solve_cholesky(prog, sm, b, x);
420 break;
421 case SolverType::LU:
422 solve_lu(prog, sm, b, x);
423 break;
424 default:
425 TI_NOT_IMPLEMENTED
426 }
427}
428
429void CuSparseSolver::solve_cholesky(Program *prog,
430 const SparseMatrix &sm,
431 const Ndarray &b,
432 const Ndarray &x) {
433#if defined(TI_WITH_CUDA)
434 if (is_analyzed_ == false) {
435 analyze_pattern(sm);
436 }
437 if (is_factorized_ == false) {
438 factorize(sm);
439 }
440 // Retrive the info of the sparse matrix
441 SparseMatrix *sm_no_cv = const_cast<SparseMatrix *>(&sm);
442 CuSparseMatrix *A = static_cast<CuSparseMatrix *>(sm_no_cv);
443 size_t rowsA = A->num_rows();
444 size_t colsA = A->num_cols();
445 size_t d_b = prog->get_ndarray_data_ptr_as_int(&b);
446 size_t d_x = prog->get_ndarray_data_ptr_as_int(&x);
447
448 // step 1: d_Qb = Q * b
449 void *d_Qb = nullptr;
450 CUDADriver::get_instance().malloc(&d_Qb, sizeof(float) * rowsA);
451 cusparseDnVecDescr_t vec_b;
452 cusparseSpVecDescr_t vec_Qb;
453 CUSPARSEDriver::get_instance().cpCreateDnVec(&vec_b, (int)rowsA, (void *)d_b,
454 CUDA_R_32F);
455 CUSPARSEDriver::get_instance().cpCreateSpVec(
456 &vec_Qb, (int)rowsA, (int)rowsA, d_Q_, d_Qb, CUSPARSE_INDEX_32I,
457 CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F);
458 CUSPARSEDriver::get_instance().cpGather(cusparse_handel_, vec_b, vec_Qb);
459
460 // step 2: solve B*z = Q*b using cholesky solver
461 void *d_z = nullptr;
462 CUDADriver::get_instance().malloc(&d_z, sizeof(float) * colsA);
463 CUSOLVERDriver::get_instance().csSpScsrcholSolve(
464 cusolver_handle_, rowsA, (void *)d_Qb, (void *)d_z, info_, gpu_buffer_);
465
466 // step 3: Q*x = z
467 cusparseSpVecDescr_t vecX;
468 cusparseDnVecDescr_t vecY;
469 CUSPARSEDriver::get_instance().cpCreateSpVec(
470 &vecX, (int)colsA, (int)colsA, d_Q_, d_z, CUSPARSE_INDEX_32I,
471 CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F);
472 CUSPARSEDriver::get_instance().cpCreateDnVec(&vecY, (int)colsA, (void *)d_x,
473 CUDA_R_32F);
474 CUSPARSEDriver::get_instance().cpScatter(cusparse_handel_, vecX, vecY);
475
476 if (d_Qb != nullptr)
477 CUDADriver::get_instance().mem_free(d_Qb);
478 if (d_z != nullptr)
479 CUDADriver::get_instance().mem_free(d_z);
480 CUSPARSEDriver::get_instance().cpDestroySpVec(vec_Qb);
481 CUSPARSEDriver::get_instance().cpDestroyDnVec(vec_b);
482 CUSPARSEDriver::get_instance().cpDestroySpVec(vecX);
483 CUSPARSEDriver::get_instance().cpDestroyDnVec(vecY);
484#else
485 TI_NOT_IMPLEMENTED
486#endif
487}
488
489void CuSparseSolver::solve_lu(Program *prog,
490 const SparseMatrix &sm,
491 const Ndarray &b,
492 const Ndarray &x) {
493#if defined(TI_WITH_CUDA)
494 if (is_analyzed_ == false) {
495 analyze_pattern(sm);
496 }
497 if (is_factorized_ == false) {
498 factorize(sm);
499 }
500
501 // Retrive the info of the sparse matrix
502 SparseMatrix *sm_no_cv = const_cast<SparseMatrix *>(&sm);
503 CuSparseMatrix *A = static_cast<CuSparseMatrix *>(sm_no_cv);
504 size_t rowsA = A->num_rows();
505 size_t colsA = A->num_cols();
506
507 // step 7: solve L*U*x = b
508 size_t d_b = prog->get_ndarray_data_ptr_as_int(&b);
509 size_t d_x = prog->get_ndarray_data_ptr_as_int(&x);
510 float *h_b = (float *)malloc(sizeof(float) * rowsA);
511 float *h_b_hat = (float *)malloc(sizeof(float) * rowsA);
512 float *h_x = (float *)malloc(sizeof(float) * colsA);
513 float *h_x_hat = (float *)malloc(sizeof(float) * colsA);
514 assert(nullptr != h_b);
515 assert(nullptr != h_b_hat);
516 assert(nullptr != h_x);
517 assert(nullptr != h_x_hat);
518 CUDADriver::get_instance().memcpy_device_to_host((void *)h_b, (void *)d_b,
519 sizeof(float) * rowsA);
520 CUDADriver::get_instance().memcpy_device_to_host((void *)h_x, (void *)d_x,
521 sizeof(float) * colsA);
522 for (int j = 0; j < rowsA; j++) {
523 h_b_hat[j] = h_b[h_Q_[j]];
524 }
525 CUSOLVERDriver::get_instance().csSpScsrluSolveHost(
526 cusolver_handle_, rowsA, h_b_hat, h_x_hat, lu_info_, cpu_buffer_);
527 for (int j = 0; j < colsA; j++) {
528 h_x[h_Q_[j]] = h_x_hat[j];
529 }
530 CUDADriver::get_instance().memcpy_host_to_device((void *)d_x, (void *)h_x,
531 sizeof(float) * colsA);
532
533 free(h_b);
534 free(h_b_hat);
535 free(h_x);
536 free(h_x_hat);
537#else
538 TI_NOT_IMPLEMENTED
539#endif
540}
541
542std::unique_ptr<SparseSolver> make_sparse_solver(DataType dt,
543 const std::string &solver_type,
544 const std::string &ordering) {
545 using key_type = Triplets;
546 using func_type = std::unique_ptr<SparseSolver> (*)();
547 static const std::unordered_map<key_type, func_type, key_hash>
548 solver_factory = {
549 MAKE_SOLVER(float32, LLT, AMD), MAKE_SOLVER(float32, LLT, COLAMD),
550 MAKE_SOLVER(float32, LDLT, AMD), MAKE_SOLVER(float32, LDLT, COLAMD),
551 MAKE_SOLVER(float64, LLT, AMD), MAKE_SOLVER(float64, LLT, COLAMD),
552 MAKE_SOLVER(float64, LDLT, AMD), MAKE_SOLVER(float64, LDLT, COLAMD)};
553 static const std::unordered_map<std::string, std::string> dt_map = {
554 {"f32", "float32"}, {"f64", "float64"}};
555 auto it = dt_map.find(taichi::lang::data_type_name(dt));
556 if (it == dt_map.end())
557 TI_ERROR("Not supported sparse solver data type: {}",
558 taichi::lang::data_type_name(dt));
559
560 Triplets solver_key = std::make_tuple(it->second, solver_type, ordering);
561 if (solver_factory.find(solver_key) != solver_factory.end()) {
562 auto solver_func = solver_factory.at(solver_key);
563 return solver_func();
564 } else if (solver_type == "LU") {
565 if (it->first == "f32") {
566 using EigenMatrix = Eigen::SparseMatrix<float32>;
567 using LU = Eigen::SparseLU<EigenMatrix>;
568 return std::make_unique<EigenSparseSolver<LU, EigenMatrix>>();
569 } else if (it->first == "f64") {
570 using EigenMatrix = Eigen::SparseMatrix<float64>;
571 using LU = Eigen::SparseLU<EigenMatrix>;
572 return std::make_unique<EigenSparseSolver<LU, EigenMatrix>>();
573 } else {
574 TI_ERROR("Not supported sparse solver data type: {}", it->second);
575 }
576 } else
577 TI_ERROR("Not supported sparse solver type: {}", solver_type);
578}
579
580CuSparseSolver::~CuSparseSolver() {
581#if defined(TI_WITH_CUDA)
582 if (h_Q_ != nullptr)
583 free(h_Q_);
584 if (h_csrRowPtrB_ != nullptr)
585 free(h_csrRowPtrB_);
586 if (h_csrColIndB_ != nullptr)
587 free(h_csrColIndB_);
588 if (h_csrValB_ != nullptr)
589 free(h_csrValB_);
590 if (h_mapBfromA_ != nullptr)
591 free(h_mapBfromA_);
592 if (cpu_buffer_ != nullptr)
593 free(cpu_buffer_);
594 if (info_ != nullptr)
595 CUSOLVERDriver::get_instance().csSpDestroyCsrcholInfo(info_);
596 if (lu_info_ != nullptr)
597 CUSOLVERDriver::get_instance().csSpDestroyCsrluInfoHost(lu_info_);
598 if (cusolver_handle_ != nullptr)
599 CUSOLVERDriver::get_instance().csSpDestory(cusolver_handle_);
600 if (cusparse_handel_ != nullptr)
601 CUSPARSEDriver::get_instance().cpDestroy(cusparse_handel_);
602 if (descr_ != nullptr)
603 CUSPARSEDriver::get_instance().cpDestroyMatDescr(descr_);
604 if (gpu_buffer_ != nullptr)
605 CUDADriver::get_instance().mem_free(gpu_buffer_);
606 if (d_Q_ != nullptr)
607 CUDADriver::get_instance().mem_free(d_Q_);
608 if (d_csrRowPtrB_ != nullptr)
609 CUDADriver::get_instance().mem_free(d_csrRowPtrB_);
610 if (d_csrColIndB_ != nullptr)
611 CUDADriver::get_instance().mem_free(d_csrColIndB_);
612 if (d_csrValB_ != nullptr)
613 CUDADriver::get_instance().mem_free(d_csrValB_);
614#endif
615}
616std::unique_ptr<SparseSolver> make_cusparse_solver(
617 DataType dt,
618 const std::string &solver_type,
619 const std::string &ordering) {
620 if (solver_type == "LLT" || solver_type == "LDLT") {
621 return std::make_unique<CuSparseSolver>(
622 CuSparseSolver::SolverType::Cholesky);
623 } else if (solver_type == "LU") {
624 return std::make_unique<CuSparseSolver>(CuSparseSolver::SolverType::LU);
625 } else {
626 TI_ERROR("Not supported sparse solver type: {}", solver_type);
627 }
628}
629} // namespace taichi::lang
630