1#pragma once
2
3#include "sparse_matrix.h"
4
5#include "taichi/ir/type.h"
6#include "taichi/rhi/cuda/cuda_driver.h"
7#include "taichi/program/program.h"
8
9#define DECLARE_EIGEN_LLT_SOLVER(dt, type, order) \
10 typedef EigenSparseSolver< \
11 Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
12 Eigen::order##Ordering<int>>, \
13 Eigen::SparseMatrix<dt>> \
14 EigenSparseSolver##dt##type##order;
15
16#define DECLARE_EIGEN_LU_SOLVER(dt, type, order) \
17 typedef EigenSparseSolver<Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \
18 Eigen::order##Ordering<int>>, \
19 Eigen::SparseMatrix<dt>> \
20 EigenSparseSolver##dt##type##order;
21
22namespace taichi::lang {
23
24class SparseSolver {
25 protected:
26 int rows_{0};
27 int cols_{0};
28 DataType dtype_{PrimitiveType::f32};
29 bool is_initialized_{false};
30
31 public:
32 virtual ~SparseSolver() = default;
33 void init_solver(const int rows, const int cols, const DataType dtype) {
34 rows_ = rows;
35 cols_ = cols;
36 dtype_ = dtype;
37 }
38 virtual bool compute(const SparseMatrix &sm) = 0;
39 virtual void analyze_pattern(const SparseMatrix &sm) = 0;
40 virtual void factorize(const SparseMatrix &sm) = 0;
41 virtual bool info() = 0;
42};
43
44template <class EigenSolver, class EigenMatrix>
45class EigenSparseSolver : public SparseSolver {
46 private:
47 EigenSolver solver_;
48
49 public:
50 ~EigenSparseSolver() override = default;
51 bool compute(const SparseMatrix &sm) override;
52 void analyze_pattern(const SparseMatrix &sm) override;
53 void factorize(const SparseMatrix &sm) override;
54 template <typename T>
55 T solve(const T &b);
56
57 template <typename T, typename V>
58 void solve_rf(Program *prog,
59 const SparseMatrix &sm,
60 const Ndarray &b,
61 const Ndarray &x);
62 bool info() override;
63};
64
65DECLARE_EIGEN_LLT_SOLVER(float32, LLT, AMD);
66DECLARE_EIGEN_LLT_SOLVER(float32, LLT, COLAMD);
67DECLARE_EIGEN_LLT_SOLVER(float32, LDLT, AMD);
68DECLARE_EIGEN_LLT_SOLVER(float32, LDLT, COLAMD);
69DECLARE_EIGEN_LU_SOLVER(float32, LU, AMD);
70DECLARE_EIGEN_LU_SOLVER(float32, LU, COLAMD);
71DECLARE_EIGEN_LLT_SOLVER(float64, LLT, AMD);
72DECLARE_EIGEN_LLT_SOLVER(float64, LLT, COLAMD);
73DECLARE_EIGEN_LLT_SOLVER(float64, LDLT, AMD);
74DECLARE_EIGEN_LLT_SOLVER(float64, LDLT, COLAMD);
75DECLARE_EIGEN_LU_SOLVER(float64, LU, AMD);
76DECLARE_EIGEN_LU_SOLVER(float64, LU, COLAMD);
77
78class CuSparseSolver : public SparseSolver {
79 public:
80 enum class SolverType { Cholesky, LU };
81
82 private:
83 SolverType solver_type_{SolverType::Cholesky};
84 csrcholInfo_t info_{nullptr};
85 csrluInfoHost_t lu_info_{nullptr};
86 cusolverSpHandle_t cusolver_handle_{nullptr};
87 cusparseHandle_t cusparse_handel_{nullptr};
88 cusparseMatDescr_t descr_{nullptr};
89 void *gpu_buffer_{nullptr};
90 void *cpu_buffer_{nullptr};
91 bool is_analyzed_{false};
92 bool is_factorized_{false};
93
94 int *h_Q_{
95 nullptr}; /* <int> n, B = Q*A*Q' or B = A(Q,Q) by MATLAB notation */
96 int *d_Q_{nullptr};
97 int *h_csrRowPtrB_{nullptr}; /* <int> n+1 */
98 int *h_csrColIndB_{nullptr}; /* <int> nnzA */
99 float *h_csrValB_{nullptr}; /* <float> nnzA */
100 int *h_mapBfromA_{nullptr}; /* <int> nnzA */
101 int *d_csrRowPtrB_{nullptr}; /* <int> n+1 */
102 int *d_csrColIndB_{nullptr}; /* <int> nnzA */
103 float *d_csrValB_{nullptr}; /* <float> nnzA */
104 public:
105 CuSparseSolver();
106 explicit CuSparseSolver(SolverType solver_type) : solver_type_(solver_type) {
107 init_solver();
108 }
109 ~CuSparseSolver() override;
110 bool compute(const SparseMatrix &sm) override {
111 TI_NOT_IMPLEMENTED;
112 };
113 void analyze_pattern(const SparseMatrix &sm) override;
114
115 void factorize(const SparseMatrix &sm) override;
116 void solve_rf(Program *prog,
117 const SparseMatrix &sm,
118 const Ndarray &b,
119 const Ndarray &x);
120
121 bool info() override {
122 TI_NOT_IMPLEMENTED;
123 };
124
125 private:
126 void init_solver();
127 void reorder(const CuSparseMatrix &sm);
128 void analyze_pattern_cholesky(const SparseMatrix &sm);
129 void analyze_pattern_lu(const SparseMatrix &sm);
130 void factorize_cholesky(const SparseMatrix &sm);
131 void factorize_lu(const SparseMatrix &sm);
132 void solve_cholesky(Program *prog,
133 const SparseMatrix &sm,
134 const Ndarray &b,
135 const Ndarray &x);
136 void solve_lu(Program *prog,
137 const SparseMatrix &sm,
138 const Ndarray &b,
139 const Ndarray &x);
140};
141
142std::unique_ptr<SparseSolver> make_sparse_solver(DataType dt,
143 const std::string &solver_type,
144 const std::string &ordering);
145
146std::unique_ptr<SparseSolver> make_cusparse_solver(
147 DataType dt,
148 const std::string &solver_type,
149 const std::string &ordering);
150} // namespace taichi::lang
151