1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/spmmUtils.h"
9#include <cassert>
10#include <cstring>
11#include <iostream>
12
13using namespace std;
14
15namespace fbgemm {
16
17void sparseDenseMMRef(
18 int M,
19 int N,
20 const int* row_ptr,
21 const int* col_idx,
22 const float* values,
23 const float* B,
24 int ldb,
25 float* C,
26 int ldc,
27 bool accum) {
28 // Calcualtes accum ? C += A * B : C = A * B
29 // size of values is equal to number of non-zeros (nnzs)
30 // size of row_ptr is equal to M + 1
31 // size of col_idx is equal to nnzs
32 for (int i = 0; i < M; ++i) {
33 if (!accum) {
34 for (int j = 0; j < N; ++j) {
35 C[i * ldc + j] = 0;
36 }
37 }
38 for (int r = row_ptr[i]; r < row_ptr[i + 1]; ++r) {
39 int acbr = col_idx[r];
40 float v = values[r];
41 for (int j = 0; j < N; ++j) {
42 C[i * ldc + j] += v * B[acbr * ldb + j];
43 }
44 }
45 }
46}
47
48template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
49FBGEMM_API void trRequantizeRef(
50 uint8_t* out,
51 const int32_t* inp,
52 const block_type_t& block,
53 int ld_out,
54 int ld_in,
55 const trRequantizationParams_t& r) {
56 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
57 for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
58 int32_t raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)];
59 if (r.act_zero_point) {
60 raw -= r.act_zero_point * r.weight_row_offsets[i];
61 }
62 int weight_zeropoint_idx;
63 if (Q_GRAN == QuantizationGranularity::TENSOR) {
64 weight_zeropoint_idx = 0;
65 } else {
66 // Q_GRAN == QuantizationGranularity::OUT_CHANNEL
67 weight_zeropoint_idx = i;
68 }
69 if (r.act_col_offsets) {
70 raw -= r.act_col_offsets[j - block.col_start] *
71 r.weight_zero_points[weight_zeropoint_idx];
72 }
73 float raw_f = raw;
74 if (r.bias) {
75 raw_f += r.bias[i] / r.act_times_w_scale[weight_zeropoint_idx];
76 }
77
78 float ab = raw_f * r.act_times_w_scale[weight_zeropoint_idx] / r.C_scale;
79 int rounded = std::rintf(ab) + r.C_zero_point;
80 out[i * ld_out + j] = std::max(
81 FUSE_RELU ? static_cast<int>(r.C_zero_point) : 0,
82 std::min(255, rounded));
83 }
84 }
85}
86
87#define CREATE_INSTANCE(FUSE_RELU, QGRAN) \
88 template FBGEMM_API void trRequantizeRef<FUSE_RELU, QGRAN>( \
89 uint8_t * out, \
90 const int32_t* inp, \
91 const block_type_t& block, \
92 int ld_out, \
93 int ld_in, \
94 const trRequantizationParams_t& r);
95CREATE_INSTANCE(true, QuantizationGranularity::TENSOR)
96CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL)
97CREATE_INSTANCE(false, QuantizationGranularity::TENSOR)
98CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL)
99#undef CREATE_INSTANCE
100
101vector<vector<int>> getSparseMatrixShapes() {
102 // clang-format off
103 // {M, N, K}
104 vector<vector<int>> shapes = {
105 {1,128,160},
106 {1,16,128},
107 {1,256,160},
108 {168,15,197},
109 {168,8,197},
110 {176,15,197},
111 {176,8,197},
112 {21,1,1027},
113 {21,120,512},
114 {21,125,300},
115 {21,128,120},
116 {21,128,176},
117 {21,16,128},
118 {21,256,5018},
119 {21,256,512},
120 {21,2955,512},
121 {21,5018,256},
122 {21,512,128},
123 {21,512,2125},
124 {21,512,256},
125 {21,512,3851},
126 {21,512,4085},
127 {21,8,512},
128 {22,1,1027},
129 {22,120,512},
130 {22,125,300},
131 {22,128,120},
132 {22,128,176},
133 {22,16,128},
134 {22,256,5018},
135 {22,256,512},
136 {22,2955,512},
137 {22,5018,256},
138 {22,512,128},
139 {22,512,2125},
140 {22,512,256},
141 {22,512,3851},
142 {22,512,4085},
143 {22,8,512},
144 {128,128,128},
145 {256,256,256},
146 {512,512,512},
147 };
148
149 // RoBERTa shapes
150 const char* include_roberta = std::getenv("INCLUDE_ROBERTA");
151 if(include_roberta && (strcmp(include_roberta, "1") == 0)) {
152 vector<vector<int>> roberta_shapes = {
153 // average input length = 25
154 {25, 2304, 768},
155 {25, 768, 768},
156 {25, 3072, 768},
157 {25, 768, 3072},
158 {25, 3072, 1024},
159 {25, 1024, 1024},
160 {25, 4096, 1024},
161 {25, 1024, 4096},
162 // high input length = 51
163 {51, 2304, 768},
164 {51, 768, 768},
165 {51, 3072, 768},
166 {51, 768, 3072},
167 {51, 3072, 1024},
168 {51, 1024, 1024},
169 {51, 4096, 1024},
170 {51, 1024, 4096},
171 };
172 shapes.insert(shapes.end(), roberta_shapes.begin(), roberta_shapes.end() );
173 cout << "RoBERTa shapes included." << endl;
174 }
175 else {
176 cout << "RoBERTa shapes not included. " <<
177 "To include, add \"INCLUDE_ROBERTA=1\" as an env variable." << endl;
178 }
179
180 // LSTM shapes
181 const char* include_lstm = std::getenv("INCLUDE_LSTM");
182 if(include_lstm && (strcmp(include_lstm, "1") == 0)) {
183 vector<vector<int>> lstm_shapes = {
184 { 1, 2560, 640},
185 {16, 2560, 640},
186 {18, 2560, 640},
187 { 1, 2560, 720},
188 {16, 2560, 720},
189 {18, 2560, 720},
190 };
191 shapes.insert(shapes.end(), lstm_shapes.begin(), lstm_shapes.end() );
192 cout << "LSTM shapes included." << endl;
193 }
194 else {
195 cout << "LSTM shapes not included. " <<
196 "To include, add \"INCLUDE_LSTM=1\" as an env variable." << endl;
197 }
198
199 // RNNT shapes
200 const char* include_rnnt = std::getenv("INCLUDE_RNNT");
201 if(include_rnnt && (strcmp(include_rnnt, "1") == 0)) {
202 vector<vector<int>> rnnt_shapes = {
203 {1, 4096, 640},
204 {1, 640, 1024},
205 {5, 4096, 640},
206 {20, 4096, 640},
207 {4, 4096, 1024},
208 {3, 4096, 1024},
209 {1, 4096, 1024},
210 {2, 4096, 1024},
211 {5, 1024, 640},
212 {5, 4096, 1280},
213 {20, 4096, 880},
214 {10, 4096, 640},
215 {10, 4096, 1280},
216 {5, 4096, 1024},
217 {1, 1024, 640},
218 {6, 4096, 1024},
219 {1, 640, 256},
220 {1, 1024, 256},
221 {7, 4096, 1024},
222 {8, 4096, 1024},
223 {9, 4096, 1024},
224 {7, 4096, 640},
225 {4, 4096, 640},
226 {28, 4096, 640},
227 {16, 4096, 640},
228 {10, 4096, 1024},
229 {8, 4096, 640},
230 {8, 4096, 1280},
231 {7, 1024, 640},
232 {7, 4096, 1280},
233 {4, 1024, 640},
234 {4, 4096, 1280},
235 {28, 4096, 880},
236 {16, 4096, 880},
237 {14, 4096, 640},
238 {14, 4096, 1280},
239 {1, 256, 5000},
240 {2, 256, 4500},
241 {64, 256, 4500},
242 };
243 shapes.insert(shapes.end(), rnnt_shapes.begin(), rnnt_shapes.end() );
244 cout << "rnnt shapes included." << endl;
245 }
246 else {
247 cout << "RNNT shapes not included. " <<
248 "To include, add \"INCLUDE_RNNT=1\" as an env variable." << endl;
249 }
250 // clang-format on
251 return shapes;
252}
253
254template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
255void sparseDenseInt8MMRef(
256 int N,
257 const std::unique_ptr<BCSRMatrix<>>& bcsr,
258 const uint8_t* B,
259 int ldb,
260 int32_t* C_i32,
261 uint8_t* C_i8,
262 int ldc,
263 trRequantizationParams_t& rParams,
264 bool accum,
265 int /*thread_id*/,
266 int /*num_threads*/) {
267 // Calcualtes accum ? C += A * B : C = A * B
268 constexpr int rowBlockSize = BCSRMatrix<>::RB;
269 constexpr int colBlockSize = BCSRMatrix<>::CB;
270 constexpr int colTileSize = BCSRMatrix<>::COLTILE;
271 int M = bcsr->R;
272 int K = bcsr->C;
273 int kTiles = (K + colTileSize - 1) / colTileSize;
274 assert(
275 M % rowBlockSize == 0 &&
276 "Number of rows is not a multiple of rowBlockSize size");
277
278 for (int j = 0; j < N; ++j) {
279 for (int kt = 0; kt < kTiles; ++kt) {
280 int* rowBPtr_start = bcsr->rowBPtr.data() + kt * M;
281 for (int i = 0; i < M / rowBlockSize; i += rowBlockSize) {
282 // only initialize to 0 for the first ktile
283 if (!accum && !kt) {
284 C_i32[i * ldc + j] = 0;
285 }
286 for (int r = rowBPtr_start[i]; r < rowBPtr_start[i + 1]; ++r) {
287 int acbr_block = bcsr->colBIdx[r];
288 const int8_t* blockValues =
289 bcsr->values.data() + r * rowBlockSize * colBlockSize;
290 for (int i_b = 0; i_b < rowBlockSize; ++i_b) {
291 for (int k_b = 0; k_b < colBlockSize; ++k_b) {
292 C_i32[(i * rowBlockSize + i_b) * ldc + j] +=
293 static_cast<int32_t>(blockValues[i_b * colBlockSize + k_b]) *
294 static_cast<int32_t>(
295 B[(acbr_block * colBlockSize + k_b + kt * colTileSize) *
296 ldb +
297 j]);
298 }
299 }
300 }
301 }
302 }
303 }
304 block_type_t block{0, M, 0, N};
305 trRequantizeRef<FUSE_RELU, Q_GRAN>(C_i8, C_i32, block, ldc, ldc, rParams);
306}
307
308#define CREATE_INSTANCE(FUSE_RELU, QGRAN) \
309 template void sparseDenseInt8MMRef<FUSE_RELU, QGRAN>( \
310 int N, \
311 const std::unique_ptr<BCSRMatrix<>>& bcsr, \
312 const uint8_t* B, \
313 int ldb, \
314 int32_t* C_i32, \
315 uint8_t* C_u8, \
316 int ldc, \
317 trRequantizationParams_t& rParams, \
318 bool accum, \
319 int thread_id, \
320 int num_threads);
321CREATE_INSTANCE(true, QuantizationGranularity::TENSOR)
322CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL)
323CREATE_INSTANCE(false, QuantizationGranularity::TENSOR)
324CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL)
325#undef CREATE_INSTANCE
326
327} // namespace fbgemm
328