1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#include "./ExecuteKernelU8S8.h"
8#include <cpuinfo.h>
9#include <chrono>
10
11#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
12double kernel_time = 0.0;
13double postprocessing_time = 0.0;
14#endif
15
16namespace fbgemm {
17
18template <typename packingAMatrix, typename cT, typename processOutputType>
19ExecuteKernel<
20 packingAMatrix,
21 PackBMatrix<int8_t, typename packingAMatrix::accType>,
22 cT,
23 processOutputType>::
24 ExecuteKernel(
25 PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
26 packA,
27 PackMatrix<
28 PackBMatrix<int8_t, typename packingAMatrix::accType>,
29 int8_t,
30 typename packingAMatrix::accType>& packB,
31 cT* matC,
32 int32_t* C_buffer,
33 int32_t ldc,
34 const processOutputType& outputProcess,
35 thread_type_t th_info,
36 const BlockingFactors* params)
37 : CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>(
38 params),
39 packedA_(packA),
40 packedB_(packB),
41 matC_(matC),
42 C_buffer_(C_buffer),
43 ldc_(ldc),
44 outputProcess_(outputProcess),
45 th_info_(th_info) {
46 if (!cpuinfo_initialize()) {
47 throw std::runtime_error("Failed to initialize cpuinfo!");
48 }
49 if (params) {
50 if (fbgemmHasAvx2Support()) {
51 mbSize_ = params->MCB;
52 nbSize_ = params->NCB;
53 nrMinSize_ = params->NR_MIN;
54 nrSize_ = params->NR;
55 } else {
56 // TODO: Have default slower path
57 assert(0 && "unsupported architecure");
58 throw std::runtime_error("unsupported architecure");
59 }
60 } else {
61 const inst_set_t isa = fbgemmInstructionSet();
62 switch (isa) {
63 case inst_set_t::avx512_vnni:
64 std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits<
65 typename packingAMatrix::inpType,
66 typename packingAMatrix::accType,
67 inst_set_t::avx512_vnni>::getKernelParams();
68 break;
69
70 case inst_set_t::avx512_vnni_ymm:
71 std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits<
72 typename packingAMatrix::inpType,
73 typename packingAMatrix::accType,
74 inst_set_t::avx512_vnni_ymm>::getKernelParams();
75 break;
76
77 case inst_set_t::avx512:
78 std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits<
79 typename packingAMatrix::inpType,
80 typename packingAMatrix::accType,
81 inst_set_t::avx512>::getKernelParams();
82 break;
83
84 case inst_set_t::avx512_ymm:
85 std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits<
86 typename packingAMatrix::inpType,
87 typename packingAMatrix::accType,
88 inst_set_t::avx512_ymm>::getKernelParams();
89 break;
90
91 case inst_set_t::avx2:
92 std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits<
93 typename packingAMatrix::inpType,
94 typename packingAMatrix::accType,
95 inst_set_t::avx2>::getKernelParams();
96 break;
97
98 default:
99 assert(0 && "unknown architecure");
100 throw std::runtime_error("unknown architecure");
101 }
102 }
103}
104
105template <typename packingAMatrix, typename cT, typename processOutputType>
106void ExecuteKernel<
107 packingAMatrix,
108 PackBMatrix<int8_t, typename packingAMatrix::accType>,
109 cT,
110 processOutputType>::execute(int kBlock) {
111 // packedA_.printPackedMatrix("packedA from kernel");
112 // packedB_.printPackedMatrix("packedB from kernel");
113
114 int32_t bColBlocks = packedB_.blockCols();
115
116 int8_t* bBuf;
117 int8_t* bBuf_pf;
118
119 uint8_t* aBuf = packedA_.getBuf(0);
120
121 int32_t packed_rows_A = packedA_.numPackedRows();
122 int32_t row_start_A = packedA_.packedRowStart();
123
124 int group = kBlock / packedB_.blockRows();
125 int NDim = packedB_.numCols();
126 bool lastKBlock = packedB_.isThisLastKBlock(kBlock % packedB_.blockRows());
127 bool accum = (kBlock % packedB_.blockRows()) > 0;
128
129 int64_t jb_begin, jb_end;
130 fbgemmPartition1D(
131 th_info_.n_thread_id,
132 th_info_.n_num_threads,
133 bColBlocks,
134 jb_begin,
135 jb_end);
136 if (jb_end == jb_begin) {
137 return;
138 }
139
140 typename BaseType::jit_micro_kernel_fp fn;
141
142 const inst_set_t isa = fbgemmInstructionSet();
143 switch (isa) {
144 case inst_set_t::avx512_vnni:
145 if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
146 // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
147 CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
148 fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
149 accum,
150 packed_rows_A,
151 packedB_.blockColSize(),
152 packedA_.numPackedCols());
153 } else {
154 fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
155 accum,
156 packed_rows_A,
157 packedB_.blockColSize(),
158 packedA_.numPackedCols());
159 }
160 break;
161
162 case inst_set_t::avx512_vnni_ymm:
163 if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
164 // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
165 CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
166 fn = codeObj.getOrCreate<inst_set_t::avx512_vnni_ymm>(
167 accum,
168 packed_rows_A,
169 packedB_.blockColSize(),
170 packedA_.numPackedCols());
171 } else {
172 fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni_ymm>(
173 accum,
174 packed_rows_A,
175 packedB_.blockColSize(),
176 packedA_.numPackedCols());
177 }
178 break;
179
180 case inst_set_t::avx512:
181 fn = BaseType::template getOrCreate<inst_set_t::avx512>(
182 accum,
183 packed_rows_A,
184 packedB_.blockColSize(),
185 packedA_.numPackedCols());
186 break;
187
188 case inst_set_t::avx512_ymm:
189 fn = BaseType::template getOrCreate<inst_set_t::avx512_ymm>(
190 accum,
191 packed_rows_A,
192 packedB_.blockColSize(),
193 packedA_.numPackedCols());
194 break;
195
196 case inst_set_t::avx2:
197 fn = BaseType::template getOrCreate<inst_set_t::avx2>(
198 accum,
199 packed_rows_A,
200 packedB_.blockColSize(),
201 packedA_.numPackedCols());
202 break;
203
204 default:
205 // TODO: Have default slower path
206 assert(0 && "unsupported architecture");
207 throw std::runtime_error("unsupported architecure");
208 }
209
210#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
211 std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end;
212 double dt;
213 t_start = std::chrono::high_resolution_clock::now();
214#endif
215
216 for (int jb = jb_begin; jb < jb_end; ++jb) {
217 if (jb == bColBlocks - 1) {
218 int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
219 if (nc != nbSize_) {
220 switch (isa) {
221 case inst_set_t::avx512_vnni:
222 if (std::is_same<typename packingAMatrix::accType, std::int16_t>::
223 value) {
224 // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
225 CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
226 fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
227 accum, packed_rows_A, nc, packedA_.numPackedCols());
228 } else {
229 fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
230 accum, packed_rows_A, nc, packedA_.numPackedCols());
231 }
232 break;
233
234 case inst_set_t::avx512_vnni_ymm:
235 if (std::is_same<typename packingAMatrix::accType, std::int16_t>::
236 value) {
237 // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
238 CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
239 fn = codeObj.getOrCreate<inst_set_t::avx512_vnni_ymm>(
240 accum, packed_rows_A, nc, packedA_.numPackedCols());
241 } else {
242 fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni_ymm>(
243 accum, packed_rows_A, nc, packedA_.numPackedCols());
244 }
245 break;
246
247 case inst_set_t::avx512:
248 fn = BaseType::template getOrCreate<inst_set_t::avx512>(
249 accum, packed_rows_A, nc, packedA_.numPackedCols());
250 break;
251
252 case inst_set_t::avx512_ymm:
253 fn = BaseType::template getOrCreate<inst_set_t::avx512_ymm>(
254 accum, packed_rows_A, nc, packedA_.numPackedCols());
255 break;
256
257 case inst_set_t::avx2:
258 fn = BaseType::template getOrCreate<inst_set_t::avx2>(
259 accum, packed_rows_A, nc, packedA_.numPackedCols());
260 break;
261
262 default:
263 // TODO: Have default slower path
264 assert(0 && "unsupported architecture");
265 throw std::runtime_error("unsupported architecure");
266 }
267 }
268 }
269
270 bBuf = packedB_.getBuf(jb, kBlock);
271 // prefetch addr of the next packed block of B matrix
272 bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock);
273
274 // If the accumulation buffer C_buffer_ is the same as matC_ (inplace output
275 // processing), then each thread use the different parts of output buffer
276 // matC_;
277 // Otherwise, each thread uses different portions of the accumulation
278 // buffer C_buffer_. If m is large enough (m >= m_nthreads * MC), then we
279 // only need to use (m_nthreads * MC) x n portion of C_buffer_, each thread
280 // access the C_buffer_row_start as tid * MC * ldc_; else when m is very
281 // small, we juse use the whole m x n C_buffer_: each thread use the
282 // different portion.
283 int32_t* C_buffer_row_start = C_buffer_ +
284 ((C_buffer_ == reinterpret_cast<int32_t*>(matC_) ||
285 th_info_.m_num_threads * mbSize_ > packedA_.numRows())
286 ? row_start_A * ldc_ + NDim * group
287 : th_info_.m_thread_id * mbSize_ * ldc_ + NDim * group);
288
289 int32_t* C_buffer_start = C_buffer_row_start + jb * nbSize_;
290 int32_t leadingDim = ldc_;
291 static thread_local std::vector<int32_t> C_tile_;
292 if (packedB_.isThereColRemainder() && (jb == bColBlocks - 1)) {
293 // In case we will access memory past C_buffer_, we use C_tile_ scratchpad
294 // instead.
295 C_tile_.resize(mbSize_ * nbSize_);
296 C_buffer_start = C_tile_.data();
297 leadingDim = nbSize_;
298 }
299
300 fn(aBuf,
301 bBuf,
302 bBuf_pf,
303 C_buffer_start,
304 packedA_.numPackedCols(),
305 leadingDim);
306
307#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
308 t_end = std::chrono::high_resolution_clock::now();
309 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
310 .count();
311 kernel_time += (dt);
312 t_start = std::chrono::high_resolution_clock::now();
313#endif
314
315 // Output processing is done only once per rowblock to amortize overhead
316 // and for better spatial locality.
317 if (lastKBlock && jb == jb_end - 1) {
318 // When C_tile_ is used for the last column block, we need a separate
319 // handling for the last column block.
320 int32_t nSize =
321 (C_buffer_start == C_tile_.data() ? (jb - jb_begin) * nbSize_
322 : (jb_end - jb_begin) * nbSize_);
323 if (nSize) {
324 if (fbgemmHasAvx2Support()) {
325 // TODO: avx512 path
326 // Currently use avx2 code
327 outputProcess_.template f<inst_set_t::avx2>(
328 matC_,
329 C_buffer_row_start + jb_begin * nbSize_,
330 {row_start_A,
331 packed_rows_A,
332 static_cast<int>(NDim * group + jb_begin * nbSize_),
333 nSize},
334 ldc_,
335 ldc_);
336 } else {
337 // TODO: Have default slower path
338 assert(0 && "unsupported architecure");
339 throw std::runtime_error("unsupported architecure");
340 }
341 }
342
343 if (C_buffer_start == C_tile_.data()) {
344 // When C_tile_ scratchpad was used to avoid accessing memory past
345 // C_buffer_ .
346 if (fbgemmHasAvx2Support()) {
347 // TODO: avx512 path
348 // Currently use avx2 code
349 outputProcess_.template f<inst_set_t::avx2>(
350 matC_,
351 C_tile_.data(),
352 {row_start_A,
353 packed_rows_A,
354 NDim * group + jb * nbSize_,
355 packedB_.lastBcol()},
356 ldc_,
357 leadingDim);
358 } else {
359 // TODO: Have default slower path
360 assert(0 && "unsupported architecure");
361 throw std::runtime_error("unsupported architecure");
362 }
363 }
364 } // output processing
365
366#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
367 t_end = std::chrono::high_resolution_clock::now();
368 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
369 .count();
370 postprocessing_time += (dt);
371 t_start = std::chrono::high_resolution_clock::now();
372#endif
373
374 } // for each j block
375}
376
377////////////////////////////////////////////////////////////////////////////////
378// ReQuantizeOutput
379#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \
380 template class ExecuteKernel< \
381 PACK_A<uint8_t, ACC_T>, \
382 PackBMatrix<int8_t, ACC_T>, \
383 uint8_t, \
384 ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>;
385
386#define INSTANTIATE_REQUANT_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \
387 INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \
388 INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t);
389
390#define INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, RELU) \
391 INSTANTIATE_REQUANT_BIAS_T( \
392 PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
393 INSTANTIATE_REQUANT_BIAS_T( \
394 PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \
395 INSTANTIATE_REQUANT_BIAS_T( \
396 PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL);
397
398#define INSTANTIATE_REQUANT_RELU(PACK_A, ACC_T) \
399 INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, false); \
400 INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, true);
401
402#define INSTANTIATE_REQUANT_ACC_T(PACK_A) \
403 INSTANTIATE_REQUANT_RELU(PACK_A, int32_t); \
404 INSTANTIATE_REQUANT_RELU(PACK_A, int16_t);
405
406INSTANTIATE_REQUANT_ACC_T(PackAMatrix);
407INSTANTIATE_REQUANT_ACC_T(PackAWithRowOffset);
408
409#undef INSTANTIATE_REQUANT_ACC_T
410#undef INSTANTIATE_REQUANT_RELU
411#undef INSTANTIATE_REQUANT_Q_GRANS
412#undef INSTANTIATE_REQUANT_BIAS_T
413#undef INSTANTIATE_REQUANT_BASE
414
415#define INSTANTIATE_IM2COL_REQUANT_BASE( \
416 ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \
417 template class ExecuteKernel< \
418 PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
419 PackBMatrix<int8_t, ACC_T>, \
420 uint8_t, \
421 ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>;
422
423#define INSTANTIATE_IM2COL_REQUANT_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
424 INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \
425 INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t);
426
427#define INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
428 INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
429 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
430 INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
431 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
432 INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
433 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL);
434
435#define INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, RELU) \
436 INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, 1); \
437 INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, 2); \
438 INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, 3);
439
440#define INSTANTIATE_IM2COL_REQUANT_RELU(ACC_T) \
441 INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, false); \
442 INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, true);
443
444INSTANTIATE_IM2COL_REQUANT_RELU(int32_t);
445INSTANTIATE_IM2COL_REQUANT_RELU(int16_t);
446
447#undef INSTANTIATE_IM2COL_REQUANT_RELU
448#undef INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM
449#undef INSTANTIATE_IM2COL_REQUANT_Q_GRANS
450#undef INSTANTIATE_IM2COL_REQUANT_BIAS_T
451#undef INSTANTIATE_IM2COL_REQUANT_BASE
452
453////////////////////////////////////////////////////////////////////////////////
454// ReQuantizeForFloat
455#define INSTANTIATE_REQUANT_FLOAT_BASE(PACK_A, RELU, Q_GRAN) \
456 template class ExecuteKernel< \
457 PACK_A<uint8_t, int32_t>, \
458 PackBMatrix<int8_t, int32_t>, \
459 float, \
460 ReQuantizeForFloat<RELU, Q_GRAN>>;
461
462#define INSTANTIATE_REQUANT_FLOAT_Q_GRANS(PACK_A, RELU) \
463 INSTANTIATE_REQUANT_FLOAT_BASE( \
464 PACK_A, RELU, QuantizationGranularity::TENSOR); \
465 INSTANTIATE_REQUANT_FLOAT_BASE( \
466 PACK_A, RELU, QuantizationGranularity::GROUP); \
467 INSTANTIATE_REQUANT_FLOAT_BASE( \
468 PACK_A, RELU, QuantizationGranularity::OUT_CHANNEL);
469
470#define INSTANTIATE_REQUANT_FLOAT_RELU(PACK_A) \
471 INSTANTIATE_REQUANT_FLOAT_Q_GRANS(PACK_A, false); \
472 INSTANTIATE_REQUANT_FLOAT_Q_GRANS(PACK_A, true);
473
474INSTANTIATE_REQUANT_FLOAT_RELU(PackAWithRowOffset);
475INSTANTIATE_REQUANT_FLOAT_RELU(PackAWithQuantRowOffset);
476
477#undef INSTANTIATE_REQUANT_FLOAT_RELU
478#undef INSTANTIATE_REQUANT_FLOAT_Q_GRANS
479#undef INSTANTIATE_REQUANT_FLOAT_BASE
480
481#define INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE( \
482 ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
483 template class ExecuteKernel< \
484 PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
485 PackBMatrix<int8_t, ACC_T>, \
486 float, \
487 ReQuantizeForFloat<RELU, Q_GRAN>>;
488
489#define INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
490 INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE( \
491 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
492 INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE( \
493 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
494 INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE( \
495 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL);
496
497#define INSTANTIATE_REQUANT_FLOAT_IM2COL_SPATIAL_DIM(ACC_T, RELU) \
498 INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS(ACC_T, RELU, 1); \
499 INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS(ACC_T, RELU, 2); \
500 INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS(ACC_T, RELU, 3);
501
502#define INSTANTIATE_REQUANT_FLOAT_IM2COL_RELU(ACC_T) \
503 INSTANTIATE_REQUANT_FLOAT_IM2COL_SPATIAL_DIM(ACC_T, false); \
504 INSTANTIATE_REQUANT_FLOAT_IM2COL_SPATIAL_DIM(ACC_T, true);
505
506INSTANTIATE_REQUANT_FLOAT_IM2COL_RELU(int32_t);
507INSTANTIATE_REQUANT_FLOAT_IM2COL_RELU(int16_t);
508
509#undef INSTANTIATE_REQUANT_FLOAT_IM2COL_RELU
510#undef INSTANTIATE_REQUANT_FLOAT_IM2COL_SPATIAL_DIM
511#undef INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS
512#undef INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE
513
514template class ExecuteKernel<
515 PackAWithRowOffset<uint8_t, int16_t>,
516 PackBMatrix<int8_t, int16_t>,
517 float,
518 ReQuantizeForFloat<false /* FUSE_RELU*/>>;
519
520////////////////////////////////////////////////////////////////////////////////
521// DoSpmdmOnInpBuffer
522#define INSTANTIATE_SPMDM_BASE(PACK_A, RELU, Q_GRAN) \
523 template class ExecuteKernel< \
524 PACK_A<uint8_t, int16_t>, \
525 PackBMatrix<int8_t, int16_t>, \
526 uint8_t, \
527 DoSpmdmOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<RELU, Q_GRAN>>>;
528
529#define INSTANTIATE_SPMDM_Q_GRANS(PACK_A, RELU) \
530 INSTANTIATE_SPMDM_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR); \
531 INSTANTIATE_SPMDM_BASE(PACK_A, RELU, QuantizationGranularity::GROUP); \
532 INSTANTIATE_SPMDM_BASE(PACK_A, RELU, QuantizationGranularity::OUT_CHANNEL);
533
534#define INSTANTIATE_SPMDM_RELU(PACK_A) \
535 INSTANTIATE_SPMDM_Q_GRANS(PACK_A, false); \
536 INSTANTIATE_SPMDM_Q_GRANS(PACK_A, true);
537
538INSTANTIATE_SPMDM_RELU(PackAMatrix);
539INSTANTIATE_SPMDM_RELU(PackAWithRowOffset);
540
541#undef INSTANTIATE_SPMDM_RELU
542#undef INSTANTIATE_SPMDM_Q_GRANS
543#undef INSTANTIATE_SPMDM_BASE
544
545#define INSTANTIATE_SCONV_BASE(RELU, Q_GRAN) \
546 template class ExecuteKernel< \
547 PackAWithIm2Col<uint8_t, int16_t>, \
548 PackBMatrix<int8_t, int16_t>, \
549 uint8_t, \
550 DoSConvOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<RELU, Q_GRAN>>>;
551
552#define INSTANTIATE_SCONV_Q_GRANS(RELU) \
553 INSTANTIATE_SCONV_BASE(RELU, QuantizationGranularity::TENSOR); \
554 INSTANTIATE_SCONV_BASE(RELU, QuantizationGranularity::GROUP); \
555 INSTANTIATE_SCONV_BASE(RELU, QuantizationGranularity::OUT_CHANNEL);
556
557INSTANTIATE_SCONV_Q_GRANS(false);
558INSTANTIATE_SCONV_Q_GRANS(true);
559
560#undef INSTANTIATE_SCONV_Q_GRANS
561#undef INSTANTIATE_SCONV_BASE
562
563template class ExecuteKernel<
564 PackAWithRowOffset<uint8_t, int16_t>,
565 PackBMatrix<int8_t, int16_t>,
566 float,
567 DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>>;
568
569////////////////////////////////////////////////////////////////////////////////
570// memCopy
571#define INSTANTIATE_MEMCPY_BASE(PACK_A, ACC_T) \
572 template class ExecuteKernel< \
573 PACK_A<uint8_t, ACC_T>, \
574 PackBMatrix<int8_t, ACC_T>, \
575 int32_t, \
576 memCopy<>>;
577
578#define INSTANTIATE_MEMCPY_ACC_T(PACK_A) \
579 INSTANTIATE_MEMCPY_BASE(PACK_A, int32_t) \
580 INSTANTIATE_MEMCPY_BASE(PACK_A, int16_t)
581
582INSTANTIATE_MEMCPY_ACC_T(PackAMatrix);
583INSTANTIATE_MEMCPY_ACC_T(PackAWithRowOffset);
584
585#undef INSTANTIATE_MEMCPY_ACC_T
586#undef INSTANTIATE_MEMCPY_BASE
587
588#define INSTANTIATE_MEMCPY_IM2COL_BASE(ACC_T, SPATIAL_DIM) \
589 template class ExecuteKernel< \
590 PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
591 PackBMatrix<int8_t, ACC_T>, \
592 int32_t, \
593 memCopy<>>;
594
595#define INSTANTIATE_MEMCPY_IM2COL_SPATIAL_DIM(ACC_T) \
596 INSTANTIATE_MEMCPY_IM2COL_BASE(ACC_T, 1); \
597 INSTANTIATE_MEMCPY_IM2COL_BASE(ACC_T, 2); \
598 INSTANTIATE_MEMCPY_IM2COL_BASE(ACC_T, 3);
599
600INSTANTIATE_MEMCPY_IM2COL_SPATIAL_DIM(int32_t);
601INSTANTIATE_MEMCPY_IM2COL_SPATIAL_DIM(int16_t);
602
603#undef INSTANTIATE_MEMCPY_IM2COL_SPATIAL_DIM
604#undef INSTANTIATE_MEMCPY_IM2COL_BASE
605
606template class ExecuteKernel<
607 PackAWithQuantRowOffset<uint8_t, int32_t>,
608 PackBMatrix<int8_t, int32_t>,
609 int32_t,
610 memCopy<>>;
611
612template class ExecuteKernel<
613 PackAMatrix<uint8_t, int16_t>,
614 PackBMatrix<int8_t, int16_t>,
615 int32_t,
616 DoNothing<int32_t, int32_t>>;
617
618} // namespace fbgemm
619