1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/spmmUtils.h" |
9 | #include <cassert> |
10 | #include <cstring> |
11 | #include <iostream> |
12 | |
13 | using namespace std; |
14 | |
15 | namespace fbgemm { |
16 | |
17 | void sparseDenseMMRef( |
18 | int M, |
19 | int N, |
20 | const int* row_ptr, |
21 | const int* col_idx, |
22 | const float* values, |
23 | const float* B, |
24 | int ldb, |
25 | float* C, |
26 | int ldc, |
27 | bool accum) { |
28 | // Calcualtes accum ? C += A * B : C = A * B |
29 | // size of values is equal to number of non-zeros (nnzs) |
30 | // size of row_ptr is equal to M + 1 |
31 | // size of col_idx is equal to nnzs |
32 | for (int i = 0; i < M; ++i) { |
33 | if (!accum) { |
34 | for (int j = 0; j < N; ++j) { |
35 | C[i * ldc + j] = 0; |
36 | } |
37 | } |
38 | for (int r = row_ptr[i]; r < row_ptr[i + 1]; ++r) { |
39 | int acbr = col_idx[r]; |
40 | float v = values[r]; |
41 | for (int j = 0; j < N; ++j) { |
42 | C[i * ldc + j] += v * B[acbr * ldb + j]; |
43 | } |
44 | } |
45 | } |
46 | } |
47 | |
48 | template <bool FUSE_RELU, QuantizationGranularity Q_GRAN> |
49 | FBGEMM_API void trRequantizeRef( |
50 | uint8_t* out, |
51 | const int32_t* inp, |
52 | const block_type_t& block, |
53 | int ld_out, |
54 | int ld_in, |
55 | const trRequantizationParams_t& r) { |
56 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
57 | for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { |
58 | int32_t raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)]; |
59 | if (r.act_zero_point) { |
60 | raw -= r.act_zero_point * r.weight_row_offsets[i]; |
61 | } |
62 | int weight_zeropoint_idx; |
63 | if (Q_GRAN == QuantizationGranularity::TENSOR) { |
64 | weight_zeropoint_idx = 0; |
65 | } else { |
66 | // Q_GRAN == QuantizationGranularity::OUT_CHANNEL |
67 | weight_zeropoint_idx = i; |
68 | } |
69 | if (r.act_col_offsets) { |
70 | raw -= r.act_col_offsets[j - block.col_start] * |
71 | r.weight_zero_points[weight_zeropoint_idx]; |
72 | } |
73 | float raw_f = raw; |
74 | if (r.bias) { |
75 | raw_f += r.bias[i] / r.act_times_w_scale[weight_zeropoint_idx]; |
76 | } |
77 | |
78 | float ab = raw_f * r.act_times_w_scale[weight_zeropoint_idx] / r.C_scale; |
79 | int rounded = std::rintf(ab) + r.C_zero_point; |
80 | out[i * ld_out + j] = std::max( |
81 | FUSE_RELU ? static_cast<int>(r.C_zero_point) : 0, |
82 | std::min(255, rounded)); |
83 | } |
84 | } |
85 | } |
86 | |
87 | #define CREATE_INSTANCE(FUSE_RELU, QGRAN) \ |
88 | template FBGEMM_API void trRequantizeRef<FUSE_RELU, QGRAN>( \ |
89 | uint8_t * out, \ |
90 | const int32_t* inp, \ |
91 | const block_type_t& block, \ |
92 | int ld_out, \ |
93 | int ld_in, \ |
94 | const trRequantizationParams_t& r); |
95 | CREATE_INSTANCE(true, QuantizationGranularity::TENSOR) |
96 | CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL) |
97 | CREATE_INSTANCE(false, QuantizationGranularity::TENSOR) |
98 | CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL) |
99 | #undef CREATE_INSTANCE |
100 | |
101 | vector<vector<int>> getSparseMatrixShapes() { |
102 | // clang-format off |
103 | // {M, N, K} |
104 | vector<vector<int>> shapes = { |
105 | {1,128,160}, |
106 | {1,16,128}, |
107 | {1,256,160}, |
108 | {168,15,197}, |
109 | {168,8,197}, |
110 | {176,15,197}, |
111 | {176,8,197}, |
112 | {21,1,1027}, |
113 | {21,120,512}, |
114 | {21,125,300}, |
115 | {21,128,120}, |
116 | {21,128,176}, |
117 | {21,16,128}, |
118 | {21,256,5018}, |
119 | {21,256,512}, |
120 | {21,2955,512}, |
121 | {21,5018,256}, |
122 | {21,512,128}, |
123 | {21,512,2125}, |
124 | {21,512,256}, |
125 | {21,512,3851}, |
126 | {21,512,4085}, |
127 | {21,8,512}, |
128 | {22,1,1027}, |
129 | {22,120,512}, |
130 | {22,125,300}, |
131 | {22,128,120}, |
132 | {22,128,176}, |
133 | {22,16,128}, |
134 | {22,256,5018}, |
135 | {22,256,512}, |
136 | {22,2955,512}, |
137 | {22,5018,256}, |
138 | {22,512,128}, |
139 | {22,512,2125}, |
140 | {22,512,256}, |
141 | {22,512,3851}, |
142 | {22,512,4085}, |
143 | {22,8,512}, |
144 | {128,128,128}, |
145 | {256,256,256}, |
146 | {512,512,512}, |
147 | }; |
148 | |
149 | // RoBERTa shapes |
150 | const char* include_roberta = std::getenv("INCLUDE_ROBERTA" ); |
151 | if(include_roberta && (strcmp(include_roberta, "1" ) == 0)) { |
152 | vector<vector<int>> roberta_shapes = { |
153 | // average input length = 25 |
154 | {25, 2304, 768}, |
155 | {25, 768, 768}, |
156 | {25, 3072, 768}, |
157 | {25, 768, 3072}, |
158 | {25, 3072, 1024}, |
159 | {25, 1024, 1024}, |
160 | {25, 4096, 1024}, |
161 | {25, 1024, 4096}, |
162 | // high input length = 51 |
163 | {51, 2304, 768}, |
164 | {51, 768, 768}, |
165 | {51, 3072, 768}, |
166 | {51, 768, 3072}, |
167 | {51, 3072, 1024}, |
168 | {51, 1024, 1024}, |
169 | {51, 4096, 1024}, |
170 | {51, 1024, 4096}, |
171 | }; |
172 | shapes.insert(shapes.end(), roberta_shapes.begin(), roberta_shapes.end() ); |
173 | cout << "RoBERTa shapes included." << endl; |
174 | } |
175 | else { |
176 | cout << "RoBERTa shapes not included. " << |
177 | "To include, add \"INCLUDE_ROBERTA=1\" as an env variable." << endl; |
178 | } |
179 | |
180 | // LSTM shapes |
181 | const char* include_lstm = std::getenv("INCLUDE_LSTM" ); |
182 | if(include_lstm && (strcmp(include_lstm, "1" ) == 0)) { |
183 | vector<vector<int>> lstm_shapes = { |
184 | { 1, 2560, 640}, |
185 | {16, 2560, 640}, |
186 | {18, 2560, 640}, |
187 | { 1, 2560, 720}, |
188 | {16, 2560, 720}, |
189 | {18, 2560, 720}, |
190 | }; |
191 | shapes.insert(shapes.end(), lstm_shapes.begin(), lstm_shapes.end() ); |
192 | cout << "LSTM shapes included." << endl; |
193 | } |
194 | else { |
195 | cout << "LSTM shapes not included. " << |
196 | "To include, add \"INCLUDE_LSTM=1\" as an env variable." << endl; |
197 | } |
198 | |
199 | // RNNT shapes |
200 | const char* include_rnnt = std::getenv("INCLUDE_RNNT" ); |
201 | if(include_rnnt && (strcmp(include_rnnt, "1" ) == 0)) { |
202 | vector<vector<int>> rnnt_shapes = { |
203 | {1, 4096, 640}, |
204 | {1, 640, 1024}, |
205 | {5, 4096, 640}, |
206 | {20, 4096, 640}, |
207 | {4, 4096, 1024}, |
208 | {3, 4096, 1024}, |
209 | {1, 4096, 1024}, |
210 | {2, 4096, 1024}, |
211 | {5, 1024, 640}, |
212 | {5, 4096, 1280}, |
213 | {20, 4096, 880}, |
214 | {10, 4096, 640}, |
215 | {10, 4096, 1280}, |
216 | {5, 4096, 1024}, |
217 | {1, 1024, 640}, |
218 | {6, 4096, 1024}, |
219 | {1, 640, 256}, |
220 | {1, 1024, 256}, |
221 | {7, 4096, 1024}, |
222 | {8, 4096, 1024}, |
223 | {9, 4096, 1024}, |
224 | {7, 4096, 640}, |
225 | {4, 4096, 640}, |
226 | {28, 4096, 640}, |
227 | {16, 4096, 640}, |
228 | {10, 4096, 1024}, |
229 | {8, 4096, 640}, |
230 | {8, 4096, 1280}, |
231 | {7, 1024, 640}, |
232 | {7, 4096, 1280}, |
233 | {4, 1024, 640}, |
234 | {4, 4096, 1280}, |
235 | {28, 4096, 880}, |
236 | {16, 4096, 880}, |
237 | {14, 4096, 640}, |
238 | {14, 4096, 1280}, |
239 | {1, 256, 5000}, |
240 | {2, 256, 4500}, |
241 | {64, 256, 4500}, |
242 | }; |
243 | shapes.insert(shapes.end(), rnnt_shapes.begin(), rnnt_shapes.end() ); |
244 | cout << "rnnt shapes included." << endl; |
245 | } |
246 | else { |
247 | cout << "RNNT shapes not included. " << |
248 | "To include, add \"INCLUDE_RNNT=1\" as an env variable." << endl; |
249 | } |
250 | // clang-format on |
251 | return shapes; |
252 | } |
253 | |
254 | template <bool FUSE_RELU, QuantizationGranularity Q_GRAN> |
255 | void sparseDenseInt8MMRef( |
256 | int N, |
257 | const std::unique_ptr<BCSRMatrix<>>& bcsr, |
258 | const uint8_t* B, |
259 | int ldb, |
260 | int32_t* C_i32, |
261 | uint8_t* C_i8, |
262 | int ldc, |
263 | trRequantizationParams_t& rParams, |
264 | bool accum, |
265 | int /*thread_id*/, |
266 | int /*num_threads*/) { |
267 | // Calcualtes accum ? C += A * B : C = A * B |
268 | constexpr int rowBlockSize = BCSRMatrix<>::RB; |
269 | constexpr int colBlockSize = BCSRMatrix<>::CB; |
270 | constexpr int colTileSize = BCSRMatrix<>::COLTILE; |
271 | int M = bcsr->R; |
272 | int K = bcsr->C; |
273 | int kTiles = (K + colTileSize - 1) / colTileSize; |
274 | assert( |
275 | M % rowBlockSize == 0 && |
276 | "Number of rows is not a multiple of rowBlockSize size" ); |
277 | |
278 | for (int j = 0; j < N; ++j) { |
279 | for (int kt = 0; kt < kTiles; ++kt) { |
280 | int* rowBPtr_start = bcsr->rowBPtr.data() + kt * M; |
281 | for (int i = 0; i < M / rowBlockSize; i += rowBlockSize) { |
282 | // only initialize to 0 for the first ktile |
283 | if (!accum && !kt) { |
284 | C_i32[i * ldc + j] = 0; |
285 | } |
286 | for (int r = rowBPtr_start[i]; r < rowBPtr_start[i + 1]; ++r) { |
287 | int acbr_block = bcsr->colBIdx[r]; |
288 | const int8_t* blockValues = |
289 | bcsr->values.data() + r * rowBlockSize * colBlockSize; |
290 | for (int i_b = 0; i_b < rowBlockSize; ++i_b) { |
291 | for (int k_b = 0; k_b < colBlockSize; ++k_b) { |
292 | C_i32[(i * rowBlockSize + i_b) * ldc + j] += |
293 | static_cast<int32_t>(blockValues[i_b * colBlockSize + k_b]) * |
294 | static_cast<int32_t>( |
295 | B[(acbr_block * colBlockSize + k_b + kt * colTileSize) * |
296 | ldb + |
297 | j]); |
298 | } |
299 | } |
300 | } |
301 | } |
302 | } |
303 | } |
304 | block_type_t block{0, M, 0, N}; |
305 | trRequantizeRef<FUSE_RELU, Q_GRAN>(C_i8, C_i32, block, ldc, ldc, rParams); |
306 | } |
307 | |
308 | #define CREATE_INSTANCE(FUSE_RELU, QGRAN) \ |
309 | template void sparseDenseInt8MMRef<FUSE_RELU, QGRAN>( \ |
310 | int N, \ |
311 | const std::unique_ptr<BCSRMatrix<>>& bcsr, \ |
312 | const uint8_t* B, \ |
313 | int ldb, \ |
314 | int32_t* C_i32, \ |
315 | uint8_t* C_u8, \ |
316 | int ldc, \ |
317 | trRequantizationParams_t& rParams, \ |
318 | bool accum, \ |
319 | int thread_id, \ |
320 | int num_threads); |
321 | CREATE_INSTANCE(true, QuantizationGranularity::TENSOR) |
322 | CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL) |
323 | CREATE_INSTANCE(false, QuantizationGranularity::TENSOR) |
324 | CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL) |
325 | #undef CREATE_INSTANCE |
326 | |
327 | } // namespace fbgemm |
328 | |