1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmSparse.h"
9
10#include <algorithm>
11#include <array>
12#include <cassert>
13#include <cstdio>
14#include <cstring>
15#include <memory>
16#include <sstream>
17#include <vector>
18
19#include "fbgemm/Utils.h"
20#include "fbgemm/spmmUtils.h"
21
22using namespace std;
23
24namespace fbgemm {
25
26template <typename T>
27FBGEMM_API std::unique_ptr<CSRMatrix<T>>
28fbgemmDenseToCSR(int R, int C, const T* inp, int ld) {
29 unique_ptr<CSRMatrix<T>> csr(new CSRMatrix<T>());
30 csr->rowPtr.push_back(0);
31 int nnz = 0;
32 for (int i = 0; i < R; ++i) {
33 for (int j = 0; j < C; ++j) {
34 if (inp[i * ld + j] != 0) {
35 csr->values.push_back(inp[i * ld + j]);
36 csr->colIdx.push_back(j);
37 nnz++;
38 }
39 }
40 csr->rowPtr.push_back(nnz);
41 }
42 return csr;
43}
44
45template <typename T>
46std::unique_ptr<CSRMatrix<T>> fbgemmDenseToCSR(int R, int C, const T* inp) {
47 return fbgemmDenseToCSR<T>(R, C, inp, C);
48}
49
50template FBGEMM_API std::unique_ptr<CSRMatrix<int8_t>>
51fbgemmDenseToCSR(int R, int C, const int8_t* inp);
52template FBGEMM_API std::unique_ptr<CSRMatrix<float>>
53fbgemmDenseToCSR(int R, int C, const float* inp);
54
55template FBGEMM_API std::unique_ptr<CSRMatrix<int8_t>>
56fbgemmDenseToCSR(int R, int C, const int8_t* inp, int ld);
57template FBGEMM_API std::unique_ptr<CSRMatrix<float>>
58fbgemmDenseToCSR(int R, int C, const float* inp, int ld);
59
60template <typename T, int RB, int CB>
61FBGEMM_API std::unique_ptr<BCSRMatrix<T, RB, CB>>
62fbgemmDenseToBCSR(int R, int C, const T* inp, int ld) {
63 unique_ptr<BCSRMatrix<T, RB, CB>> bcsr(new BCSRMatrix<T, RB, CB>(R, C));
64 bcsr->pack(inp, ld);
65 return bcsr;
66}
67
68template <typename T, int RB, int CB>
69FBGEMM_API std::unique_ptr<BCSRMatrix<T, RB, CB>>
70fbgemmDenseToBCSR(int R, int C, const T* inp) {
71 return fbgemmDenseToBCSR<T, RB, CB>(R, C, inp, C);
72}
73
74template <typename T, int RB, int CB>
75constexpr int BCSRMatrix<T, RB, CB>::RB;
76
77template <typename T, int RB, int CB>
78constexpr int BCSRMatrix<T, RB, CB>::CB;
79
80template <typename T, int RB, int CB>
81constexpr int BCSRMatrix<T, RB, CB>::COLTILE;
82
83template <typename T, int RB, int CB>
84void BCSRMatrix<T, RB, CB>::pack(const DTYPE* src, size_t ld) {
85 rowBPtr.push_back(0);
86 int nnzb = 0;
87 int numCOLTILEs = (C + COLTILE - 1) / COLTILE;
88 int rowBlocks = (R + RB - 1) / RB;
89 for (int jt = 0; jt < numCOLTILEs; ++jt) {
90 for (int i = 0; i < rowBlocks; ++i) {
91 int curCols = min(C - jt * COLTILE, COLTILE);
92 int curColBlocks = (curCols + CB - 1) / CB;
93 std::array<int32_t, RB> rowSum = {0};
94 for (int j = 0; j < curColBlocks; ++j) {
95 // is the whole block zero?
96 bool isCurrentBlockNonZero = false;
97 for (int ib = 0; ib < RB; ++ib) {
98 // break if already found a non-zero element or
99 // out of bounds
100 if (isCurrentBlockNonZero || (i * RB + ib) >= R) {
101 break;
102 }
103 for (int jb = 0; jb < CB; ++jb) {
104 // within bound?
105 if ((jt * COLTILE + j * CB + jb) >= C) {
106 continue;
107 } else {
108 if (src[(i * RB + ib) * ld + jt * COLTILE + j * CB + jb] != 0) {
109 isCurrentBlockNonZero = true;
110 break;
111 }
112 }
113 }
114 }
115 if (isCurrentBlockNonZero) {
116 for (int ib = 0; ib < RB; ++ib) {
117 for (int jb = 0; jb < CB; ++jb) {
118 if ((i * RB + ib) >= R || (jt * COLTILE + j * CB + jb) >= C) {
119 // zero fill
120 values.push_back(0);
121 } else {
122 DTYPE val =
123 src[(i * RB + ib) * ld + jt * COLTILE + j * CB + jb];
124 values.push_back(val);
125 rowSum[ib] += static_cast<int32_t>(val);
126 }
127 }
128 }
129 colBIdx.push_back(j);
130 nnzb++;
131 }
132 }
133 rowBPtr.push_back(nnzb);
134 // Note: in row_offsets we don't need to subtract the constant term
135 // weight_zero_point * C because it's 0 as weight_zero_point is always 0
136 // for sparse kernels.
137 for (int ib = 0; ib < RB; ++ib) {
138 if (jt) {
139 row_offsets[i * RB + ib] += rowSum[ib];
140 } else {
141 row_offsets[i * RB + ib] = rowSum[ib];
142 }
143 }
144 }
145 }
146}
147
148template <typename T, int RB, int CB>
149void BCSRMatrix<T, RB, CB>::pack(const DTYPE* src) {
150 pack(src, C);
151}
152
153template <typename T, int RB, int CB>
154void BCSRMatrix<T, RB, CB>::unpack(T* dst, size_t ld) {
155 // zero out destination
156 memset(dst, 0, R * C * sizeof(T));
157
158 int numCOLTILEs = (C + COLTILE - 1) / COLTILE;
159 int rowBlocks = (R + RB - 1) / RB;
160 for (int jt = 0; jt < numCOLTILEs; ++jt) {
161 for (int i = 0; i < rowBlocks; ++i) {
162 // For the current tile, rowBPtr starts from currentTileIdx (i.e., jt) * R
163 for (int r = rowBPtr[jt * R + i]; r < rowBPtr[jt * R + i + 1]; ++r) {
164 int curColIdx = colBIdx[r];
165 for (int ib = 0; ib < RB; ++ib) {
166 for (int jb = 0; jb < CB; ++jb) {
167 // Are we within bounds of destination matrix?
168 if ((i * RB + ib) < R && (jt * COLTILE + curColIdx * CB + jb) < C) {
169 dst[(i * RB + ib) * ld + jt * COLTILE + curColIdx * CB + jb] =
170 values[r * RB * CB + ib * CB + jb];
171 }
172 }
173 }
174 }
175 }
176 }
177}
178
179template <typename T, int RB, int CB>
180void BCSRMatrix<T, RB, CB>::unpack(T* dst) {
181 unpack(dst, C);
182}
183
184template struct BCSRMatrix<int8_t, 1, 4>;
185
186template struct CSRMatrix<int8_t>;
187template struct CSRMatrix<float>;
188
189template FBGEMM_API std::unique_ptr<BCSRMatrix<int8_t, 1, 4>>
190fbgemmDenseToBCSR(int R, int C, const int8_t* inp);
191
192template FBGEMM_API std::unique_ptr<BCSRMatrix<int8_t, 1, 4>>
193fbgemmDenseToBCSR(int R, int C, const int8_t* inp, int ld);
194
195void SparseDenseMM(
196 int M,
197 int N,
198 const int* row_ptr,
199 const int* col_idx,
200 const float* values,
201 const float* B,
202 int ldb,
203 float* C,
204 int ldc,
205 bool accum) {
206 static const auto iset = fbgemmInstructionSet();
207 // Run time CPU detection
208 if (isZmm(iset)) {
209 internal::SparseDenseMMAvx512(
210 M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum);
211 } else if (isYmm(iset)) {
212 internal::SparseDenseMMAvx2(
213 M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum);
214 } else {
215 sparseDenseMMRef(M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum);
216 }
217}
218
219template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
220FBGEMM_API void fbgemmSparseDenseInt8MM(
221 int N,
222 const std::unique_ptr<BCSRMatrix<>>& bcsr,
223 const uint8_t* B,
224 int ldb,
225 int32_t* C_i32,
226 uint8_t* C_u8,
227 int ldc,
228 trRequantizationParams_t& rParams,
229 bool accum,
230 int thread_id,
231 int num_threads) {
232 static const auto iset = fbgemmInstructionSet();
233 // No parallelization currently
234 // All work is done by thread 0
235 if (thread_id > 0) {
236 return;
237 }
238
239 // Run time CPU detection
240 if (isZmm(iset)) {
241 internal::SparseDenseInt8MMAvx512<FUSE_RELU, Q_GRAN>(
242 N,
243 bcsr,
244 B,
245 ldb,
246 C_i32,
247 C_u8,
248 ldc,
249 rParams,
250 accum,
251 thread_id,
252 num_threads);
253 } else if (isYmm(iset)) {
254 internal::SparseDenseInt8MMAvx2<FUSE_RELU, Q_GRAN>(
255 N,
256 bcsr,
257 B,
258 ldb,
259 C_i32,
260 C_u8,
261 ldc,
262 rParams,
263 accum,
264 thread_id,
265 num_threads);
266 } else {
267 sparseDenseInt8MMRef<FUSE_RELU, Q_GRAN>(
268 N,
269 bcsr,
270 B,
271 ldb,
272 C_i32,
273 C_u8,
274 ldc,
275 rParams,
276 accum,
277 thread_id,
278 num_threads);
279 }
280}
281
282#define CREATE_INSTANCE(FUSE_RELU, QGRAN) \
283 template FBGEMM_API void fbgemmSparseDenseInt8MM<FUSE_RELU, QGRAN>( \
284 int N, \
285 const std::unique_ptr<BCSRMatrix<>>& bcsr, \
286 const uint8_t* B, \
287 int ldb, \
288 int32_t* C_i32, \
289 uint8_t* C_u8, \
290 int ldc, \
291 trRequantizationParams_t& rParams, \
292 bool accum, \
293 int thread_id, \
294 int num_threads);
295CREATE_INSTANCE(true, QuantizationGranularity::TENSOR)
296CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL)
297CREATE_INSTANCE(false, QuantizationGranularity::TENSOR)
298CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL)
299#undef CREATE_INSTANCE
300
301} // namespace fbgemm
302