1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmSparse.h"
9
10#include "fbgemm/Utils.h"
11#include "fbgemm/spmmUtilsAvx2.h"
12
13#if defined(__x86_64__) || defined(__i386__) || \
14 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
15#include <immintrin.h>
16#endif
17#include <algorithm> // for min and max
18#include <cassert>
19
20namespace fbgemm {
21
22namespace internal {
23
24template <
25 bool FUSE_RELU,
26 bool ACT_ZP_0, // is activation zero point 0?
27 bool HAS_BIAS,
28 QuantizationGranularity Q_GRAN>
29static inline __m512i
30requantizeForMM(__m512i x[], int rowIdx, trRequantizationParams_t& rParams) {
31 __m512i C_zero_point_epi8_v = _mm512_set1_epi8(rParams.C_zero_point);
32 __m512i C_zero_point_epi16_v = _mm512_set1_epi16(rParams.C_zero_point);
33 // clang-format off
34 __m512i permute_mask_v = _mm512_set_epi32(
35 0x0F, 0x0B, 0x07, 0x03,
36 0x0E, 0x0A, 0x06, 0x02,
37 0x0D, 0x09, 0x05, 0x01,
38 0x0C, 0x08, 0x04, 0x00);
39 // clang-format on
40 int32_t row_offset = 0;
41 if (!ACT_ZP_0) {
42 row_offset = rParams.act_zero_point * rParams.weight_row_offsets[rowIdx];
43 }
44 __m512i row_offset_v = _mm512_set1_epi32(row_offset);
45
46 int weight_zeropoint_idx = 0;
47 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
48 weight_zeropoint_idx = rowIdx;
49 }
50
51 __m512 bias_v;
52 if (HAS_BIAS) {
53 float bias =
54 rParams.bias[rowIdx] / rParams.act_times_w_scale[weight_zeropoint_idx];
55 bias_v = _mm512_set1_ps(bias);
56 }
57
58 __m512 act_times_w_div_c_v;
59 if (Q_GRAN == QuantizationGranularity::TENSOR) {
60 act_times_w_div_c_v =
61 _mm512_set1_ps(rParams.act_times_w_scale[0] / rParams.C_scale);
62 } else {
63 act_times_w_div_c_v = _mm512_set1_ps(
64 rParams.act_times_w_scale[weight_zeropoint_idx] / rParams.C_scale);
65 }
66 if (!ACT_ZP_0) {
67 x[0] = _mm512_sub_epi32(x[0], row_offset_v);
68 x[1] = _mm512_sub_epi32(x[1], row_offset_v);
69 x[2] = _mm512_sub_epi32(x[2], row_offset_v);
70 x[3] = _mm512_sub_epi32(x[3], row_offset_v);
71 }
72
73 __m512 xf_v, yf_v, zf_v, wf_v;
74 if (HAS_BIAS) {
75 xf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x[0]), bias_v);
76 yf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x[1]), bias_v);
77 zf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x[2]), bias_v);
78 wf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x[3]), bias_v);
79 } else {
80 xf_v = _mm512_cvtepi32_ps(x[0]);
81 yf_v = _mm512_cvtepi32_ps(x[1]);
82 zf_v = _mm512_cvtepi32_ps(x[2]);
83 wf_v = _mm512_cvtepi32_ps(x[3]);
84 }
85
86 __m512 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
87
88 x_scaled_v = _mm512_mul_ps(xf_v, act_times_w_div_c_v);
89 y_scaled_v = _mm512_mul_ps(yf_v, act_times_w_div_c_v);
90 z_scaled_v = _mm512_mul_ps(zf_v, act_times_w_div_c_v);
91 w_scaled_v = _mm512_mul_ps(wf_v, act_times_w_div_c_v);
92
93 __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v);
94 __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v);
95 __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v);
96 __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v);
97
98 __m512i xy_packed_v = _mm512_adds_epi16(
99 _mm512_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
100 __m512i zw_packed_v = _mm512_adds_epi16(
101 _mm512_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
102 // _mm512_packus_epi16 takes care of saturating to uint8 range
103 __m512i xyzw_clamped_v = _mm512_packus_epi16(xy_packed_v, zw_packed_v);
104 if (FUSE_RELU) {
105 xyzw_clamped_v = _mm512_max_epu8(C_zero_point_epi8_v, xyzw_clamped_v);
106 }
107
108 xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v);
109 return xyzw_clamped_v;
110}
111
112static inline __m512i permute_row(__m512i row) {
113 // clang-format off
114 __m256i shuffle_256v = _mm256_set_epi8(
115 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
116 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
117 // clang-format on
118 __m512i shuffle_512v = _mm512_castsi256_si512(shuffle_256v);
119 row = _mm512_shuffle_epi8(
120 row, _mm512_inserti64x4(shuffle_512v, shuffle_256v, 1));
121 return row;
122}
123
124static inline void interleave_4rows(__m512i data[]) {
125 __m512i __t0 = _mm512_unpacklo_epi32(data[0], data[1]);
126 __m512i __t1 = _mm512_unpackhi_epi32(data[0], data[1]);
127 __m512i __t2 = _mm512_unpacklo_epi32(data[2], data[3]);
128 __m512i __t3 = _mm512_unpackhi_epi32(data[2], data[3]);
129 __m512i __tt0 = _mm512_unpacklo_epi64(__t0, __t2);
130 __m512i __tt1 = _mm512_unpacklo_epi64(__t1, __t3);
131 __m512i __tt2 = _mm512_unpackhi_epi64(__t0, __t2);
132 __m512i __tt3 = _mm512_unpackhi_epi64(__t1, __t3);
133 __m512i row0 = _mm512_permutex2var_epi64(
134 __tt0, _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0), __tt2);
135 __m512i row1 = _mm512_permutex2var_epi64(
136 __tt1, _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0), __tt3);
137 __m512i row2 = _mm512_permutex2var_epi64(
138 __tt0, _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4), __tt2);
139 __m512i row3 = _mm512_permutex2var_epi64(
140 __tt1, _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4), __tt3);
141 __m512i row0i = _mm512_shuffle_i64x2(row0, row1, 0x44);
142 __m512i row1i = _mm512_shuffle_i64x2(row0, row1, 0xEE);
143 __m512i row2i = _mm512_shuffle_i64x2(row2, row3, 0x44);
144 __m512i row3i = _mm512_shuffle_i64x2(row2, row3, 0xEE);
145 // End of int32 transpose
146 // Now we only need a simple row permutation to get the right result
147 data[0] = permute_row(row0i);
148 data[1] = permute_row(row1i);
149 data[2] = permute_row(row2i);
150 data[3] = permute_row(row3i);
151 return;
152}
153
154// By default, we proecess 4 column blocks (default value of COLBLOCKS).
155// Each block is of size 64 (VLEN_INT8) uint8 values.
156// We can change the number of blocks we process by
157// using a different value for COLBLOCKS.
158// For example, for cases with 16 <= N <= 32 and 4 rows are interleaved,
159// we only need to process the first 32 columns, which means
160// COLBLOCKS = 32 * 4 (row_interleave) / 64 (VLEN_INT8) = 2.
161template <int UNROLL = 3, int COLBLOCKS = 4>
162static inline void loopOverReductionDim(
163 const int* row_ptr,
164 int rowIdx,
165 const int* col_idx,
166 const int8_t* values,
167 const uint8_t* interleave_buffer,
168 __m512i one_16bit_v,
169 __m512i c_v[]) {
170 constexpr int VLEN_INT8 = 64;
171
172 int r = row_ptr[rowIdx];
173 int r_end_aligned = row_ptr[rowIdx] +
174 (row_ptr[rowIdx + 1] - row_ptr[rowIdx]) / UNROLL * UNROLL;
175 for (; r < r_end_aligned; r += UNROLL) {
176 __m512i a_v[UNROLL];
177 int acbr_block[UNROLL];
178 for (int i = 0; i < UNROLL; ++i) {
179 acbr_block[i] = col_idx[r + i];
180 int32_t v = reinterpret_cast<const int32_t*>(values)[r + i];
181 a_v[i] = _mm512_set1_epi32(v);
182 }
183
184 __m512i br_v[UNROLL][COLBLOCKS];
185 for (int i = 0; i < UNROLL; ++i) {
186 for (int idx = 0; idx < COLBLOCKS; ++idx) {
187 br_v[i][idx] = _mm512_loadu_si512(
188 interleave_buffer + (acbr_block[i] * COLBLOCKS + idx) * VLEN_INT8);
189 }
190 }
191
192 for (int i = 0; i < UNROLL; ++i) {
193 for (int idx = 0; idx < COLBLOCKS; ++idx) {
194 __m512i c_i16_v = _mm512_maddubs_epi16(br_v[i][idx], a_v[i]);
195 __m512i c_i32_v = _mm512_madd_epi16(one_16bit_v, c_i16_v);
196 c_v[idx] = _mm512_add_epi32(c_v[idx], c_i32_v);
197 }
198 }
199 }
200 // remainder loop
201 for (; r < row_ptr[rowIdx + 1]; ++r) {
202 int acbr_block = col_idx[r];
203 int32_t v = reinterpret_cast<const int32_t*>(values)[r];
204 __m512i a_v = _mm512_set1_epi32(v);
205
206 __m512i br_v[COLBLOCKS];
207 for (int idx = 0; idx < COLBLOCKS; ++idx) {
208 br_v[idx] = _mm512_loadu_si512(
209 interleave_buffer + (acbr_block * COLBLOCKS + idx) * VLEN_INT8);
210 }
211
212 for (int idx = 0; idx < COLBLOCKS; ++idx) {
213 __m512i c_i16_v = _mm512_maddubs_epi16(br_v[idx], a_v);
214 __m512i c_i32_v = _mm512_madd_epi16(one_16bit_v, c_i16_v);
215 c_v[idx] = _mm512_add_epi32(c_v[idx], c_i32_v);
216 }
217 }
218}
219
220template <int ROWSIZE = 4, bool MASKLOAD = false>
221static inline void loadBRows(
222 __m512i br_v[],
223 const uint8_t* B_start,
224 int ld,
225 __mmask64 mask_int8_v = 0) {
226 int idx = 0;
227 for (; idx < ROWSIZE; ++idx) {
228 if (MASKLOAD) {
229 br_v[idx] = _mm512_maskz_loadu_epi8(mask_int8_v, B_start + idx * ld);
230 } else {
231 br_v[idx] = _mm512_loadu_si512(B_start + idx * ld);
232 }
233 }
234 // set rests to 0
235 for (; idx < 4; ++idx) {
236 br_v[idx] = _mm512_set1_epi32(0);
237 }
238}
239
240// For COLBLOCKS, see description at loopOverReductionDim
241template <int COLBLOCKS = 4>
242static inline void
243storeToInterleaveBuffer(__m512i br_v[], uint8_t* interleave_start, int ld) {
244 for (int idx = 0; idx < COLBLOCKS; ++idx) {
245 _mm512_storeu_si512(interleave_start + idx * ld, br_v[idx]);
246 }
247}
248
249// For COLBLOCKS, see description at loopOverReductionDim
250template <int COLBLOCKS = 4>
251static inline void interleave4RowsTile(
252 int N,
253 int kSize,
254 const uint8_t* B,
255 uint8_t* interleave_buffer,
256 int ld,
257 const int col_start) {
258 constexpr int VLEN_INT8 = 64;
259 constexpr int colBlockSize = 4;
260 assert(colBlockSize == 4 && "column block size should be 4");
261 const int kBlocks = kSize / colBlockSize;
262 if (col_start < N / VLEN_INT8 * VLEN_INT8) {
263 __m512i br_v[4];
264 int i = 0;
265 for (; i < kBlocks; ++i) {
266 loadBRows<4, false>(br_v, B + i * colBlockSize * ld + col_start, ld);
267 interleave_4rows(br_v);
268 storeToInterleaveBuffer<4>(
269 br_v, interleave_buffer + i * colBlockSize * VLEN_INT8, VLEN_INT8);
270 }
271 int rem = kSize - i * colBlockSize;
272 if (rem > 0) {
273 if (rem == 3) {
274 loadBRows<3, false>(br_v, B + i * colBlockSize * ld + col_start, ld);
275 } else if (rem == 2) {
276 loadBRows<2, false>(br_v, B + i * colBlockSize * ld + col_start, ld);
277 } else {
278 loadBRows<1, false>(br_v, B + i * colBlockSize * ld + col_start, ld);
279 }
280 interleave_4rows(br_v);
281 storeToInterleaveBuffer<4>(
282 br_v, interleave_buffer + i * colBlockSize * VLEN_INT8, VLEN_INT8);
283 }
284 } else {
285 int rem_int8 = N - col_start;
286 __mmask64 mask_int8_v = (((long long)1) << rem_int8) - 1;
287 __m512i br_v[4];
288 int i = 0;
289 for (; i < kBlocks; ++i) {
290 loadBRows<4, true>(
291 br_v, B + i * colBlockSize * ld + col_start, ld, mask_int8_v);
292 interleave_4rows(br_v);
293 storeToInterleaveBuffer<COLBLOCKS>(
294 br_v, interleave_buffer + i * COLBLOCKS * VLEN_INT8, VLEN_INT8);
295 }
296 int rem = kSize - i * colBlockSize;
297 if (rem > 0) {
298 if (rem == 3) {
299 loadBRows<3, true>(
300 br_v, B + i * colBlockSize * ld + col_start, ld, mask_int8_v);
301 } else if (rem == 2) {
302 loadBRows<2, true>(
303 br_v, B + i * colBlockSize * ld + col_start, ld, mask_int8_v);
304 } else {
305 loadBRows<1, true>(
306 br_v, B + i * colBlockSize * ld + col_start, ld, mask_int8_v);
307 }
308 interleave_4rows(br_v);
309 storeToInterleaveBuffer<COLBLOCKS>(
310 br_v, interleave_buffer + i * COLBLOCKS * VLEN_INT8, VLEN_INT8);
311 }
312 }
313 return;
314}
315
316template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
317void SparseDenseInt8MMAvx512(
318 int N,
319 const std::unique_ptr<BCSRMatrix<>>& bcsr,
320 const uint8_t* B,
321 int ldb,
322 int32_t* C_i32,
323 uint8_t* C_u8,
324 int ldc,
325 trRequantizationParams_t& rParams,
326 bool accum,
327 int thread_id,
328 int num_threads) {
329 // gemv
330 if (N == 1 && ldb == 1 && ldc == 1 && bcsr->C % 4 == 0) {
331 return SparseDenseInt8MVAvx512<FUSE_RELU, Q_GRAN>(
332 bcsr, B, ldb, C_i32, C_u8, rParams, accum, thread_id, num_threads);
333 }
334
335 // Calcualtes accum ? C += A * B : C = A * B
336 constexpr int VLEN_INT8 = 64;
337 constexpr int VLEN_INT32 = 16;
338
339 constexpr int colTileSize = BCSRMatrix<>::COLTILE;
340 // Number of columns in the sparse matrix A
341 int K = bcsr->C;
342 int M = bcsr->R;
343 assert((K > 0) && "K needs to be positive");
344 int kTiles = (K + colTileSize - 1) / colTileSize;
345 const int* row_ptr = bcsr->rowBPtr.data();
346 const int* col_idx = bcsr->colBIdx.data();
347 const int8_t* values = bcsr->values.data();
348
349 constexpr int buffer_size = BCSRMatrix<>::COLTILE * VLEN_INT8;
350 static thread_local uint8_t* interleave_buffer_ = nullptr;
351
352 if (interleave_buffer_ == nullptr) {
353 interleave_buffer_ =
354 static_cast<uint8_t*>(fbgemmAlignedAlloc(64, buffer_size));
355 }
356
357 assert(
358 (interleave_buffer_ != nullptr) &&
359 "interleave_buffer_ cannot be nullptr");
360
361 __m512i one_16bit_v = _mm512_set1_epi16(1);
362 int j = 0;
363 for (; j < N / VLEN_INT8 * VLEN_INT8; j += VLEN_INT8) {
364 for (int kt = 0; kt < kTiles; ++kt) {
365 int curKSize = std::min(K - kt * colTileSize, colTileSize);
366 interleave4RowsTile<4 /*COLBLOCKS*/>(
367 N, curKSize, B + kt * colTileSize * ldb, interleave_buffer_, ldb, j);
368 for (int i = 0; i < M; ++i) {
369 __m512i c_v[4];
370 if (accum || kt > 0) {
371 for (int idx = 0; idx < 4; ++idx) {
372 c_v[idx] = _mm512_loadu_si512(C_i32 + i * ldb + idx * VLEN_INT32);
373 }
374 } else {
375 for (int idx = 0; idx < 4; ++idx) {
376 c_v[idx] = _mm512_set1_epi32(0);
377 }
378 }
379
380 loopOverReductionDim<2 /*UNROLL*/, 4 /*COLBLOCKS*/>(
381 row_ptr + kt * M,
382 i,
383 col_idx,
384 values,
385 interleave_buffer_,
386 one_16bit_v,
387 c_v);
388
389 if (kt == kTiles - 1) {
390 // Requantize after last ktile
391 __m512i res;
392 if (rParams.bias == nullptr) {
393 if (rParams.act_zero_point) {
394 res = requantizeForMM<FUSE_RELU, false, false, Q_GRAN>(
395 c_v, i, rParams);
396 } else {
397 res = requantizeForMM<FUSE_RELU, true, false, Q_GRAN>(
398 c_v, i, rParams);
399 }
400 } else {
401 if (rParams.act_zero_point) {
402 res = requantizeForMM<FUSE_RELU, false, true, Q_GRAN>(
403 c_v, i, rParams);
404 } else {
405 res = requantizeForMM<FUSE_RELU, true, true, Q_GRAN>(
406 c_v, i, rParams);
407 }
408 }
409 _mm512_storeu_si512(C_u8 + i * ldc + j, res);
410 } else {
411 // store the results
412 for (int idx = 0; idx < 4; ++idx) {
413 _mm512_storeu_si512(C_i32 + i * ldb + idx * VLEN_INT32, c_v[idx]);
414 }
415 }
416 }
417 }
418 }
419 // Handle remainder j loop
420 int rem_int8 = N - j;
421 int rem_int32 = N % VLEN_INT32;
422 int colBlocks = (rem_int8 + VLEN_INT32 - 1) / VLEN_INT32;
423 if (rem_int8 > 0) {
424 for (int kt = 0; kt < kTiles; ++kt) {
425 // last k tile may have less than colTileSize columns of A matrix (aka
426 // rows of B)
427 int curKSize = std::min(K - kt * colTileSize, colTileSize);
428 switch (colBlocks) {
429 case 1:
430 interleave4RowsTile<1>(
431 N,
432 curKSize,
433 B + kt * colTileSize * ldb,
434 interleave_buffer_,
435 ldb,
436 j);
437 break;
438 case 2:
439 interleave4RowsTile<2>(
440 N,
441 curKSize,
442 B + kt * colTileSize * ldb,
443 interleave_buffer_,
444 ldb,
445 j);
446 break;
447 case 3:
448 interleave4RowsTile<3>(
449 N,
450 curKSize,
451 B + kt * colTileSize * ldb,
452 interleave_buffer_,
453 ldb,
454 j);
455 break;
456 case 4:
457 interleave4RowsTile<4>(
458 N,
459 curKSize,
460 B + kt * colTileSize * ldb,
461 interleave_buffer_,
462 ldb,
463 j);
464 break;
465 default:
466 // not reachable
467 break;
468 }
469
470 __mmask16 mask_int32_v = (((long long)1) << rem_int32) - 1;
471 __mmask64 mask_int8_v = (((long long)1) << rem_int8) - 1;
472 for (int i = 0; i < M; ++i) {
473 __m512i c_v[4] = {};
474 if (accum || kt > 0) {
475 int idx = 0;
476 for (; idx < rem_int8 / VLEN_INT32; ++idx) {
477 c_v[idx] = _mm512_loadu_si512(C_i32 + i * ldb + idx * VLEN_INT32);
478 }
479 c_v[idx] = _mm512_maskz_loadu_epi32(
480 mask_int32_v, C_i32 + i * ldb + idx * VLEN_INT32);
481 }
482
483 switch (colBlocks) {
484 case 1:
485 loopOverReductionDim<3 /*UNROLL*/, 1 /*colBlocks*/>(
486 row_ptr + M * kt,
487 i,
488 col_idx,
489 values,
490 interleave_buffer_,
491 one_16bit_v,
492 c_v);
493 break;
494 case 2:
495 loopOverReductionDim<3 /*UNROLL*/, 2 /*colBlocks*/>(
496 row_ptr + M * kt,
497 i,
498 col_idx,
499 values,
500 interleave_buffer_,
501 one_16bit_v,
502 c_v);
503 break;
504 case 3:
505 loopOverReductionDim<2 /*UNROLL*/, 3 /*colBlocks*/>(
506 row_ptr + M * kt,
507 i,
508 col_idx,
509 values,
510 interleave_buffer_,
511 one_16bit_v,
512 c_v);
513 break;
514 case 4:
515 loopOverReductionDim<2 /*UNROLL*/, 4 /*colBlocks*/>(
516 row_ptr + M * kt,
517 i,
518 col_idx,
519 values,
520 interleave_buffer_,
521 one_16bit_v,
522 c_v);
523 break;
524 default:
525 // not reachable
526 break;
527 }
528
529 if (kt == kTiles - 1) {
530 // Requantize after last ktile
531 __m512i res;
532 if (rParams.bias == nullptr) {
533 if (rParams.act_zero_point) {
534 res = requantizeForMM<FUSE_RELU, false, false, Q_GRAN>(
535 c_v, i, rParams);
536 } else {
537 res = requantizeForMM<FUSE_RELU, true, false, Q_GRAN>(
538 c_v, i, rParams);
539 }
540 } else {
541 if (rParams.act_zero_point) {
542 res = requantizeForMM<FUSE_RELU, false, true, Q_GRAN>(
543 c_v, i, rParams);
544 } else {
545 res = requantizeForMM<FUSE_RELU, true, true, Q_GRAN>(
546 c_v, i, rParams);
547 }
548 }
549 _mm512_mask_storeu_epi8(C_u8 + i * ldc + j, mask_int8_v, res);
550 } else {
551 int idx = 0;
552 for (; idx < rem_int8 / VLEN_INT32; ++idx) {
553 _mm512_storeu_si512(C_i32 + i * ldb + idx * VLEN_INT32, c_v[idx]);
554 }
555 _mm512_mask_storeu_epi32(
556 C_i32 + i * ldb + idx * VLEN_INT32, mask_int32_v, c_v[idx]);
557 }
558 }
559 }
560 }
561}
562
563#define CREATE_INSTANCE(FUSE_RELU, QGRAN) \
564 template void SparseDenseInt8MMAvx512<FUSE_RELU, QGRAN>( \
565 int N, \
566 const std::unique_ptr<BCSRMatrix<>>& bcsr, \
567 const uint8_t* B, \
568 int ldb, \
569 int32_t* C_i32, \
570 uint8_t* C_u8, \
571 int ldc, \
572 trRequantizationParams_t& rParams, \
573 bool accum, \
574 int thread_id, \
575 int num_threads);
576CREATE_INSTANCE(true, QuantizationGranularity::TENSOR)
577CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL)
578CREATE_INSTANCE(false, QuantizationGranularity::TENSOR)
579CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL)
580#undef CREATE_INSTANCE
581
582} // namespace internal
583} // namespace fbgemm
584