1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmSparse.h" |
9 | |
10 | #include "fbgemm/Utils.h" |
11 | #include "fbgemm/spmmUtilsAvx2.h" |
12 | |
13 | #if defined(__x86_64__) || defined(__i386__) || \ |
14 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
15 | #include <immintrin.h> |
16 | #endif |
17 | #include <algorithm> // for min and max |
18 | #include <cassert> |
19 | |
20 | namespace fbgemm { |
21 | |
22 | namespace internal { |
23 | |
24 | template < |
25 | bool FUSE_RELU, |
26 | bool ACT_ZP_0, // is activation zero point 0? |
27 | bool HAS_BIAS, |
28 | QuantizationGranularity Q_GRAN> |
29 | static inline __m512i |
30 | requantizeForMM(__m512i x[], int rowIdx, trRequantizationParams_t& rParams) { |
31 | __m512i C_zero_point_epi8_v = _mm512_set1_epi8(rParams.C_zero_point); |
32 | __m512i C_zero_point_epi16_v = _mm512_set1_epi16(rParams.C_zero_point); |
33 | // clang-format off |
34 | __m512i permute_mask_v = _mm512_set_epi32( |
35 | 0x0F, 0x0B, 0x07, 0x03, |
36 | 0x0E, 0x0A, 0x06, 0x02, |
37 | 0x0D, 0x09, 0x05, 0x01, |
38 | 0x0C, 0x08, 0x04, 0x00); |
39 | // clang-format on |
40 | int32_t row_offset = 0; |
41 | if (!ACT_ZP_0) { |
42 | row_offset = rParams.act_zero_point * rParams.weight_row_offsets[rowIdx]; |
43 | } |
44 | __m512i row_offset_v = _mm512_set1_epi32(row_offset); |
45 | |
46 | int weight_zeropoint_idx = 0; |
47 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
48 | weight_zeropoint_idx = rowIdx; |
49 | } |
50 | |
51 | __m512 bias_v; |
52 | if (HAS_BIAS) { |
53 | float bias = |
54 | rParams.bias[rowIdx] / rParams.act_times_w_scale[weight_zeropoint_idx]; |
55 | bias_v = _mm512_set1_ps(bias); |
56 | } |
57 | |
58 | __m512 act_times_w_div_c_v; |
59 | if (Q_GRAN == QuantizationGranularity::TENSOR) { |
60 | act_times_w_div_c_v = |
61 | _mm512_set1_ps(rParams.act_times_w_scale[0] / rParams.C_scale); |
62 | } else { |
63 | act_times_w_div_c_v = _mm512_set1_ps( |
64 | rParams.act_times_w_scale[weight_zeropoint_idx] / rParams.C_scale); |
65 | } |
66 | if (!ACT_ZP_0) { |
67 | x[0] = _mm512_sub_epi32(x[0], row_offset_v); |
68 | x[1] = _mm512_sub_epi32(x[1], row_offset_v); |
69 | x[2] = _mm512_sub_epi32(x[2], row_offset_v); |
70 | x[3] = _mm512_sub_epi32(x[3], row_offset_v); |
71 | } |
72 | |
73 | __m512 xf_v, yf_v, zf_v, wf_v; |
74 | if (HAS_BIAS) { |
75 | xf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x[0]), bias_v); |
76 | yf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x[1]), bias_v); |
77 | zf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x[2]), bias_v); |
78 | wf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x[3]), bias_v); |
79 | } else { |
80 | xf_v = _mm512_cvtepi32_ps(x[0]); |
81 | yf_v = _mm512_cvtepi32_ps(x[1]); |
82 | zf_v = _mm512_cvtepi32_ps(x[2]); |
83 | wf_v = _mm512_cvtepi32_ps(x[3]); |
84 | } |
85 | |
86 | __m512 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; |
87 | |
88 | x_scaled_v = _mm512_mul_ps(xf_v, act_times_w_div_c_v); |
89 | y_scaled_v = _mm512_mul_ps(yf_v, act_times_w_div_c_v); |
90 | z_scaled_v = _mm512_mul_ps(zf_v, act_times_w_div_c_v); |
91 | w_scaled_v = _mm512_mul_ps(wf_v, act_times_w_div_c_v); |
92 | |
93 | __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); |
94 | __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v); |
95 | __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v); |
96 | __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v); |
97 | |
98 | __m512i xy_packed_v = _mm512_adds_epi16( |
99 | _mm512_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); |
100 | __m512i zw_packed_v = _mm512_adds_epi16( |
101 | _mm512_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); |
102 | // _mm512_packus_epi16 takes care of saturating to uint8 range |
103 | __m512i xyzw_clamped_v = _mm512_packus_epi16(xy_packed_v, zw_packed_v); |
104 | if (FUSE_RELU) { |
105 | xyzw_clamped_v = _mm512_max_epu8(C_zero_point_epi8_v, xyzw_clamped_v); |
106 | } |
107 | |
108 | xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); |
109 | return xyzw_clamped_v; |
110 | } |
111 | |
112 | static inline __m512i permute_row(__m512i row) { |
113 | // clang-format off |
114 | __m256i shuffle_256v = _mm256_set_epi8( |
115 | 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, |
116 | 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); |
117 | // clang-format on |
118 | __m512i shuffle_512v = _mm512_castsi256_si512(shuffle_256v); |
119 | row = _mm512_shuffle_epi8( |
120 | row, _mm512_inserti64x4(shuffle_512v, shuffle_256v, 1)); |
121 | return row; |
122 | } |
123 | |
124 | static inline void interleave_4rows(__m512i data[]) { |
125 | __m512i __t0 = _mm512_unpacklo_epi32(data[0], data[1]); |
126 | __m512i __t1 = _mm512_unpackhi_epi32(data[0], data[1]); |
127 | __m512i __t2 = _mm512_unpacklo_epi32(data[2], data[3]); |
128 | __m512i __t3 = _mm512_unpackhi_epi32(data[2], data[3]); |
129 | __m512i __tt0 = _mm512_unpacklo_epi64(__t0, __t2); |
130 | __m512i __tt1 = _mm512_unpacklo_epi64(__t1, __t3); |
131 | __m512i __tt2 = _mm512_unpackhi_epi64(__t0, __t2); |
132 | __m512i __tt3 = _mm512_unpackhi_epi64(__t1, __t3); |
133 | __m512i row0 = _mm512_permutex2var_epi64( |
134 | __tt0, _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0), __tt2); |
135 | __m512i row1 = _mm512_permutex2var_epi64( |
136 | __tt1, _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0), __tt3); |
137 | __m512i row2 = _mm512_permutex2var_epi64( |
138 | __tt0, _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4), __tt2); |
139 | __m512i row3 = _mm512_permutex2var_epi64( |
140 | __tt1, _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4), __tt3); |
141 | __m512i row0i = _mm512_shuffle_i64x2(row0, row1, 0x44); |
142 | __m512i row1i = _mm512_shuffle_i64x2(row0, row1, 0xEE); |
143 | __m512i row2i = _mm512_shuffle_i64x2(row2, row3, 0x44); |
144 | __m512i row3i = _mm512_shuffle_i64x2(row2, row3, 0xEE); |
145 | // End of int32 transpose |
146 | // Now we only need a simple row permutation to get the right result |
147 | data[0] = permute_row(row0i); |
148 | data[1] = permute_row(row1i); |
149 | data[2] = permute_row(row2i); |
150 | data[3] = permute_row(row3i); |
151 | return; |
152 | } |
153 | |
154 | // By default, we proecess 4 column blocks (default value of COLBLOCKS). |
155 | // Each block is of size 64 (VLEN_INT8) uint8 values. |
156 | // We can change the number of blocks we process by |
157 | // using a different value for COLBLOCKS. |
158 | // For example, for cases with 16 <= N <= 32 and 4 rows are interleaved, |
159 | // we only need to process the first 32 columns, which means |
160 | // COLBLOCKS = 32 * 4 (row_interleave) / 64 (VLEN_INT8) = 2. |
161 | template <int UNROLL = 3, int COLBLOCKS = 4> |
162 | static inline void loopOverReductionDim( |
163 | const int* row_ptr, |
164 | int rowIdx, |
165 | const int* col_idx, |
166 | const int8_t* values, |
167 | const uint8_t* interleave_buffer, |
168 | __m512i one_16bit_v, |
169 | __m512i c_v[]) { |
170 | constexpr int VLEN_INT8 = 64; |
171 | |
172 | int r = row_ptr[rowIdx]; |
173 | int r_end_aligned = row_ptr[rowIdx] + |
174 | (row_ptr[rowIdx + 1] - row_ptr[rowIdx]) / UNROLL * UNROLL; |
175 | for (; r < r_end_aligned; r += UNROLL) { |
176 | __m512i a_v[UNROLL]; |
177 | int acbr_block[UNROLL]; |
178 | for (int i = 0; i < UNROLL; ++i) { |
179 | acbr_block[i] = col_idx[r + i]; |
180 | int32_t v = reinterpret_cast<const int32_t*>(values)[r + i]; |
181 | a_v[i] = _mm512_set1_epi32(v); |
182 | } |
183 | |
184 | __m512i br_v[UNROLL][COLBLOCKS]; |
185 | for (int i = 0; i < UNROLL; ++i) { |
186 | for (int idx = 0; idx < COLBLOCKS; ++idx) { |
187 | br_v[i][idx] = _mm512_loadu_si512( |
188 | interleave_buffer + (acbr_block[i] * COLBLOCKS + idx) * VLEN_INT8); |
189 | } |
190 | } |
191 | |
192 | for (int i = 0; i < UNROLL; ++i) { |
193 | for (int idx = 0; idx < COLBLOCKS; ++idx) { |
194 | __m512i c_i16_v = _mm512_maddubs_epi16(br_v[i][idx], a_v[i]); |
195 | __m512i c_i32_v = _mm512_madd_epi16(one_16bit_v, c_i16_v); |
196 | c_v[idx] = _mm512_add_epi32(c_v[idx], c_i32_v); |
197 | } |
198 | } |
199 | } |
200 | // remainder loop |
201 | for (; r < row_ptr[rowIdx + 1]; ++r) { |
202 | int acbr_block = col_idx[r]; |
203 | int32_t v = reinterpret_cast<const int32_t*>(values)[r]; |
204 | __m512i a_v = _mm512_set1_epi32(v); |
205 | |
206 | __m512i br_v[COLBLOCKS]; |
207 | for (int idx = 0; idx < COLBLOCKS; ++idx) { |
208 | br_v[idx] = _mm512_loadu_si512( |
209 | interleave_buffer + (acbr_block * COLBLOCKS + idx) * VLEN_INT8); |
210 | } |
211 | |
212 | for (int idx = 0; idx < COLBLOCKS; ++idx) { |
213 | __m512i c_i16_v = _mm512_maddubs_epi16(br_v[idx], a_v); |
214 | __m512i c_i32_v = _mm512_madd_epi16(one_16bit_v, c_i16_v); |
215 | c_v[idx] = _mm512_add_epi32(c_v[idx], c_i32_v); |
216 | } |
217 | } |
218 | } |
219 | |
220 | template <int ROWSIZE = 4, bool MASKLOAD = false> |
221 | static inline void loadBRows( |
222 | __m512i br_v[], |
223 | const uint8_t* B_start, |
224 | int ld, |
225 | __mmask64 mask_int8_v = 0) { |
226 | int idx = 0; |
227 | for (; idx < ROWSIZE; ++idx) { |
228 | if (MASKLOAD) { |
229 | br_v[idx] = _mm512_maskz_loadu_epi8(mask_int8_v, B_start + idx * ld); |
230 | } else { |
231 | br_v[idx] = _mm512_loadu_si512(B_start + idx * ld); |
232 | } |
233 | } |
234 | // set rests to 0 |
235 | for (; idx < 4; ++idx) { |
236 | br_v[idx] = _mm512_set1_epi32(0); |
237 | } |
238 | } |
239 | |
240 | // For COLBLOCKS, see description at loopOverReductionDim |
241 | template <int COLBLOCKS = 4> |
242 | static inline void |
243 | storeToInterleaveBuffer(__m512i br_v[], uint8_t* interleave_start, int ld) { |
244 | for (int idx = 0; idx < COLBLOCKS; ++idx) { |
245 | _mm512_storeu_si512(interleave_start + idx * ld, br_v[idx]); |
246 | } |
247 | } |
248 | |
249 | // For COLBLOCKS, see description at loopOverReductionDim |
250 | template <int COLBLOCKS = 4> |
251 | static inline void interleave4RowsTile( |
252 | int N, |
253 | int kSize, |
254 | const uint8_t* B, |
255 | uint8_t* interleave_buffer, |
256 | int ld, |
257 | const int col_start) { |
258 | constexpr int VLEN_INT8 = 64; |
259 | constexpr int colBlockSize = 4; |
260 | assert(colBlockSize == 4 && "column block size should be 4" ); |
261 | const int kBlocks = kSize / colBlockSize; |
262 | if (col_start < N / VLEN_INT8 * VLEN_INT8) { |
263 | __m512i br_v[4]; |
264 | int i = 0; |
265 | for (; i < kBlocks; ++i) { |
266 | loadBRows<4, false>(br_v, B + i * colBlockSize * ld + col_start, ld); |
267 | interleave_4rows(br_v); |
268 | storeToInterleaveBuffer<4>( |
269 | br_v, interleave_buffer + i * colBlockSize * VLEN_INT8, VLEN_INT8); |
270 | } |
271 | int rem = kSize - i * colBlockSize; |
272 | if (rem > 0) { |
273 | if (rem == 3) { |
274 | loadBRows<3, false>(br_v, B + i * colBlockSize * ld + col_start, ld); |
275 | } else if (rem == 2) { |
276 | loadBRows<2, false>(br_v, B + i * colBlockSize * ld + col_start, ld); |
277 | } else { |
278 | loadBRows<1, false>(br_v, B + i * colBlockSize * ld + col_start, ld); |
279 | } |
280 | interleave_4rows(br_v); |
281 | storeToInterleaveBuffer<4>( |
282 | br_v, interleave_buffer + i * colBlockSize * VLEN_INT8, VLEN_INT8); |
283 | } |
284 | } else { |
285 | int rem_int8 = N - col_start; |
286 | __mmask64 mask_int8_v = (((long long)1) << rem_int8) - 1; |
287 | __m512i br_v[4]; |
288 | int i = 0; |
289 | for (; i < kBlocks; ++i) { |
290 | loadBRows<4, true>( |
291 | br_v, B + i * colBlockSize * ld + col_start, ld, mask_int8_v); |
292 | interleave_4rows(br_v); |
293 | storeToInterleaveBuffer<COLBLOCKS>( |
294 | br_v, interleave_buffer + i * COLBLOCKS * VLEN_INT8, VLEN_INT8); |
295 | } |
296 | int rem = kSize - i * colBlockSize; |
297 | if (rem > 0) { |
298 | if (rem == 3) { |
299 | loadBRows<3, true>( |
300 | br_v, B + i * colBlockSize * ld + col_start, ld, mask_int8_v); |
301 | } else if (rem == 2) { |
302 | loadBRows<2, true>( |
303 | br_v, B + i * colBlockSize * ld + col_start, ld, mask_int8_v); |
304 | } else { |
305 | loadBRows<1, true>( |
306 | br_v, B + i * colBlockSize * ld + col_start, ld, mask_int8_v); |
307 | } |
308 | interleave_4rows(br_v); |
309 | storeToInterleaveBuffer<COLBLOCKS>( |
310 | br_v, interleave_buffer + i * COLBLOCKS * VLEN_INT8, VLEN_INT8); |
311 | } |
312 | } |
313 | return; |
314 | } |
315 | |
316 | template <bool FUSE_RELU, QuantizationGranularity Q_GRAN> |
317 | void SparseDenseInt8MMAvx512( |
318 | int N, |
319 | const std::unique_ptr<BCSRMatrix<>>& bcsr, |
320 | const uint8_t* B, |
321 | int ldb, |
322 | int32_t* C_i32, |
323 | uint8_t* C_u8, |
324 | int ldc, |
325 | trRequantizationParams_t& rParams, |
326 | bool accum, |
327 | int thread_id, |
328 | int num_threads) { |
329 | // gemv |
330 | if (N == 1 && ldb == 1 && ldc == 1 && bcsr->C % 4 == 0) { |
331 | return SparseDenseInt8MVAvx512<FUSE_RELU, Q_GRAN>( |
332 | bcsr, B, ldb, C_i32, C_u8, rParams, accum, thread_id, num_threads); |
333 | } |
334 | |
335 | // Calcualtes accum ? C += A * B : C = A * B |
336 | constexpr int VLEN_INT8 = 64; |
337 | constexpr int VLEN_INT32 = 16; |
338 | |
339 | constexpr int colTileSize = BCSRMatrix<>::COLTILE; |
340 | // Number of columns in the sparse matrix A |
341 | int K = bcsr->C; |
342 | int M = bcsr->R; |
343 | assert((K > 0) && "K needs to be positive" ); |
344 | int kTiles = (K + colTileSize - 1) / colTileSize; |
345 | const int* row_ptr = bcsr->rowBPtr.data(); |
346 | const int* col_idx = bcsr->colBIdx.data(); |
347 | const int8_t* values = bcsr->values.data(); |
348 | |
349 | constexpr int buffer_size = BCSRMatrix<>::COLTILE * VLEN_INT8; |
350 | static thread_local uint8_t* interleave_buffer_ = nullptr; |
351 | |
352 | if (interleave_buffer_ == nullptr) { |
353 | interleave_buffer_ = |
354 | static_cast<uint8_t*>(fbgemmAlignedAlloc(64, buffer_size)); |
355 | } |
356 | |
357 | assert( |
358 | (interleave_buffer_ != nullptr) && |
359 | "interleave_buffer_ cannot be nullptr" ); |
360 | |
361 | __m512i one_16bit_v = _mm512_set1_epi16(1); |
362 | int j = 0; |
363 | for (; j < N / VLEN_INT8 * VLEN_INT8; j += VLEN_INT8) { |
364 | for (int kt = 0; kt < kTiles; ++kt) { |
365 | int curKSize = std::min(K - kt * colTileSize, colTileSize); |
366 | interleave4RowsTile<4 /*COLBLOCKS*/>( |
367 | N, curKSize, B + kt * colTileSize * ldb, interleave_buffer_, ldb, j); |
368 | for (int i = 0; i < M; ++i) { |
369 | __m512i c_v[4]; |
370 | if (accum || kt > 0) { |
371 | for (int idx = 0; idx < 4; ++idx) { |
372 | c_v[idx] = _mm512_loadu_si512(C_i32 + i * ldb + idx * VLEN_INT32); |
373 | } |
374 | } else { |
375 | for (int idx = 0; idx < 4; ++idx) { |
376 | c_v[idx] = _mm512_set1_epi32(0); |
377 | } |
378 | } |
379 | |
380 | loopOverReductionDim<2 /*UNROLL*/, 4 /*COLBLOCKS*/>( |
381 | row_ptr + kt * M, |
382 | i, |
383 | col_idx, |
384 | values, |
385 | interleave_buffer_, |
386 | one_16bit_v, |
387 | c_v); |
388 | |
389 | if (kt == kTiles - 1) { |
390 | // Requantize after last ktile |
391 | __m512i res; |
392 | if (rParams.bias == nullptr) { |
393 | if (rParams.act_zero_point) { |
394 | res = requantizeForMM<FUSE_RELU, false, false, Q_GRAN>( |
395 | c_v, i, rParams); |
396 | } else { |
397 | res = requantizeForMM<FUSE_RELU, true, false, Q_GRAN>( |
398 | c_v, i, rParams); |
399 | } |
400 | } else { |
401 | if (rParams.act_zero_point) { |
402 | res = requantizeForMM<FUSE_RELU, false, true, Q_GRAN>( |
403 | c_v, i, rParams); |
404 | } else { |
405 | res = requantizeForMM<FUSE_RELU, true, true, Q_GRAN>( |
406 | c_v, i, rParams); |
407 | } |
408 | } |
409 | _mm512_storeu_si512(C_u8 + i * ldc + j, res); |
410 | } else { |
411 | // store the results |
412 | for (int idx = 0; idx < 4; ++idx) { |
413 | _mm512_storeu_si512(C_i32 + i * ldb + idx * VLEN_INT32, c_v[idx]); |
414 | } |
415 | } |
416 | } |
417 | } |
418 | } |
419 | // Handle remainder j loop |
420 | int rem_int8 = N - j; |
421 | int rem_int32 = N % VLEN_INT32; |
422 | int colBlocks = (rem_int8 + VLEN_INT32 - 1) / VLEN_INT32; |
423 | if (rem_int8 > 0) { |
424 | for (int kt = 0; kt < kTiles; ++kt) { |
425 | // last k tile may have less than colTileSize columns of A matrix (aka |
426 | // rows of B) |
427 | int curKSize = std::min(K - kt * colTileSize, colTileSize); |
428 | switch (colBlocks) { |
429 | case 1: |
430 | interleave4RowsTile<1>( |
431 | N, |
432 | curKSize, |
433 | B + kt * colTileSize * ldb, |
434 | interleave_buffer_, |
435 | ldb, |
436 | j); |
437 | break; |
438 | case 2: |
439 | interleave4RowsTile<2>( |
440 | N, |
441 | curKSize, |
442 | B + kt * colTileSize * ldb, |
443 | interleave_buffer_, |
444 | ldb, |
445 | j); |
446 | break; |
447 | case 3: |
448 | interleave4RowsTile<3>( |
449 | N, |
450 | curKSize, |
451 | B + kt * colTileSize * ldb, |
452 | interleave_buffer_, |
453 | ldb, |
454 | j); |
455 | break; |
456 | case 4: |
457 | interleave4RowsTile<4>( |
458 | N, |
459 | curKSize, |
460 | B + kt * colTileSize * ldb, |
461 | interleave_buffer_, |
462 | ldb, |
463 | j); |
464 | break; |
465 | default: |
466 | // not reachable |
467 | break; |
468 | } |
469 | |
470 | __mmask16 mask_int32_v = (((long long)1) << rem_int32) - 1; |
471 | __mmask64 mask_int8_v = (((long long)1) << rem_int8) - 1; |
472 | for (int i = 0; i < M; ++i) { |
473 | __m512i c_v[4] = {}; |
474 | if (accum || kt > 0) { |
475 | int idx = 0; |
476 | for (; idx < rem_int8 / VLEN_INT32; ++idx) { |
477 | c_v[idx] = _mm512_loadu_si512(C_i32 + i * ldb + idx * VLEN_INT32); |
478 | } |
479 | c_v[idx] = _mm512_maskz_loadu_epi32( |
480 | mask_int32_v, C_i32 + i * ldb + idx * VLEN_INT32); |
481 | } |
482 | |
483 | switch (colBlocks) { |
484 | case 1: |
485 | loopOverReductionDim<3 /*UNROLL*/, 1 /*colBlocks*/>( |
486 | row_ptr + M * kt, |
487 | i, |
488 | col_idx, |
489 | values, |
490 | interleave_buffer_, |
491 | one_16bit_v, |
492 | c_v); |
493 | break; |
494 | case 2: |
495 | loopOverReductionDim<3 /*UNROLL*/, 2 /*colBlocks*/>( |
496 | row_ptr + M * kt, |
497 | i, |
498 | col_idx, |
499 | values, |
500 | interleave_buffer_, |
501 | one_16bit_v, |
502 | c_v); |
503 | break; |
504 | case 3: |
505 | loopOverReductionDim<2 /*UNROLL*/, 3 /*colBlocks*/>( |
506 | row_ptr + M * kt, |
507 | i, |
508 | col_idx, |
509 | values, |
510 | interleave_buffer_, |
511 | one_16bit_v, |
512 | c_v); |
513 | break; |
514 | case 4: |
515 | loopOverReductionDim<2 /*UNROLL*/, 4 /*colBlocks*/>( |
516 | row_ptr + M * kt, |
517 | i, |
518 | col_idx, |
519 | values, |
520 | interleave_buffer_, |
521 | one_16bit_v, |
522 | c_v); |
523 | break; |
524 | default: |
525 | // not reachable |
526 | break; |
527 | } |
528 | |
529 | if (kt == kTiles - 1) { |
530 | // Requantize after last ktile |
531 | __m512i res; |
532 | if (rParams.bias == nullptr) { |
533 | if (rParams.act_zero_point) { |
534 | res = requantizeForMM<FUSE_RELU, false, false, Q_GRAN>( |
535 | c_v, i, rParams); |
536 | } else { |
537 | res = requantizeForMM<FUSE_RELU, true, false, Q_GRAN>( |
538 | c_v, i, rParams); |
539 | } |
540 | } else { |
541 | if (rParams.act_zero_point) { |
542 | res = requantizeForMM<FUSE_RELU, false, true, Q_GRAN>( |
543 | c_v, i, rParams); |
544 | } else { |
545 | res = requantizeForMM<FUSE_RELU, true, true, Q_GRAN>( |
546 | c_v, i, rParams); |
547 | } |
548 | } |
549 | _mm512_mask_storeu_epi8(C_u8 + i * ldc + j, mask_int8_v, res); |
550 | } else { |
551 | int idx = 0; |
552 | for (; idx < rem_int8 / VLEN_INT32; ++idx) { |
553 | _mm512_storeu_si512(C_i32 + i * ldb + idx * VLEN_INT32, c_v[idx]); |
554 | } |
555 | _mm512_mask_storeu_epi32( |
556 | C_i32 + i * ldb + idx * VLEN_INT32, mask_int32_v, c_v[idx]); |
557 | } |
558 | } |
559 | } |
560 | } |
561 | } |
562 | |
563 | #define CREATE_INSTANCE(FUSE_RELU, QGRAN) \ |
564 | template void SparseDenseInt8MMAvx512<FUSE_RELU, QGRAN>( \ |
565 | int N, \ |
566 | const std::unique_ptr<BCSRMatrix<>>& bcsr, \ |
567 | const uint8_t* B, \ |
568 | int ldb, \ |
569 | int32_t* C_i32, \ |
570 | uint8_t* C_u8, \ |
571 | int ldc, \ |
572 | trRequantizationParams_t& rParams, \ |
573 | bool accum, \ |
574 | int thread_id, \ |
575 | int num_threads); |
576 | CREATE_INSTANCE(true, QuantizationGranularity::TENSOR) |
577 | CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL) |
578 | CREATE_INSTANCE(false, QuantizationGranularity::TENSOR) |
579 | CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL) |
580 | #undef CREATE_INSTANCE |
581 | |
582 | } // namespace internal |
583 | } // namespace fbgemm |
584 | |