1 | #pragma once |
2 | |
3 | #include "sparse_matrix.h" |
4 | |
5 | #include "taichi/ir/type.h" |
6 | #include "taichi/rhi/cuda/cuda_driver.h" |
7 | #include "taichi/program/program.h" |
8 | |
9 | #define DECLARE_EIGEN_LLT_SOLVER(dt, type, order) \ |
10 | typedef EigenSparseSolver< \ |
11 | Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \ |
12 | Eigen::order##Ordering<int>>, \ |
13 | Eigen::SparseMatrix<dt>> \ |
14 | EigenSparseSolver##dt##type##order; |
15 | |
16 | #define DECLARE_EIGEN_LU_SOLVER(dt, type, order) \ |
17 | typedef EigenSparseSolver<Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \ |
18 | Eigen::order##Ordering<int>>, \ |
19 | Eigen::SparseMatrix<dt>> \ |
20 | EigenSparseSolver##dt##type##order; |
21 | |
22 | namespace taichi::lang { |
23 | |
24 | class SparseSolver { |
25 | protected: |
26 | int rows_{0}; |
27 | int cols_{0}; |
28 | DataType dtype_{PrimitiveType::f32}; |
29 | bool is_initialized_{false}; |
30 | |
31 | public: |
32 | virtual ~SparseSolver() = default; |
33 | void init_solver(const int rows, const int cols, const DataType dtype) { |
34 | rows_ = rows; |
35 | cols_ = cols; |
36 | dtype_ = dtype; |
37 | } |
38 | virtual bool compute(const SparseMatrix &sm) = 0; |
39 | virtual void analyze_pattern(const SparseMatrix &sm) = 0; |
40 | virtual void factorize(const SparseMatrix &sm) = 0; |
41 | virtual bool info() = 0; |
42 | }; |
43 | |
44 | template <class EigenSolver, class EigenMatrix> |
45 | class EigenSparseSolver : public SparseSolver { |
46 | private: |
47 | EigenSolver solver_; |
48 | |
49 | public: |
50 | ~EigenSparseSolver() override = default; |
51 | bool compute(const SparseMatrix &sm) override; |
52 | void analyze_pattern(const SparseMatrix &sm) override; |
53 | void factorize(const SparseMatrix &sm) override; |
54 | template <typename T> |
55 | T solve(const T &b); |
56 | |
57 | template <typename T, typename V> |
58 | void solve_rf(Program *prog, |
59 | const SparseMatrix &sm, |
60 | const Ndarray &b, |
61 | const Ndarray &x); |
62 | bool info() override; |
63 | }; |
64 | |
65 | DECLARE_EIGEN_LLT_SOLVER(float32, LLT, AMD); |
66 | DECLARE_EIGEN_LLT_SOLVER(float32, LLT, COLAMD); |
67 | DECLARE_EIGEN_LLT_SOLVER(float32, LDLT, AMD); |
68 | DECLARE_EIGEN_LLT_SOLVER(float32, LDLT, COLAMD); |
69 | DECLARE_EIGEN_LU_SOLVER(float32, LU, AMD); |
70 | DECLARE_EIGEN_LU_SOLVER(float32, LU, COLAMD); |
71 | DECLARE_EIGEN_LLT_SOLVER(float64, LLT, AMD); |
72 | DECLARE_EIGEN_LLT_SOLVER(float64, LLT, COLAMD); |
73 | DECLARE_EIGEN_LLT_SOLVER(float64, LDLT, AMD); |
74 | DECLARE_EIGEN_LLT_SOLVER(float64, LDLT, COLAMD); |
75 | DECLARE_EIGEN_LU_SOLVER(float64, LU, AMD); |
76 | DECLARE_EIGEN_LU_SOLVER(float64, LU, COLAMD); |
77 | |
78 | class CuSparseSolver : public SparseSolver { |
79 | public: |
80 | enum class SolverType { Cholesky, LU }; |
81 | |
82 | private: |
83 | SolverType solver_type_{SolverType::Cholesky}; |
84 | csrcholInfo_t info_{nullptr}; |
85 | csrluInfoHost_t lu_info_{nullptr}; |
86 | cusolverSpHandle_t cusolver_handle_{nullptr}; |
87 | cusparseHandle_t cusparse_handel_{nullptr}; |
88 | cusparseMatDescr_t descr_{nullptr}; |
89 | void *gpu_buffer_{nullptr}; |
90 | void *cpu_buffer_{nullptr}; |
91 | bool is_analyzed_{false}; |
92 | bool is_factorized_{false}; |
93 | |
94 | int *h_Q_{ |
95 | nullptr}; /* <int> n, B = Q*A*Q' or B = A(Q,Q) by MATLAB notation */ |
96 | int *d_Q_{nullptr}; |
97 | int *h_csrRowPtrB_{nullptr}; /* <int> n+1 */ |
98 | int *h_csrColIndB_{nullptr}; /* <int> nnzA */ |
99 | float *h_csrValB_{nullptr}; /* <float> nnzA */ |
100 | int *h_mapBfromA_{nullptr}; /* <int> nnzA */ |
101 | int *d_csrRowPtrB_{nullptr}; /* <int> n+1 */ |
102 | int *d_csrColIndB_{nullptr}; /* <int> nnzA */ |
103 | float *d_csrValB_{nullptr}; /* <float> nnzA */ |
104 | public: |
105 | CuSparseSolver(); |
106 | explicit CuSparseSolver(SolverType solver_type) : solver_type_(solver_type) { |
107 | init_solver(); |
108 | } |
109 | ~CuSparseSolver() override; |
110 | bool compute(const SparseMatrix &sm) override { |
111 | TI_NOT_IMPLEMENTED; |
112 | }; |
113 | void analyze_pattern(const SparseMatrix &sm) override; |
114 | |
115 | void factorize(const SparseMatrix &sm) override; |
116 | void solve_rf(Program *prog, |
117 | const SparseMatrix &sm, |
118 | const Ndarray &b, |
119 | const Ndarray &x); |
120 | |
121 | bool info() override { |
122 | TI_NOT_IMPLEMENTED; |
123 | }; |
124 | |
125 | private: |
126 | void init_solver(); |
127 | void reorder(const CuSparseMatrix &sm); |
128 | void analyze_pattern_cholesky(const SparseMatrix &sm); |
129 | void analyze_pattern_lu(const SparseMatrix &sm); |
130 | void factorize_cholesky(const SparseMatrix &sm); |
131 | void factorize_lu(const SparseMatrix &sm); |
132 | void solve_cholesky(Program *prog, |
133 | const SparseMatrix &sm, |
134 | const Ndarray &b, |
135 | const Ndarray &x); |
136 | void solve_lu(Program *prog, |
137 | const SparseMatrix &sm, |
138 | const Ndarray &b, |
139 | const Ndarray &x); |
140 | }; |
141 | |
142 | std::unique_ptr<SparseSolver> make_sparse_solver(DataType dt, |
143 | const std::string &solver_type, |
144 | const std::string &ordering); |
145 | |
146 | std::unique_ptr<SparseSolver> make_cusparse_solver( |
147 | DataType dt, |
148 | const std::string &solver_type, |
149 | const std::string &ordering); |
150 | } // namespace taichi::lang |
151 | |