1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/Fbgemm.h" |
9 | #include <cpuinfo.h> |
10 | #include <functional> |
11 | #include <stdexcept> |
12 | #include "./ExecuteKernel.h" |
13 | |
14 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
15 | double packing_time = 0.0; |
16 | double computing_time = 0.0; |
17 | double run_time = 0.0; |
18 | #endif |
19 | |
20 | namespace fbgemm { |
21 | |
22 | template < |
23 | typename packingAMatrix, |
24 | typename packingBMatrix, |
25 | typename cT, |
26 | typename processOutputType> |
27 | void fbgemmPacked( |
28 | PackMatrix< |
29 | packingAMatrix, |
30 | typename packingAMatrix::inpType, |
31 | typename packingAMatrix::accType>& packA, |
32 | PackMatrix< |
33 | packingBMatrix, |
34 | typename packingBMatrix::inpType, |
35 | typename packingBMatrix::accType>& packB, |
36 | cT* C, |
37 | int32_t* C_buffer, |
38 | uint32_t ldc, |
39 | const processOutputType& outProcess, |
40 | int thread_id, |
41 | int num_threads, |
42 | const BlockingFactors* blocking_params) { |
43 | static_assert( |
44 | std::is_same< |
45 | typename packingAMatrix::accType, |
46 | typename packingBMatrix::accType>::value, |
47 | "Accumulation type of both matrices should be the same" ); |
48 | |
49 | // Run time CPU detection |
50 | if (!cpuinfo_initialize()) { |
51 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
52 | } |
53 | if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && |
54 | !fbgemmHasAvx2Support())) { |
55 | assert(0 && "unknown architecure" ); |
56 | throw std::runtime_error("unknown architecure" ); |
57 | } |
58 | |
59 | int64_t MCB; |
60 | int KCB; |
61 | int MR; |
62 | |
63 | if (blocking_params) { |
64 | MCB = blocking_params->MCB; |
65 | KCB = blocking_params->KCB; |
66 | MR = blocking_params->MR; |
67 | } else { |
68 | const inst_set_t isa = fbgemmInstructionSet(); |
69 | switch (isa) { |
70 | case inst_set_t::avx512_vnni: |
71 | std::tie(MCB, KCB, MR) = PackingTraits< |
72 | typename packingAMatrix::inpType, |
73 | typename packingAMatrix::accType, |
74 | inst_set_t::avx512_vnni>::getCacheBlockParams(); |
75 | break; |
76 | |
77 | case inst_set_t::avx512_vnni_ymm: |
78 | std::tie(MCB, KCB, MR) = PackingTraits< |
79 | typename packingAMatrix::inpType, |
80 | typename packingAMatrix::accType, |
81 | inst_set_t::avx512_vnni_ymm>::getCacheBlockParams(); |
82 | break; |
83 | |
84 | case inst_set_t::avx512: |
85 | std::tie(MCB, KCB, MR) = PackingTraits< |
86 | typename packingAMatrix::inpType, |
87 | typename packingAMatrix::accType, |
88 | inst_set_t::avx512>::getCacheBlockParams(); |
89 | break; |
90 | |
91 | case inst_set_t::avx512_ymm: |
92 | std::tie(MCB, KCB, MR) = PackingTraits< |
93 | typename packingAMatrix::inpType, |
94 | typename packingAMatrix::accType, |
95 | inst_set_t::avx512_ymm>::getCacheBlockParams(); |
96 | break; |
97 | |
98 | case inst_set_t::avx2: |
99 | std::tie(MCB, KCB, MR) = PackingTraits< |
100 | typename packingAMatrix::inpType, |
101 | typename packingAMatrix::accType, |
102 | inst_set_t::avx2>::getCacheBlockParams(); |
103 | break; |
104 | |
105 | default: |
106 | assert(0 && "unknown architecure" ); |
107 | throw std::runtime_error("unknown architecure" ); |
108 | } |
109 | } |
110 | |
111 | if (!packB.isPrePacked()) { |
112 | throw std::runtime_error("B matrix must be prepacked" ); |
113 | } |
114 | int G = packA.numGroups(); |
115 | if (G != packB.numGroups()) { |
116 | throw std::runtime_error( |
117 | "A.groups = " + std::to_string(G) + " and B.groups = " + |
118 | std::to_string(packB.numGroups()) + " are not the same" ); |
119 | } |
120 | |
121 | int MDim = packA.numRows(); |
122 | int KDimPerGroup = packB.numRows() / G; |
123 | int NDim = packB.numCols(); |
124 | |
125 | int kBlocks = (KDimPerGroup + KCB - 1) / KCB; |
126 | |
127 | // remainders |
128 | int _kc = KDimPerGroup % KCB; |
129 | |
130 | int kc, mc; |
131 | |
132 | block_type_t blockA{0, 0, 0, 0}; |
133 | |
134 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
135 | std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start, |
136 | t_start, t_end; |
137 | double dt; |
138 | t_start = std::chrono::high_resolution_clock::now(); |
139 | t_very_start = std::chrono::high_resolution_clock::now(); |
140 | #endif |
141 | |
142 | thread_type_t th_info = |
143 | fbgemmGetThreadPartition(G, MDim, NDim, thread_id, num_threads); |
144 | // if (thread_id == 0) |
145 | // std::cout << ", " << th_info.toString(); |
146 | |
147 | int64_t g_begin, g_end, i_begin, i_end; |
148 | |
149 | // Calculate the begin and end index along the group dimension |
150 | fbgemmPartition1D( |
151 | th_info.g_thread_id, th_info.g_num_threads, G, g_begin, g_end); |
152 | // Calculate the begin and end index along the m dimension |
153 | fbgemmPartition1DBlocked( |
154 | th_info.m_thread_id, th_info.m_num_threads, MDim, MR, i_begin, i_end); |
155 | |
156 | for (int g = g_begin; g < g_end; ++g) { |
157 | ExecuteKernel<packingAMatrix, packingBMatrix, cT, processOutputType> |
158 | exeKernelObj( |
159 | packA, |
160 | packB, |
161 | C, |
162 | C_buffer, |
163 | ldc, |
164 | outProcess, |
165 | th_info, |
166 | blocking_params); |
167 | for (int i = i_begin; i < i_end; i += MCB) { // i is the element index |
168 | mc = std::min(i_end - i, MCB); |
169 | for (int kb = 0; kb < kBlocks; ++kb) { // kb is the block index |
170 | kc = (kb != kBlocks - 1 || _kc == 0) ? KCB : _kc; |
171 | // pack A matrix |
172 | blockA = {i, mc, g * KDimPerGroup + kb * KCB, kc}; |
173 | packA.pack(blockA); |
174 | |
175 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
176 | t_end = std::chrono::high_resolution_clock::now(); |
177 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>( |
178 | t_end - t_start) |
179 | .count(); |
180 | packing_time += (dt); |
181 | t_start = std::chrono::high_resolution_clock::now(); |
182 | #endif |
183 | |
184 | exeKernelObj.execute(g * kBlocks + kb); |
185 | |
186 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
187 | t_end = std::chrono::high_resolution_clock::now(); |
188 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>( |
189 | t_end - t_start) |
190 | .count(); |
191 | computing_time += (dt); |
192 | t_start = std::chrono::high_resolution_clock::now(); |
193 | #endif |
194 | } |
195 | } |
196 | } // for each group |
197 | |
198 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
199 | t_end = std::chrono::high_resolution_clock::now(); |
200 | dt = |
201 | std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start) |
202 | .count(); |
203 | run_time += (dt); |
204 | t_start = std::chrono::high_resolution_clock::now(); |
205 | #endif |
206 | } |
207 | |
208 | template <int SPATIAL_DIM> |
209 | bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) { |
210 | if (SPATIAL_DIM == 1) |
211 | return false; |
212 | |
213 | int C_per_G = conv_p.IC / conv_p.G; |
214 | int K_per_G = conv_p.OC / conv_p.G; |
215 | |
216 | int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>:: |
217 | numOfGroupsTogether(conv_p); |
218 | |
219 | auto areEqual = [](int a, int b) { return a == b; }; |
220 | |
221 | return (C_per_G == K_per_G) && |
222 | (C_per_G == 2 || C_per_G == 4 || C_per_G == 8 || C_per_G == 16) && |
223 | (conv_p.G >= G_together) && |
224 | |
225 | std::all_of( |
226 | conv_p.K.begin(), |
227 | conv_p.K.end(), |
228 | std::bind(areEqual, std::placeholders::_1, 3)) && |
229 | |
230 | std::all_of( |
231 | conv_p.pad.begin(), |
232 | conv_p.pad.end(), |
233 | std::bind(areEqual, std::placeholders::_1, 1)) && |
234 | |
235 | std::all_of( |
236 | conv_p.dilation.begin(), |
237 | conv_p.dilation.end(), |
238 | std::bind(areEqual, std::placeholders::_1, 1)) && |
239 | |
240 | // Height/Width strides should be the same and |
241 | // should be either 1 or 2 |
242 | // Temporal stride can be anything. |
243 | (std::all_of( |
244 | conv_p.stride.begin() + SPATIAL_DIM - 2, |
245 | conv_p.stride.end(), |
246 | std::bind(areEqual, std::placeholders::_1, 1)) || |
247 | std::all_of( |
248 | conv_p.stride.begin() + SPATIAL_DIM - 2, |
249 | conv_p.stride.end(), |
250 | std::bind(areEqual, std::placeholders::_1, 2))) && |
251 | !conv_p.transposed; |
252 | } |
253 | |
254 | template FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<1>& conv_p); |
255 | template FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p); |
256 | template FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p); |
257 | |
258 | bool fbgemmSupportedCPU() { |
259 | #if defined(__x86_64__) || defined(__i386__) || \ |
260 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
261 | return (cpuinfo_initialize() && fbgemmHasAvx2Support()); |
262 | #else |
263 | return cpuinfo_initialize(); |
264 | #endif |
265 | } |
266 | |
267 | //////////////////////////////////////////////////////////////////////////////// |
268 | // Added for Windows DLL for implicit template parameter instantiation |
269 | template class FBGEMM_API memCopy<std::int32_t, std::int32_t>; |
270 | template class FBGEMM_API DoNothing<std::int32_t, std::int32_t>; |
271 | template class FBGEMM_API DoNothing<float, float>; |
272 | template class FBGEMM_API DoNothing<std::uint8_t, std::uint8_t>; |
273 | template class FBGEMM_API |
274 | ReQuantizeForFloat<false, QuantizationGranularity::TENSOR>; |
275 | template class FBGEMM_API |
276 | ReQuantizeForFloat<false, QuantizationGranularity::GROUP>; |
277 | template class FBGEMM_API |
278 | ReQuantizeForFloat<false, QuantizationGranularity::OUT_CHANNEL>; |
279 | template class FBGEMM_API |
280 | ReQuantizeForFloat<true, QuantizationGranularity::TENSOR>; |
281 | template class FBGEMM_API |
282 | ReQuantizeForFloat<true, QuantizationGranularity::GROUP>; |
283 | template class FBGEMM_API |
284 | ReQuantizeForFloat<true, QuantizationGranularity::OUT_CHANNEL>; |
285 | |
286 | #define INSTANTIATE_BASE(FNAME, RELU, Q_GRAN) \ |
287 | template class FBGEMM_API \ |
288 | FNAME<std::uint8_t, std::int32_t, ReQuantizeOutput<RELU, Q_GRAN>>; |
289 | |
290 | #define INSTANTIATE_Q_GRAN(FNAME, RELU) \ |
291 | INSTANTIATE_BASE(FNAME, RELU, QuantizationGranularity::TENSOR) \ |
292 | INSTANTIATE_BASE(FNAME, RELU, QuantizationGranularity::GROUP) \ |
293 | INSTANTIATE_BASE(FNAME, RELU, QuantizationGranularity::OUT_CHANNEL) |
294 | |
295 | #define INSTANTIATE_RELU(FNAME) \ |
296 | INSTANTIATE_Q_GRAN(FNAME, false) \ |
297 | INSTANTIATE_Q_GRAN(FNAME, true) |
298 | |
299 | INSTANTIATE_RELU(DoSpmdmOnInpBuffer) |
300 | INSTANTIATE_RELU(DoSConvOnInpBuffer) |
301 | |
302 | #undef INSTANTIATE_RELU |
303 | #undef INSTANTIATE_Q_GRAN |
304 | #undef INSTANTIATE_BASE |
305 | |
306 | template class FBGEMM_API DoSpmdmOnInpBuffer< |
307 | float, |
308 | std::int32_t, |
309 | ReQuantizeForFloat<false, QuantizationGranularity::TENSOR>>; |
310 | |
311 | #define INSTANTIATE_BASE(RELU, Q_GRAN, BIAS_TYPE) \ |
312 | template class FBGEMM_API ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>; |
313 | |
314 | #define INSTANTIATE_BIAS_T(RELU, Q_GRAN) \ |
315 | INSTANTIATE_BASE(RELU, Q_GRAN, std::int32_t) \ |
316 | INSTANTIATE_BASE(RELU, Q_GRAN, float) |
317 | |
318 | #define INSTANTIATE_Q_GRAN(RELU) \ |
319 | INSTANTIATE_BIAS_T(RELU, QuantizationGranularity::TENSOR) \ |
320 | INSTANTIATE_BIAS_T(RELU, QuantizationGranularity::GROUP) \ |
321 | INSTANTIATE_BIAS_T(RELU, QuantizationGranularity::OUT_CHANNEL) |
322 | |
323 | INSTANTIATE_Q_GRAN(false) |
324 | INSTANTIATE_Q_GRAN(true) |
325 | |
326 | #undef INSTANTIATE_Q_GRAN |
327 | #undef INSTANTIATE_BIAS_T |
328 | #undef INSTANTIATE_BASE |
329 | |
330 | // ReQuantizeOutput |
331 | #define INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \ |
332 | template FBGEMM_API void fbgemmPacked( \ |
333 | PackMatrix<PACK_A<uint8_t, ACC_T>, uint8_t, ACC_T>& packA, \ |
334 | PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ |
335 | uint8_t* C, \ |
336 | int32_t* C_buffer, \ |
337 | uint32_t ldc, \ |
338 | const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ |
339 | int thread_id, \ |
340 | int num_threads, \ |
341 | const BlockingFactors* blocking_params); |
342 | |
343 | #define INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \ |
344 | INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float) \ |
345 | INSTANTIATE_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t) |
346 | |
347 | #define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \ |
348 | INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR) \ |
349 | INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP) \ |
350 | INSTANTIATE_BIAS_T(PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL) |
351 | |
352 | #define INSTANTIATE_RELU(PACK_A, ACC_T) \ |
353 | INSTANTIATE_Q_GRANS(PACK_A, ACC_T, false) \ |
354 | INSTANTIATE_Q_GRANS(PACK_A, ACC_T, true) |
355 | |
356 | #define INSTANTIATE_ACC_T(PACK_A) \ |
357 | INSTANTIATE_RELU(PACK_A, int32_t) \ |
358 | INSTANTIATE_RELU(PACK_A, int16_t) |
359 | |
360 | INSTANTIATE_ACC_T(PackAMatrix) |
361 | INSTANTIATE_ACC_T(PackAWithRowOffset) |
362 | |
363 | #undef INSTANTIATE_ACC_T |
364 | #undef INSTANTIATE_RELU |
365 | #undef INSTANTIATE_Q_GRANS |
366 | #undef INSTANTIATE_BIAS_T |
367 | #undef INSTANTIATE_BASE |
368 | |
369 | #define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \ |
370 | template FBGEMM_API void fbgemmPacked( \ |
371 | PackMatrix< \ |
372 | PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ |
373 | uint8_t, \ |
374 | ACC_T>& packA, \ |
375 | PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ |
376 | uint8_t* C, \ |
377 | int32_t* C_buffer, \ |
378 | uint32_t ldc, \ |
379 | const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ |
380 | int thread_id, \ |
381 | int num_threads, \ |
382 | const BlockingFactors* blocking_params); |
383 | |
384 | #define INSTANTIATE_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ |
385 | INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float) \ |
386 | INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t) |
387 | |
388 | #define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ |
389 | INSTANTIATE_BIAS_T( \ |
390 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR) \ |
391 | INSTANTIATE_BIAS_T(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP) \ |
392 | INSTANTIATE_BIAS_T( \ |
393 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL) |
394 | |
395 | #define INSTANTIATE_SPATIAL_DIM(ACC_T, RELU) \ |
396 | INSTANTIATE_Q_GRANS(ACC_T, RELU, 1) \ |
397 | INSTANTIATE_Q_GRANS(ACC_T, RELU, 2) \ |
398 | INSTANTIATE_Q_GRANS(ACC_T, RELU, 3) |
399 | |
400 | #define INSTANTIATE_RELU(ACC_T) \ |
401 | INSTANTIATE_SPATIAL_DIM(ACC_T, false) \ |
402 | INSTANTIATE_SPATIAL_DIM(ACC_T, true) |
403 | |
404 | INSTANTIATE_RELU(int32_t) |
405 | INSTANTIATE_RELU(int16_t) |
406 | |
407 | #undef INSTANTIATE_RELU |
408 | #undef INSTANTIATE_SPATIAL_DIM |
409 | #undef INSTANTIATE_Q_GRANS |
410 | #undef INSTANTIATE_BIAS_T |
411 | #undef INSTANTIATE_BASE |
412 | |
413 | //////////////////////////////////////////////////////////////////////////////// |
414 | // ReQuantizeForFloat |
415 | #define INSTANTIATE_BASE(PACK_A, RELU, Q_GRAN) \ |
416 | template FBGEMM_API void fbgemmPacked( \ |
417 | PackMatrix<PACK_A<uint8_t, int32_t>, uint8_t, int32_t>& packA, \ |
418 | PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, \ |
419 | float* C, \ |
420 | int32_t* C_buffer, \ |
421 | uint32_t ldc, \ |
422 | const ReQuantizeForFloat<RELU, Q_GRAN>& outProcess, \ |
423 | int thread_id, \ |
424 | int num_threads, \ |
425 | const BlockingFactors* blocking_params); |
426 | |
427 | #define INSTANTIATE_Q_GRANS(PACK_A, RELU) \ |
428 | INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR) \ |
429 | INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::GROUP) \ |
430 | INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::OUT_CHANNEL) |
431 | |
432 | #define INSTANTIATE_RELU(PACK_A) \ |
433 | INSTANTIATE_Q_GRANS(PACK_A, false) \ |
434 | INSTANTIATE_Q_GRANS(PACK_A, true) |
435 | |
436 | INSTANTIATE_RELU(PackAWithRowOffset) |
437 | INSTANTIATE_RELU(PackAWithQuantRowOffset); |
438 | |
439 | #undef INSTANTIATE_RELU |
440 | #undef INSTANTIATE_Q_GRANS |
441 | #undef INSTANTIATE_BASE |
442 | |
443 | #define INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ |
444 | template FBGEMM_API void fbgemmPacked( \ |
445 | PackMatrix< \ |
446 | PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ |
447 | uint8_t, \ |
448 | ACC_T>& packA, \ |
449 | PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ |
450 | float* C, \ |
451 | int32_t* C_buffer, \ |
452 | uint32_t ldc, \ |
453 | const ReQuantizeForFloat<RELU, Q_GRAN>& outProcess, \ |
454 | int thread_id, \ |
455 | int num_threads, \ |
456 | const BlockingFactors* blocking_params); |
457 | |
458 | #define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ |
459 | INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR) \ |
460 | INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP) \ |
461 | INSTANTIATE_BASE( \ |
462 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL) |
463 | |
464 | #define INSTANTIATE_SPATIAL_DIM(ACC_T, RELU) \ |
465 | INSTANTIATE_Q_GRANS(ACC_T, RELU, 1) \ |
466 | INSTANTIATE_Q_GRANS(ACC_T, RELU, 2) \ |
467 | INSTANTIATE_Q_GRANS(ACC_T, RELU, 3) |
468 | |
469 | #define INSTANTIATE_RELU(ACC_T) \ |
470 | INSTANTIATE_SPATIAL_DIM(ACC_T, false) \ |
471 | INSTANTIATE_SPATIAL_DIM(ACC_T, true) |
472 | |
473 | INSTANTIATE_RELU(int32_t) |
474 | INSTANTIATE_RELU(int16_t) |
475 | |
476 | #undef INSTANTIATE_RELU |
477 | #undef INSTANTIATE_SPATIAL_DIM |
478 | #undef INSTANTIATE_Q_GRANS |
479 | #undef INSTANTIATE_BASE |
480 | |
481 | template FBGEMM_API void fbgemmPacked( |
482 | PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, |
483 | PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, |
484 | float* C, |
485 | int32_t* C_buffer, |
486 | uint32_t ldc, |
487 | const ReQuantizeForFloat<false>& outProcess, |
488 | int thread_id, |
489 | int num_threads, |
490 | const BlockingFactors* blocking_params); |
491 | |
492 | //////////////////////////////////////////////////////////////////////////////// |
493 | // DoSpmdmOnInpBuffer |
494 | #define INSTANTIATE_BASE(PACK_A, RELU, Q_GRAN) \ |
495 | template FBGEMM_API void fbgemmPacked( \ |
496 | PackMatrix<PACK_A<uint8_t, int16_t>, uint8_t, int16_t>& packA, \ |
497 | PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, \ |
498 | uint8_t* C, \ |
499 | int32_t* C_buffer, \ |
500 | uint32_t ldc, \ |
501 | const DoSpmdmOnInpBuffer< \ |
502 | uint8_t, \ |
503 | int32_t, \ |
504 | ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \ |
505 | int thread_id, \ |
506 | int num_threads, \ |
507 | const BlockingFactors* blocking_params); |
508 | |
509 | #define INSTANTIATE_Q_GRANS(PACK_A, RELU) \ |
510 | INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR) \ |
511 | INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::GROUP) \ |
512 | INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::OUT_CHANNEL) |
513 | |
514 | #define INSTANTIATE_RELU(PACK_A) \ |
515 | INSTANTIATE_Q_GRANS(PACK_A, false) \ |
516 | INSTANTIATE_Q_GRANS(PACK_A, true) |
517 | |
518 | INSTANTIATE_RELU(PackAMatrix) |
519 | INSTANTIATE_RELU(PackAWithRowOffset) |
520 | |
521 | #undef INSTANTIATE_Q_GRANS |
522 | #undef INSTANTIATE_BASE |
523 | #undef INSTANTIATE_RELU |
524 | |
525 | #define INSTANTIATE_BASE(RELU, Q_GRAN) \ |
526 | template FBGEMM_API void fbgemmPacked( \ |
527 | PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>& packA, \ |
528 | PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, \ |
529 | uint8_t* C, \ |
530 | int32_t* C_buffer, \ |
531 | uint32_t ldc, \ |
532 | const DoSConvOnInpBuffer< \ |
533 | uint8_t, \ |
534 | int32_t, \ |
535 | ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \ |
536 | int thread_id, \ |
537 | int num_threads, \ |
538 | const BlockingFactors* blocking_params); |
539 | |
540 | #define INSTANTIATE_Q_GRANS(RELU) \ |
541 | INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR) \ |
542 | INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP) \ |
543 | INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL) |
544 | |
545 | INSTANTIATE_Q_GRANS(false) |
546 | INSTANTIATE_Q_GRANS(true) |
547 | |
548 | #undef INSTANTIATE_Q_GRANS |
549 | #undef INSTANTIATE_BASE |
550 | |
551 | template FBGEMM_API void fbgemmPacked( |
552 | PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA, |
553 | PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, |
554 | float* C, |
555 | int32_t* C_buffer, |
556 | uint32_t ldc, |
557 | const DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>& |
558 | outProcess, |
559 | int thread_id, |
560 | int num_threads, |
561 | const BlockingFactors* blocking_params); |
562 | |
563 | //////////////////////////////////////////////////////////////////////////////// |
564 | // memCopy |
565 | #define INSTANTIATE_BASE(PACK_A, ACC_T) \ |
566 | template FBGEMM_API void fbgemmPacked( \ |
567 | PackMatrix<PACK_A<uint8_t, ACC_T>, uint8_t, ACC_T>& packA, \ |
568 | PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ |
569 | int32_t* C, \ |
570 | int32_t* C_buffer, \ |
571 | uint32_t ldc, \ |
572 | const memCopy<>& outProcess, \ |
573 | int thread_id, \ |
574 | int num_threads, \ |
575 | const BlockingFactors* blocking_params); |
576 | |
577 | #define INSTANTIATE_ACC_T(PACK_A) \ |
578 | INSTANTIATE_BASE(PACK_A, int32_t) \ |
579 | INSTANTIATE_BASE(PACK_A, int16_t) |
580 | |
581 | INSTANTIATE_ACC_T(PackAMatrix) |
582 | INSTANTIATE_ACC_T(PackAWithRowOffset) |
583 | |
584 | #undef INSTANTIATE_ACC_T |
585 | #undef INSTANTIATE_BASE |
586 | |
587 | #define INSTANTIATE_BASE(ACC_T, SPATIAL_DIM) \ |
588 | template FBGEMM_API void fbgemmPacked( \ |
589 | PackMatrix< \ |
590 | PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ |
591 | uint8_t, \ |
592 | ACC_T>& packA, \ |
593 | PackMatrix<PackBMatrix<int8_t, ACC_T>, int8_t, ACC_T>& packB, \ |
594 | int32_t* C, \ |
595 | int32_t* C_buffer, \ |
596 | uint32_t ldc, \ |
597 | const memCopy<>& outProcess, \ |
598 | int thread_id, \ |
599 | int num_threads, \ |
600 | const BlockingFactors* blocking_params); |
601 | |
602 | #define INSTANTIATE_SPATIAL_DIM(ACC_T) \ |
603 | INSTANTIATE_BASE(ACC_T, 1) \ |
604 | INSTANTIATE_BASE(ACC_T, 2) \ |
605 | INSTANTIATE_BASE(ACC_T, 3) |
606 | |
607 | INSTANTIATE_SPATIAL_DIM(int32_t) |
608 | INSTANTIATE_SPATIAL_DIM(int16_t) |
609 | |
610 | #undef INSTANTIATE_SPATIAL_DIM |
611 | #undef INSTANTIATE_BASE |
612 | |
613 | template FBGEMM_API void fbgemmPacked( |
614 | PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& |
615 | packA, |
616 | PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, |
617 | int32_t* C, |
618 | int32_t* C_buffer, |
619 | uint32_t ldc, |
620 | const memCopy<>& outProcess, |
621 | int thread_id, |
622 | int num_threads, |
623 | const BlockingFactors* blocking_params); |
624 | |
625 | template FBGEMM_API void fbgemmPacked( |
626 | PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA, |
627 | PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, |
628 | int32_t* C, |
629 | int32_t* C_buffer, |
630 | uint32_t ldc, |
631 | const DoNothing<int32_t, int32_t>& outProcess, |
632 | int thread_id, |
633 | int num_threads, |
634 | const BlockingFactors* blocking_params); |
635 | |
636 | } // namespace fbgemm |
637 | |