1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/Fbgemm.h"
9#include <cpuinfo.h>
10#include <functional>
11#include <stdexcept>
12#include "./ExecuteKernel.h"
13
14#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
15double packing_time = 0.0;
16double computing_time = 0.0;
17double run_time = 0.0;
18#endif
19
20namespace fbgemm {
21
22template <
23 typename packingAMatrix,
24 typename packingBMatrix,
25 typename cT,
26 typename processOutputType>
27void fbgemmPacked(
28 PackMatrix<
29 packingAMatrix,
30 typename packingAMatrix::inpType,
31 typename packingAMatrix::accType>& packA,
32 PackMatrix<
33 packingBMatrix,
34 typename packingBMatrix::inpType,
35 typename packingBMatrix::accType>& packB,
36 cT* C,
37 int32_t* C_buffer,
38 uint32_t ldc,
39 const processOutputType& outProcess,
40 int thread_id,
41 int num_threads,
42 const BlockingFactors* blocking_params) {
43 static_assert(
44 std::is_same<
45 typename packingAMatrix::accType,
46 typename packingBMatrix::accType>::value,
47 "Accumulation type of both matrices should be the same");
48
49 // Run time CPU detection
50 if (!cpuinfo_initialize()) {
51 throw std::runtime_error("Failed to initialize cpuinfo!");
52 }
53 if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
54 !fbgemmHasAvx2Support())) {
55 assert(0 && "unknown architecure");
56 throw std::runtime_error("unknown architecure");
57 }
58
59 int64_t MCB;
60 int KCB;
61 int MR;
62
63 if (blocking_params) {
64 MCB = blocking_params->MCB;
65 KCB = blocking_params->KCB;
66 MR = blocking_params->MR;
67 } else {
68 const inst_set_t isa = fbgemmInstructionSet();
69 switch (isa) {
70 case inst_set_t::avx512_vnni:
71 std::tie(MCB, KCB, MR) = PackingTraits<
72 typename packingAMatrix::inpType,
73 typename packingAMatrix::accType,
74 inst_set_t::avx512_vnni>::getCacheBlockParams();
75 break;
76
77 case inst_set_t::avx512_vnni_ymm:
78 std::tie(MCB, KCB, MR) = PackingTraits<
79 typename packingAMatrix::inpType,
80 typename packingAMatrix::accType,
81 inst_set_t::avx512_vnni_ymm>::getCacheBlockParams();
82 break;
83
84 case inst_set_t::avx512:
85 std::tie(MCB, KCB, MR) = PackingTraits<
86 typename packingAMatrix::inpType,
87 typename packingAMatrix::accType,
88 inst_set_t::avx512>::getCacheBlockParams();
89 break;
90
91 case inst_set_t::avx512_ymm:
92 std::tie(MCB, KCB, MR) = PackingTraits<
93 typename packingAMatrix::inpType,
94 typename packingAMatrix::accType,
95 inst_set_t::avx512_ymm>::getCacheBlockParams();
96 break;
97
98 case inst_set_t::avx2:
99 std::tie(MCB, KCB, MR) = PackingTraits<
100 typename packingAMatrix::inpType,
101 typename packingAMatrix::accType,
102 inst_set_t::avx2>::getCacheBlockParams();
103 break;
104
105 default:
106 assert(0 && "unknown architecure");
107 throw std::runtime_error("unknown architecure");
108 }
109 }
110
111 if (!packB.isPrePacked()) {
112 throw std::runtime_error("B matrix must be prepacked");
113 }
114 int G = packA.numGroups();
115 if (G != packB.numGroups()) {
116 throw std::runtime_error(
117 "A.groups = " + std::to_string(G) + " and B.groups = " +
118 std::to_string(packB.numGroups()) + " are not the same");
119 }
120
121 int MDim = packA.numRows();
122 int KDimPerGroup = packB.numRows() / G;
123 int NDim = packB.numCols();
124
125 int kBlocks = (KDimPerGroup + KCB - 1) / KCB;
126
127 // remainders
128 int _kc = KDimPerGroup % KCB;
129
130 int kc, mc;
131
132 block_type_t blockA{0, 0, 0, 0};
133
134#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
135 std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start,
136 t_start, t_end;
137 double dt;
138 t_start = std::chrono::high_resolution_clock::now();
139 t_very_start = std::chrono::high_resolution_clock::now();
140#endif
141
142 thread_type_t th_info =
143 fbgemmGetThreadPartition(G, MDim, NDim, thread_id, num_threads);
144 // if (thread_id == 0)
145 // std::cout << ", " << th_info.toString();
146
147 int64_t g_begin, g_end, i_begin, i_end;
148
149 // Calculate the begin and end index along the group dimension
150 fbgemmPartition1D(
151 th_info.g_thread_id, th_info.g_num_threads, G, g_begin, g_end);
152 // Calculate the begin and end index along the m dimension
153 fbgemmPartition1DBlocked(
154 th_info.m_thread_id, th_info.m_num_threads, MDim, MR, i_begin, i_end);
155
156 for (int g = g_begin; g < g_end; ++g) {
157 ExecuteKernel<packingAMatrix, packingBMatrix, cT, processOutputType>
158 exeKernelObj(
159 packA,
160 packB,
161 C,
162 C_buffer,
163 ldc,
164 outProcess,
165 th_info,
166 blocking_params);
167 for (int i = i_begin; i < i_end; i += MCB) { // i is the element index
168 mc = std::min(i_end - i, MCB);
169 for (int kb = 0; kb < kBlocks; ++kb) { // kb is the block index
170 kc = (kb != kBlocks - 1 || _kc == 0) ? KCB : _kc;
171 // pack A matrix
172 blockA = {i, mc, g * KDimPerGroup + kb * KCB, kc};
173 packA.pack(blockA);
174
175#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
176 t_end = std::chrono::high_resolution_clock::now();
177 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(
178 t_end - t_start)
179 .count();
180 packing_time += (dt);
181 t_start = std::chrono::high_resolution_clock::now();
182#endif
183
184 exeKernelObj.execute(g * kBlocks + kb);
185
186#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
187 t_end = std::chrono::high_resolution_clock::now();
188 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(
189 t_end - t_start)
190 .count();
191 computing_time += (dt);
192 t_start = std::chrono::high_resolution_clock::now();
193#endif
194 }
195 }
196 } // for each group
197
198#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
199 t_end = std::chrono::high_resolution_clock::now();
200 dt =
201 std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start)
202 .count();
203 run_time += (dt);
204 t_start = std::chrono::high_resolution_clock::now();
205#endif
206}
207
208template <int SPATIAL_DIM>
209bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
210 if (SPATIAL_DIM == 1)
211 return false;
212
213 int C_per_G = conv_p.IC / conv_p.G;
214 int K_per_G = conv_p.OC / conv_p.G;
215
216 int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>::
217 numOfGroupsTogether(conv_p);
218
219 auto areEqual = [](int a, int b) { return a == b; };
220
221 return (C_per_G == K_per_G) &&
222 (C_per_G == 2 || C_per_G == 4 || C_per_G == 8 || C_per_G == 16) &&
223 (conv_p.G >= G_together) &&
224
225 std::all_of(
226 conv_p.K.begin(),
227 conv_p.K.end(),
228 std::bind(areEqual, std::placeholders::_1, 3)) &&
229
230 std::all_of(
231 conv_p.pad.begin(),
232 conv_p.pad.end(),
233 std::bind(areEqual, std::placeholders::_1, 1)) &&
234
235 std::all_of(
236 conv_p.dilation.begin(),
237 conv_p.dilation.end(),
238 std::bind(areEqual, std::placeholders::_1, 1)) &&
239
240 // Height/Width strides should be the same and
241 // should be either 1 or 2
242 // Temporal stride can be anything.
243 (std::all_of(
244 conv_p.stride.begin() + SPATIAL_DIM - 2,
245 conv_p.stride.end(),
246 std::bind(areEqual, std::placeholders::_1, 1)) ||
247 std::all_of(
248 conv_p.stride.begin() + SPATIAL_DIM - 2,
249 conv_p.stride.end(),
250 std::bind(areEqual, std::placeholders::_1, 2))) &&
251 !conv_p.transposed;
252}
253
254template FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<1>& conv_p);
255template FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p);
256template FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p);
257
258bool fbgemmSupportedCPU() {
259#if defined(__x86_64__) || defined(__i386__) || \
260 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
261 return (cpuinfo_initialize() && fbgemmHasAvx2Support());
262#else
263 return cpuinfo_initialize();
264#endif
265}
266
267////////////////////////////////////////////////////////////////////////////////
268// Added for Windows DLL for implicit template parameter instantiation
269template class FBGEMM_API memCopy<std::int32_t, std::int32_t>;
270template class FBGEMM_API DoNothing<std::int32_t, std::int32_t>;
271template class FBGEMM_API DoNothing<float, float>;
272template class FBGEMM_API DoNothing<std::uint8_t, std::uint8_t>;
273template class FBGEMM_API
274 ReQuantizeForFloat<false, QuantizationGranularity::TENSOR>;
275template class FBGEMM_API
276 ReQuantizeForFloat<false, QuantizationGranularity::GROUP>;
277template class FBGEMM_API
278 ReQuantizeForFloat<false, QuantizationGranularity::OUT_CHANNEL>;
279template class FBGEMM_API
280 ReQuantizeForFloat<true, QuantizationGranularity::TENSOR>;
281template class FBGEMM_API
282 ReQuantizeForFloat<true, QuantizationGranularity::GROUP>;
283template class FBGEMM_API
284 ReQuantizeForFloat<true, QuantizationGranularity::OUT_CHANNEL>;
285
286#define INSTANTIATE_BASE(FNAME, RELU, Q_GRAN) \
287 template class FBGEMM_API \
288 FNAME<std::uint8_t, std::int32_t, ReQuantizeOutput<RELU, Q_GRAN>>;
289
290#define INSTANTIATE_Q_GRAN(FNAME, RELU) \
291 INSTANTIATE_BASE(FNAME, RELU, QuantizationGranularity::TENSOR) \
292 INSTANTIATE_BASE(FNAME, RELU, QuantizationGranularity::GROUP) \
293 INSTANTIATE_BASE(FNAME, RELU, QuantizationGranularity::OUT_CHANNEL)
294
295#define INSTANTIATE_RELU(FNAME) \
296 INSTANTIATE_Q_GRAN(FNAME, false) \
297 INSTANTIATE_Q_GRAN(FNAME, true)
298
299INSTANTIATE_RELU(DoSpmdmOnInpBuffer)
300INSTANTIATE_RELU(DoSConvOnInpBuffer)
301
302#undef INSTANTIATE_RELU
303#undef INSTANTIATE_Q_GRAN
304#undef INSTANTIATE_BASE
305
306template class FBGEMM_API DoSpmdmOnInpBuffer<
307 float,
308 std::int32_t,
309 ReQuantizeForFloat<false, QuantizationGranularity::TENSOR>>;
310
311#define INSTANTIATE_BASE(RELU, Q_GRAN, BIAS_TYPE) \
312 template class FBGEMM_API ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>;
313
314#define INSTANTIATE_BIAS_T(RELU, Q_GRAN) \
315 INSTANTIATE_BASE(RELU, Q_GRAN, std::int32_t) \
316 INSTANTIATE_BASE(RELU, Q_GRAN, float)
317
318#define INSTANTIATE_Q_GRAN(RELU) \
319 INSTANTIATE_BIAS_T(RELU, QuantizationGranularity::TENSOR) \
320 INSTANTIATE_BIAS_T(RELU, QuantizationGranularity::GROUP) \
321 INSTANTIATE_BIAS_T(RELU, QuantizationGranularity::OUT_CHANNEL)
322
323INSTANTIATE_Q_GRAN(false)
324INSTANTIATE_Q_GRAN(true)
325
326#undef INSTANTIATE_Q_GRAN
327#undef INSTANTIATE_BIAS_T
328#undef INSTANTIATE_BASE
329
330// ReQuantizeOutput
331#define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \
332 template FBGEMM_API void fbgemmPacked( \
333 PackMatrix<PACK_A<uint8_t, ACC_T>, uint8_t, ACC_T>& packA, \
334 PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
335 uint8_t* C, \
336 int32_t* C_buffer, \
337 uint32_t ldc, \
338 const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
339 int thread_id, \
340 int num_threads, \
341 const BlockingFactors* blocking_params);
342
343#define INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \
344 INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float) \
345 INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t)
346
347#define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \
348 INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR) \
349 INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP) \
350 INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL)
351
352#define INSTANTIATE_RELU(PACK_A, ACC_T) \
353 INSTANTIATE_Q_GRANS(PACK_A, ACC_T, false) \
354 INSTANTIATE_Q_GRANS(PACK_A, ACC_T, true)
355
356#define INSTANTIATE_ACC_T(PACK_A) \
357 INSTANTIATE_RELU(PACK_A, int32_t) \
358 INSTANTIATE_RELU(PACK_A, int16_t)
359
360INSTANTIATE_ACC_T(PackAMatrix)
361INSTANTIATE_ACC_T(PackAWithRowOffset)
362
363#undef INSTANTIATE_ACC_T
364#undef INSTANTIATE_RELU
365#undef INSTANTIATE_Q_GRANS
366#undef INSTANTIATE_BIAS_T
367#undef INSTANTIATE_BASE
368
369#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \
370 template FBGEMM_API void fbgemmPacked( \
371 PackMatrix< \
372 PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
373 uint8_t, \
374 ACC_T>& packA, \
375 PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
376 uint8_t* C, \
377 int32_t* C_buffer, \
378 uint32_t ldc, \
379 const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
380 int thread_id, \
381 int num_threads, \
382 const BlockingFactors* blocking_params);
383
384#define INSTANTIATE_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
385 INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float) \
386 INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t)
387
388#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
389 INSTANTIATE_BIAS_T( \
390 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR) \
391 INSTANTIATE_BIAS_T(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP) \
392 INSTANTIATE_BIAS_T( \
393 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL)
394
395#define INSTANTIATE_SPATIAL_DIM(ACC_T, RELU) \
396 INSTANTIATE_Q_GRANS(ACC_T, RELU, 1) \
397 INSTANTIATE_Q_GRANS(ACC_T, RELU, 2) \
398 INSTANTIATE_Q_GRANS(ACC_T, RELU, 3)
399
400#define INSTANTIATE_RELU(ACC_T) \
401 INSTANTIATE_SPATIAL_DIM(ACC_T, false) \
402 INSTANTIATE_SPATIAL_DIM(ACC_T, true)
403
404INSTANTIATE_RELU(int32_t)
405INSTANTIATE_RELU(int16_t)
406
407#undef INSTANTIATE_RELU
408#undef INSTANTIATE_SPATIAL_DIM
409#undef INSTANTIATE_Q_GRANS
410#undef INSTANTIATE_BIAS_T
411#undef INSTANTIATE_BASE
412
413////////////////////////////////////////////////////////////////////////////////
414// ReQuantizeForFloat
415#define INSTANTIATE_BASE(PACK_A, RELU, Q_GRAN) \
416 template FBGEMM_API void fbgemmPacked( \
417 PackMatrix<PACK_A<uint8_t, int32_t>, uint8_t, int32_t>& packA, \
418 PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, \
419 float* C, \
420 int32_t* C_buffer, \
421 uint32_t ldc, \
422 const ReQuantizeForFloat<RELU, Q_GRAN>& outProcess, \
423 int thread_id, \
424 int num_threads, \
425 const BlockingFactors* blocking_params);
426
427#define INSTANTIATE_Q_GRANS(PACK_A, RELU) \
428 INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR) \
429 INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::GROUP) \
430 INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::OUT_CHANNEL)
431
432#define INSTANTIATE_RELU(PACK_A) \
433 INSTANTIATE_Q_GRANS(PACK_A, false) \
434 INSTANTIATE_Q_GRANS(PACK_A, true)
435
436INSTANTIATE_RELU(PackAWithRowOffset)
437INSTANTIATE_RELU(PackAWithQuantRowOffset);
438
439#undef INSTANTIATE_RELU
440#undef INSTANTIATE_Q_GRANS
441#undef INSTANTIATE_BASE
442
443#define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
444 template FBGEMM_API void fbgemmPacked( \
445 PackMatrix< \
446 PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
447 uint8_t, \
448 ACC_T>& packA, \
449 PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
450 float* C, \
451 int32_t* C_buffer, \
452 uint32_t ldc, \
453 const ReQuantizeForFloat<RELU, Q_GRAN>& outProcess, \
454 int thread_id, \
455 int num_threads, \
456 const BlockingFactors* blocking_params);
457
458#define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
459 INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR) \
460 INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP) \
461 INSTANTIATE_BASE( \
462 ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL)
463
464#define INSTANTIATE_SPATIAL_DIM(ACC_T, RELU) \
465 INSTANTIATE_Q_GRANS(ACC_T, RELU, 1) \
466 INSTANTIATE_Q_GRANS(ACC_T, RELU, 2) \
467 INSTANTIATE_Q_GRANS(ACC_T, RELU, 3)
468
469#define INSTANTIATE_RELU(ACC_T) \
470 INSTANTIATE_SPATIAL_DIM(ACC_T, false) \
471 INSTANTIATE_SPATIAL_DIM(ACC_T, true)
472
473INSTANTIATE_RELU(int32_t)
474INSTANTIATE_RELU(int16_t)
475
476#undef INSTANTIATE_RELU
477#undef INSTANTIATE_SPATIAL_DIM
478#undef INSTANTIATE_Q_GRANS
479#undef INSTANTIATE_BASE
480
481template FBGEMM_API void fbgemmPacked(
482 PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
483 PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
484 float* C,
485 int32_t* C_buffer,
486 uint32_t ldc,
487 const ReQuantizeForFloat<false>& outProcess,
488 int thread_id,
489 int num_threads,
490 const BlockingFactors* blocking_params);
491
492////////////////////////////////////////////////////////////////////////////////
493// DoSpmdmOnInpBuffer
494#define INSTANTIATE_BASE(PACK_A, RELU, Q_GRAN) \
495 template FBGEMM_API void fbgemmPacked( \
496 PackMatrix<PACK_A<uint8_t, int16_t>, uint8_t, int16_t>& packA, \
497 PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, \
498 uint8_t* C, \
499 int32_t* C_buffer, \
500 uint32_t ldc, \
501 const DoSpmdmOnInpBuffer< \
502 uint8_t, \
503 int32_t, \
504 ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \
505 int thread_id, \
506 int num_threads, \
507 const BlockingFactors* blocking_params);
508
509#define INSTANTIATE_Q_GRANS(PACK_A, RELU) \
510 INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR) \
511 INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::GROUP) \
512 INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::OUT_CHANNEL)
513
514#define INSTANTIATE_RELU(PACK_A) \
515 INSTANTIATE_Q_GRANS(PACK_A, false) \
516 INSTANTIATE_Q_GRANS(PACK_A, true)
517
518INSTANTIATE_RELU(PackAMatrix)
519INSTANTIATE_RELU(PackAWithRowOffset)
520
521#undef INSTANTIATE_Q_GRANS
522#undef INSTANTIATE_BASE
523#undef INSTANTIATE_RELU
524
525#define INSTANTIATE_BASE(RELU, Q_GRAN) \
526 template FBGEMM_API void fbgemmPacked( \
527 PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>& packA, \
528 PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, \
529 uint8_t* C, \
530 int32_t* C_buffer, \
531 uint32_t ldc, \
532 const DoSConvOnInpBuffer< \
533 uint8_t, \
534 int32_t, \
535 ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \
536 int thread_id, \
537 int num_threads, \
538 const BlockingFactors* blocking_params);
539
540#define INSTANTIATE_Q_GRANS(RELU) \
541 INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR) \
542 INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP) \
543 INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL)
544
545INSTANTIATE_Q_GRANS(false)
546INSTANTIATE_Q_GRANS(true)
547
548#undef INSTANTIATE_Q_GRANS
549#undef INSTANTIATE_BASE
550
551template FBGEMM_API void fbgemmPacked(
552 PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
553 PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
554 float* C,
555 int32_t* C_buffer,
556 uint32_t ldc,
557 const DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>&
558 outProcess,
559 int thread_id,
560 int num_threads,
561 const BlockingFactors* blocking_params);
562
563////////////////////////////////////////////////////////////////////////////////
564// memCopy
565#define INSTANTIATE_BASE(PACK_A, ACC_T) \
566 template FBGEMM_API void fbgemmPacked( \
567 PackMatrix<PACK_A<uint8_t, ACC_T>, uint8_t, ACC_T>& packA, \
568 PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
569 int32_t* C, \
570 int32_t* C_buffer, \
571 uint32_t ldc, \
572 const memCopy<>& outProcess, \
573 int thread_id, \
574 int num_threads, \
575 const BlockingFactors* blocking_params);
576
577#define INSTANTIATE_ACC_T(PACK_A) \
578 INSTANTIATE_BASE(PACK_A, int32_t) \
579 INSTANTIATE_BASE(PACK_A, int16_t)
580
581INSTANTIATE_ACC_T(PackAMatrix)
582INSTANTIATE_ACC_T(PackAWithRowOffset)
583
584#undef INSTANTIATE_ACC_T
585#undef INSTANTIATE_BASE
586
587#define INSTANTIATE_BASE(ACC_T, SPATIAL_DIM) \
588 template FBGEMM_API void fbgemmPacked( \
589 PackMatrix< \
590 PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
591 uint8_t, \
592 ACC_T>& packA, \
593 PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \
594 int32_t* C, \
595 int32_t* C_buffer, \
596 uint32_t ldc, \
597 const memCopy<>& outProcess, \
598 int thread_id, \
599 int num_threads, \
600 const BlockingFactors* blocking_params);
601
602#define INSTANTIATE_SPATIAL_DIM(ACC_T) \
603 INSTANTIATE_BASE(ACC_T, 1) \
604 INSTANTIATE_BASE(ACC_T, 2) \
605 INSTANTIATE_BASE(ACC_T, 3)
606
607INSTANTIATE_SPATIAL_DIM(int32_t)
608INSTANTIATE_SPATIAL_DIM(int16_t)
609
610#undef INSTANTIATE_SPATIAL_DIM
611#undef INSTANTIATE_BASE
612
613template FBGEMM_API void fbgemmPacked(
614 PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
615 packA,
616 PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
617 int32_t* C,
618 int32_t* C_buffer,
619 uint32_t ldc,
620 const memCopy<>& outProcess,
621 int thread_id,
622 int num_threads,
623 const BlockingFactors* blocking_params);
624
625template FBGEMM_API void fbgemmPacked(
626 PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
627 PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
628 int32_t* C,
629 int32_t* C_buffer,
630 uint32_t ldc,
631 const DoNothing<int32_t, int32_t>& outProcess,
632 int thread_id,
633 int num_threads,
634 const BlockingFactors* blocking_params);
635
636} // namespace fbgemm
637