1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmSparse.h" |
9 | |
10 | #include <algorithm> |
11 | #include <array> |
12 | #include <cassert> |
13 | #include <cstdio> |
14 | #include <cstring> |
15 | #include <memory> |
16 | #include <sstream> |
17 | #include <vector> |
18 | |
19 | #include "fbgemm/Utils.h" |
20 | #include "fbgemm/spmmUtils.h" |
21 | |
22 | using namespace std; |
23 | |
24 | namespace fbgemm { |
25 | |
26 | template <typename T> |
27 | FBGEMM_API std::unique_ptr<CSRMatrix<T>> |
28 | fbgemmDenseToCSR(int R, int C, const T* inp, int ld) { |
29 | unique_ptr<CSRMatrix<T>> csr(new CSRMatrix<T>()); |
30 | csr->rowPtr.push_back(0); |
31 | int nnz = 0; |
32 | for (int i = 0; i < R; ++i) { |
33 | for (int j = 0; j < C; ++j) { |
34 | if (inp[i * ld + j] != 0) { |
35 | csr->values.push_back(inp[i * ld + j]); |
36 | csr->colIdx.push_back(j); |
37 | nnz++; |
38 | } |
39 | } |
40 | csr->rowPtr.push_back(nnz); |
41 | } |
42 | return csr; |
43 | } |
44 | |
45 | template <typename T> |
46 | std::unique_ptr<CSRMatrix<T>> fbgemmDenseToCSR(int R, int C, const T* inp) { |
47 | return fbgemmDenseToCSR<T>(R, C, inp, C); |
48 | } |
49 | |
50 | template FBGEMM_API std::unique_ptr<CSRMatrix<int8_t>> |
51 | fbgemmDenseToCSR(int R, int C, const int8_t* inp); |
52 | template FBGEMM_API std::unique_ptr<CSRMatrix<float>> |
53 | fbgemmDenseToCSR(int R, int C, const float* inp); |
54 | |
55 | template FBGEMM_API std::unique_ptr<CSRMatrix<int8_t>> |
56 | fbgemmDenseToCSR(int R, int C, const int8_t* inp, int ld); |
57 | template FBGEMM_API std::unique_ptr<CSRMatrix<float>> |
58 | fbgemmDenseToCSR(int R, int C, const float* inp, int ld); |
59 | |
60 | template <typename T, int RB, int CB> |
61 | FBGEMM_API std::unique_ptr<BCSRMatrix<T, RB, CB>> |
62 | fbgemmDenseToBCSR(int R, int C, const T* inp, int ld) { |
63 | unique_ptr<BCSRMatrix<T, RB, CB>> bcsr(new BCSRMatrix<T, RB, CB>(R, C)); |
64 | bcsr->pack(inp, ld); |
65 | return bcsr; |
66 | } |
67 | |
68 | template <typename T, int RB, int CB> |
69 | FBGEMM_API std::unique_ptr<BCSRMatrix<T, RB, CB>> |
70 | fbgemmDenseToBCSR(int R, int C, const T* inp) { |
71 | return fbgemmDenseToBCSR<T, RB, CB>(R, C, inp, C); |
72 | } |
73 | |
74 | template <typename T, int RB, int CB> |
75 | constexpr int BCSRMatrix<T, RB, CB>::RB; |
76 | |
77 | template <typename T, int RB, int CB> |
78 | constexpr int BCSRMatrix<T, RB, CB>::CB; |
79 | |
80 | template <typename T, int RB, int CB> |
81 | constexpr int BCSRMatrix<T, RB, CB>::COLTILE; |
82 | |
83 | template <typename T, int RB, int CB> |
84 | void BCSRMatrix<T, RB, CB>::pack(const DTYPE* src, size_t ld) { |
85 | rowBPtr.push_back(0); |
86 | int nnzb = 0; |
87 | int numCOLTILEs = (C + COLTILE - 1) / COLTILE; |
88 | int rowBlocks = (R + RB - 1) / RB; |
89 | for (int jt = 0; jt < numCOLTILEs; ++jt) { |
90 | for (int i = 0; i < rowBlocks; ++i) { |
91 | int curCols = min(C - jt * COLTILE, COLTILE); |
92 | int curColBlocks = (curCols + CB - 1) / CB; |
93 | std::array<int32_t, RB> rowSum = {0}; |
94 | for (int j = 0; j < curColBlocks; ++j) { |
95 | // is the whole block zero? |
96 | bool isCurrentBlockNonZero = false; |
97 | for (int ib = 0; ib < RB; ++ib) { |
98 | // break if already found a non-zero element or |
99 | // out of bounds |
100 | if (isCurrentBlockNonZero || (i * RB + ib) >= R) { |
101 | break; |
102 | } |
103 | for (int jb = 0; jb < CB; ++jb) { |
104 | // within bound? |
105 | if ((jt * COLTILE + j * CB + jb) >= C) { |
106 | continue; |
107 | } else { |
108 | if (src[(i * RB + ib) * ld + jt * COLTILE + j * CB + jb] != 0) { |
109 | isCurrentBlockNonZero = true; |
110 | break; |
111 | } |
112 | } |
113 | } |
114 | } |
115 | if (isCurrentBlockNonZero) { |
116 | for (int ib = 0; ib < RB; ++ib) { |
117 | for (int jb = 0; jb < CB; ++jb) { |
118 | if ((i * RB + ib) >= R || (jt * COLTILE + j * CB + jb) >= C) { |
119 | // zero fill |
120 | values.push_back(0); |
121 | } else { |
122 | DTYPE val = |
123 | src[(i * RB + ib) * ld + jt * COLTILE + j * CB + jb]; |
124 | values.push_back(val); |
125 | rowSum[ib] += static_cast<int32_t>(val); |
126 | } |
127 | } |
128 | } |
129 | colBIdx.push_back(j); |
130 | nnzb++; |
131 | } |
132 | } |
133 | rowBPtr.push_back(nnzb); |
134 | // Note: in row_offsets we don't need to subtract the constant term |
135 | // weight_zero_point * C because it's 0 as weight_zero_point is always 0 |
136 | // for sparse kernels. |
137 | for (int ib = 0; ib < RB; ++ib) { |
138 | if (jt) { |
139 | row_offsets[i * RB + ib] += rowSum[ib]; |
140 | } else { |
141 | row_offsets[i * RB + ib] = rowSum[ib]; |
142 | } |
143 | } |
144 | } |
145 | } |
146 | } |
147 | |
148 | template <typename T, int RB, int CB> |
149 | void BCSRMatrix<T, RB, CB>::pack(const DTYPE* src) { |
150 | pack(src, C); |
151 | } |
152 | |
153 | template <typename T, int RB, int CB> |
154 | void BCSRMatrix<T, RB, CB>::unpack(T* dst, size_t ld) { |
155 | // zero out destination |
156 | memset(dst, 0, R * C * sizeof(T)); |
157 | |
158 | int numCOLTILEs = (C + COLTILE - 1) / COLTILE; |
159 | int rowBlocks = (R + RB - 1) / RB; |
160 | for (int jt = 0; jt < numCOLTILEs; ++jt) { |
161 | for (int i = 0; i < rowBlocks; ++i) { |
162 | // For the current tile, rowBPtr starts from currentTileIdx (i.e., jt) * R |
163 | for (int r = rowBPtr[jt * R + i]; r < rowBPtr[jt * R + i + 1]; ++r) { |
164 | int curColIdx = colBIdx[r]; |
165 | for (int ib = 0; ib < RB; ++ib) { |
166 | for (int jb = 0; jb < CB; ++jb) { |
167 | // Are we within bounds of destination matrix? |
168 | if ((i * RB + ib) < R && (jt * COLTILE + curColIdx * CB + jb) < C) { |
169 | dst[(i * RB + ib) * ld + jt * COLTILE + curColIdx * CB + jb] = |
170 | values[r * RB * CB + ib * CB + jb]; |
171 | } |
172 | } |
173 | } |
174 | } |
175 | } |
176 | } |
177 | } |
178 | |
179 | template <typename T, int RB, int CB> |
180 | void BCSRMatrix<T, RB, CB>::unpack(T* dst) { |
181 | unpack(dst, C); |
182 | } |
183 | |
184 | template struct BCSRMatrix<int8_t, 1, 4>; |
185 | |
186 | template struct CSRMatrix<int8_t>; |
187 | template struct CSRMatrix<float>; |
188 | |
189 | template FBGEMM_API std::unique_ptr<BCSRMatrix<int8_t, 1, 4>> |
190 | fbgemmDenseToBCSR(int R, int C, const int8_t* inp); |
191 | |
192 | template FBGEMM_API std::unique_ptr<BCSRMatrix<int8_t, 1, 4>> |
193 | fbgemmDenseToBCSR(int R, int C, const int8_t* inp, int ld); |
194 | |
195 | void SparseDenseMM( |
196 | int M, |
197 | int N, |
198 | const int* row_ptr, |
199 | const int* col_idx, |
200 | const float* values, |
201 | const float* B, |
202 | int ldb, |
203 | float* C, |
204 | int ldc, |
205 | bool accum) { |
206 | static const auto iset = fbgemmInstructionSet(); |
207 | // Run time CPU detection |
208 | if (isZmm(iset)) { |
209 | internal::SparseDenseMMAvx512( |
210 | M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum); |
211 | } else if (isYmm(iset)) { |
212 | internal::SparseDenseMMAvx2( |
213 | M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum); |
214 | } else { |
215 | sparseDenseMMRef(M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum); |
216 | } |
217 | } |
218 | |
219 | template <bool FUSE_RELU, QuantizationGranularity Q_GRAN> |
220 | FBGEMM_API void fbgemmSparseDenseInt8MM( |
221 | int N, |
222 | const std::unique_ptr<BCSRMatrix<>>& bcsr, |
223 | const uint8_t* B, |
224 | int ldb, |
225 | int32_t* C_i32, |
226 | uint8_t* C_u8, |
227 | int ldc, |
228 | trRequantizationParams_t& rParams, |
229 | bool accum, |
230 | int thread_id, |
231 | int num_threads) { |
232 | static const auto iset = fbgemmInstructionSet(); |
233 | // No parallelization currently |
234 | // All work is done by thread 0 |
235 | if (thread_id > 0) { |
236 | return; |
237 | } |
238 | |
239 | // Run time CPU detection |
240 | if (isZmm(iset)) { |
241 | internal::SparseDenseInt8MMAvx512<FUSE_RELU, Q_GRAN>( |
242 | N, |
243 | bcsr, |
244 | B, |
245 | ldb, |
246 | C_i32, |
247 | C_u8, |
248 | ldc, |
249 | rParams, |
250 | accum, |
251 | thread_id, |
252 | num_threads); |
253 | } else if (isYmm(iset)) { |
254 | internal::SparseDenseInt8MMAvx2<FUSE_RELU, Q_GRAN>( |
255 | N, |
256 | bcsr, |
257 | B, |
258 | ldb, |
259 | C_i32, |
260 | C_u8, |
261 | ldc, |
262 | rParams, |
263 | accum, |
264 | thread_id, |
265 | num_threads); |
266 | } else { |
267 | sparseDenseInt8MMRef<FUSE_RELU, Q_GRAN>( |
268 | N, |
269 | bcsr, |
270 | B, |
271 | ldb, |
272 | C_i32, |
273 | C_u8, |
274 | ldc, |
275 | rParams, |
276 | accum, |
277 | thread_id, |
278 | num_threads); |
279 | } |
280 | } |
281 | |
282 | #define CREATE_INSTANCE(FUSE_RELU, QGRAN) \ |
283 | template FBGEMM_API void fbgemmSparseDenseInt8MM<FUSE_RELU, QGRAN>( \ |
284 | int N, \ |
285 | const std::unique_ptr<BCSRMatrix<>>& bcsr, \ |
286 | const uint8_t* B, \ |
287 | int ldb, \ |
288 | int32_t* C_i32, \ |
289 | uint8_t* C_u8, \ |
290 | int ldc, \ |
291 | trRequantizationParams_t& rParams, \ |
292 | bool accum, \ |
293 | int thread_id, \ |
294 | int num_threads); |
295 | CREATE_INSTANCE(true, QuantizationGranularity::TENSOR) |
296 | CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL) |
297 | CREATE_INSTANCE(false, QuantizationGranularity::TENSOR) |
298 | CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL) |
299 | #undef CREATE_INSTANCE |
300 | |
301 | } // namespace fbgemm |
302 | |