1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #include "./ExecuteKernelU8S8.h" |
8 | #include <cpuinfo.h> |
9 | #include <chrono> |
10 | |
11 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
12 | double kernel_time = 0.0; |
13 | double postprocessing_time = 0.0; |
14 | #endif |
15 | |
16 | namespace fbgemm { |
17 | |
18 | template <typename packingAMatrix, typename cT, typename processOutputType> |
19 | ExecuteKernel< |
20 | packingAMatrix, |
21 | PackBMatrix<int8_t, typename packingAMatrix::accType>, |
22 | cT, |
23 | processOutputType>:: |
24 | ExecuteKernel( |
25 | PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>& |
26 | packA, |
27 | PackMatrix< |
28 | PackBMatrix<int8_t, typename packingAMatrix::accType>, |
29 | int8_t, |
30 | typename packingAMatrix::accType>& packB, |
31 | cT* matC, |
32 | int32_t* C_buffer, |
33 | int32_t ldc, |
34 | const processOutputType& outputProcess, |
35 | thread_type_t th_info, |
36 | const BlockingFactors* params) |
37 | : CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>( |
38 | params), |
39 | packedA_(packA), |
40 | packedB_(packB), |
41 | matC_(matC), |
42 | C_buffer_(C_buffer), |
43 | ldc_(ldc), |
44 | outputProcess_(outputProcess), |
45 | th_info_(th_info) { |
46 | if (!cpuinfo_initialize()) { |
47 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
48 | } |
49 | if (params) { |
50 | if (fbgemmHasAvx2Support()) { |
51 | mbSize_ = params->MCB; |
52 | nbSize_ = params->NCB; |
53 | nrMinSize_ = params->NR_MIN; |
54 | nrSize_ = params->NR; |
55 | } else { |
56 | // TODO: Have default slower path |
57 | assert(0 && "unsupported architecure" ); |
58 | throw std::runtime_error("unsupported architecure" ); |
59 | } |
60 | } else { |
61 | const inst_set_t isa = fbgemmInstructionSet(); |
62 | switch (isa) { |
63 | case inst_set_t::avx512_vnni: |
64 | std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits< |
65 | typename packingAMatrix::inpType, |
66 | typename packingAMatrix::accType, |
67 | inst_set_t::avx512_vnni>::getKernelParams(); |
68 | break; |
69 | |
70 | case inst_set_t::avx512_vnni_ymm: |
71 | std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits< |
72 | typename packingAMatrix::inpType, |
73 | typename packingAMatrix::accType, |
74 | inst_set_t::avx512_vnni_ymm>::getKernelParams(); |
75 | break; |
76 | |
77 | case inst_set_t::avx512: |
78 | std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits< |
79 | typename packingAMatrix::inpType, |
80 | typename packingAMatrix::accType, |
81 | inst_set_t::avx512>::getKernelParams(); |
82 | break; |
83 | |
84 | case inst_set_t::avx512_ymm: |
85 | std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits< |
86 | typename packingAMatrix::inpType, |
87 | typename packingAMatrix::accType, |
88 | inst_set_t::avx512_ymm>::getKernelParams(); |
89 | break; |
90 | |
91 | case inst_set_t::avx2: |
92 | std::tie(mbSize_, nbSize_, nrMinSize_, nrSize_) = PackingTraits< |
93 | typename packingAMatrix::inpType, |
94 | typename packingAMatrix::accType, |
95 | inst_set_t::avx2>::getKernelParams(); |
96 | break; |
97 | |
98 | default: |
99 | assert(0 && "unknown architecure" ); |
100 | throw std::runtime_error("unknown architecure" ); |
101 | } |
102 | } |
103 | } |
104 | |
105 | template <typename packingAMatrix, typename cT, typename processOutputType> |
106 | void ExecuteKernel< |
107 | packingAMatrix, |
108 | PackBMatrix<int8_t, typename packingAMatrix::accType>, |
109 | cT, |
110 | processOutputType>::execute(int kBlock) { |
111 | // packedA_.printPackedMatrix("packedA from kernel"); |
112 | // packedB_.printPackedMatrix("packedB from kernel"); |
113 | |
114 | int32_t bColBlocks = packedB_.blockCols(); |
115 | |
116 | int8_t* bBuf; |
117 | int8_t* bBuf_pf; |
118 | |
119 | uint8_t* aBuf = packedA_.getBuf(0); |
120 | |
121 | int32_t packed_rows_A = packedA_.numPackedRows(); |
122 | int32_t row_start_A = packedA_.packedRowStart(); |
123 | |
124 | int group = kBlock / packedB_.blockRows(); |
125 | int NDim = packedB_.numCols(); |
126 | bool lastKBlock = packedB_.isThisLastKBlock(kBlock % packedB_.blockRows()); |
127 | bool accum = (kBlock % packedB_.blockRows()) > 0; |
128 | |
129 | int64_t jb_begin, jb_end; |
130 | fbgemmPartition1D( |
131 | th_info_.n_thread_id, |
132 | th_info_.n_num_threads, |
133 | bColBlocks, |
134 | jb_begin, |
135 | jb_end); |
136 | if (jb_end == jb_begin) { |
137 | return; |
138 | } |
139 | |
140 | typename BaseType::jit_micro_kernel_fp fn; |
141 | |
142 | const inst_set_t isa = fbgemmInstructionSet(); |
143 | switch (isa) { |
144 | case inst_set_t::avx512_vnni: |
145 | if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) { |
146 | // For AVX512VNNI, we redirect int16_t to int32_t accumulation. |
147 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; |
148 | fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>( |
149 | accum, |
150 | packed_rows_A, |
151 | packedB_.blockColSize(), |
152 | packedA_.numPackedCols()); |
153 | } else { |
154 | fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( |
155 | accum, |
156 | packed_rows_A, |
157 | packedB_.blockColSize(), |
158 | packedA_.numPackedCols()); |
159 | } |
160 | break; |
161 | |
162 | case inst_set_t::avx512_vnni_ymm: |
163 | if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) { |
164 | // For AVX512VNNI, we redirect int16_t to int32_t accumulation. |
165 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; |
166 | fn = codeObj.getOrCreate<inst_set_t::avx512_vnni_ymm>( |
167 | accum, |
168 | packed_rows_A, |
169 | packedB_.blockColSize(), |
170 | packedA_.numPackedCols()); |
171 | } else { |
172 | fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni_ymm>( |
173 | accum, |
174 | packed_rows_A, |
175 | packedB_.blockColSize(), |
176 | packedA_.numPackedCols()); |
177 | } |
178 | break; |
179 | |
180 | case inst_set_t::avx512: |
181 | fn = BaseType::template getOrCreate<inst_set_t::avx512>( |
182 | accum, |
183 | packed_rows_A, |
184 | packedB_.blockColSize(), |
185 | packedA_.numPackedCols()); |
186 | break; |
187 | |
188 | case inst_set_t::avx512_ymm: |
189 | fn = BaseType::template getOrCreate<inst_set_t::avx512_ymm>( |
190 | accum, |
191 | packed_rows_A, |
192 | packedB_.blockColSize(), |
193 | packedA_.numPackedCols()); |
194 | break; |
195 | |
196 | case inst_set_t::avx2: |
197 | fn = BaseType::template getOrCreate<inst_set_t::avx2>( |
198 | accum, |
199 | packed_rows_A, |
200 | packedB_.blockColSize(), |
201 | packedA_.numPackedCols()); |
202 | break; |
203 | |
204 | default: |
205 | // TODO: Have default slower path |
206 | assert(0 && "unsupported architecture" ); |
207 | throw std::runtime_error("unsupported architecure" ); |
208 | } |
209 | |
210 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
211 | std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end; |
212 | double dt; |
213 | t_start = std::chrono::high_resolution_clock::now(); |
214 | #endif |
215 | |
216 | for (int jb = jb_begin; jb < jb_end; ++jb) { |
217 | if (jb == bColBlocks - 1) { |
218 | int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; |
219 | if (nc != nbSize_) { |
220 | switch (isa) { |
221 | case inst_set_t::avx512_vnni: |
222 | if (std::is_same<typename packingAMatrix::accType, std::int16_t>:: |
223 | value) { |
224 | // For AVX512VNNI, we redirect int16_t to int32_t accumulation. |
225 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; |
226 | fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>( |
227 | accum, packed_rows_A, nc, packedA_.numPackedCols()); |
228 | } else { |
229 | fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( |
230 | accum, packed_rows_A, nc, packedA_.numPackedCols()); |
231 | } |
232 | break; |
233 | |
234 | case inst_set_t::avx512_vnni_ymm: |
235 | if (std::is_same<typename packingAMatrix::accType, std::int16_t>:: |
236 | value) { |
237 | // For AVX512VNNI, we redirect int16_t to int32_t accumulation. |
238 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; |
239 | fn = codeObj.getOrCreate<inst_set_t::avx512_vnni_ymm>( |
240 | accum, packed_rows_A, nc, packedA_.numPackedCols()); |
241 | } else { |
242 | fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni_ymm>( |
243 | accum, packed_rows_A, nc, packedA_.numPackedCols()); |
244 | } |
245 | break; |
246 | |
247 | case inst_set_t::avx512: |
248 | fn = BaseType::template getOrCreate<inst_set_t::avx512>( |
249 | accum, packed_rows_A, nc, packedA_.numPackedCols()); |
250 | break; |
251 | |
252 | case inst_set_t::avx512_ymm: |
253 | fn = BaseType::template getOrCreate<inst_set_t::avx512_ymm>( |
254 | accum, packed_rows_A, nc, packedA_.numPackedCols()); |
255 | break; |
256 | |
257 | case inst_set_t::avx2: |
258 | fn = BaseType::template getOrCreate<inst_set_t::avx2>( |
259 | accum, packed_rows_A, nc, packedA_.numPackedCols()); |
260 | break; |
261 | |
262 | default: |
263 | // TODO: Have default slower path |
264 | assert(0 && "unsupported architecture" ); |
265 | throw std::runtime_error("unsupported architecure" ); |
266 | } |
267 | } |
268 | } |
269 | |
270 | bBuf = packedB_.getBuf(jb, kBlock); |
271 | // prefetch addr of the next packed block of B matrix |
272 | bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock); |
273 | |
274 | // If the accumulation buffer C_buffer_ is the same as matC_ (inplace output |
275 | // processing), then each thread use the different parts of output buffer |
276 | // matC_; |
277 | // Otherwise, each thread uses different portions of the accumulation |
278 | // buffer C_buffer_. If m is large enough (m >= m_nthreads * MC), then we |
279 | // only need to use (m_nthreads * MC) x n portion of C_buffer_, each thread |
280 | // access the C_buffer_row_start as tid * MC * ldc_; else when m is very |
281 | // small, we juse use the whole m x n C_buffer_: each thread use the |
282 | // different portion. |
283 | int32_t* C_buffer_row_start = C_buffer_ + |
284 | ((C_buffer_ == reinterpret_cast<int32_t*>(matC_) || |
285 | th_info_.m_num_threads * mbSize_ > packedA_.numRows()) |
286 | ? row_start_A * ldc_ + NDim * group |
287 | : th_info_.m_thread_id * mbSize_ * ldc_ + NDim * group); |
288 | |
289 | int32_t* C_buffer_start = C_buffer_row_start + jb * nbSize_; |
290 | int32_t leadingDim = ldc_; |
291 | static thread_local std::vector<int32_t> C_tile_; |
292 | if (packedB_.isThereColRemainder() && (jb == bColBlocks - 1)) { |
293 | // In case we will access memory past C_buffer_, we use C_tile_ scratchpad |
294 | // instead. |
295 | C_tile_.resize(mbSize_ * nbSize_); |
296 | C_buffer_start = C_tile_.data(); |
297 | leadingDim = nbSize_; |
298 | } |
299 | |
300 | fn(aBuf, |
301 | bBuf, |
302 | bBuf_pf, |
303 | C_buffer_start, |
304 | packedA_.numPackedCols(), |
305 | leadingDim); |
306 | |
307 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
308 | t_end = std::chrono::high_resolution_clock::now(); |
309 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) |
310 | .count(); |
311 | kernel_time += (dt); |
312 | t_start = std::chrono::high_resolution_clock::now(); |
313 | #endif |
314 | |
315 | // Output processing is done only once per rowblock to amortize overhead |
316 | // and for better spatial locality. |
317 | if (lastKBlock && jb == jb_end - 1) { |
318 | // When C_tile_ is used for the last column block, we need a separate |
319 | // handling for the last column block. |
320 | int32_t nSize = |
321 | (C_buffer_start == C_tile_.data() ? (jb - jb_begin) * nbSize_ |
322 | : (jb_end - jb_begin) * nbSize_); |
323 | if (nSize) { |
324 | if (fbgemmHasAvx2Support()) { |
325 | // TODO: avx512 path |
326 | // Currently use avx2 code |
327 | outputProcess_.template f<inst_set_t::avx2>( |
328 | matC_, |
329 | C_buffer_row_start + jb_begin * nbSize_, |
330 | {row_start_A, |
331 | packed_rows_A, |
332 | static_cast<int>(NDim * group + jb_begin * nbSize_), |
333 | nSize}, |
334 | ldc_, |
335 | ldc_); |
336 | } else { |
337 | // TODO: Have default slower path |
338 | assert(0 && "unsupported architecure" ); |
339 | throw std::runtime_error("unsupported architecure" ); |
340 | } |
341 | } |
342 | |
343 | if (C_buffer_start == C_tile_.data()) { |
344 | // When C_tile_ scratchpad was used to avoid accessing memory past |
345 | // C_buffer_ . |
346 | if (fbgemmHasAvx2Support()) { |
347 | // TODO: avx512 path |
348 | // Currently use avx2 code |
349 | outputProcess_.template f<inst_set_t::avx2>( |
350 | matC_, |
351 | C_tile_.data(), |
352 | {row_start_A, |
353 | packed_rows_A, |
354 | NDim * group + jb * nbSize_, |
355 | packedB_.lastBcol()}, |
356 | ldc_, |
357 | leadingDim); |
358 | } else { |
359 | // TODO: Have default slower path |
360 | assert(0 && "unsupported architecure" ); |
361 | throw std::runtime_error("unsupported architecure" ); |
362 | } |
363 | } |
364 | } // output processing |
365 | |
366 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
367 | t_end = std::chrono::high_resolution_clock::now(); |
368 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) |
369 | .count(); |
370 | postprocessing_time += (dt); |
371 | t_start = std::chrono::high_resolution_clock::now(); |
372 | #endif |
373 | |
374 | } // for each j block |
375 | } |
376 | |
377 | //////////////////////////////////////////////////////////////////////////////// |
378 | // ReQuantizeOutput |
379 | #define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \ |
380 | template class ExecuteKernel< \ |
381 | PACK_A<uint8_t, ACC_T>, \ |
382 | PackBMatrix<int8_t, ACC_T>, \ |
383 | uint8_t, \ |
384 | ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>; |
385 | |
386 | #define INSTANTIATE_REQUANT_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \ |
387 | INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \ |
388 | INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t); |
389 | |
390 | #define INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, RELU) \ |
391 | INSTANTIATE_REQUANT_BIAS_T( \ |
392 | PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ |
393 | INSTANTIATE_REQUANT_BIAS_T( \ |
394 | PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \ |
395 | INSTANTIATE_REQUANT_BIAS_T( \ |
396 | PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL); |
397 | |
398 | #define INSTANTIATE_REQUANT_RELU(PACK_A, ACC_T) \ |
399 | INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, false); \ |
400 | INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, true); |
401 | |
402 | #define INSTANTIATE_REQUANT_ACC_T(PACK_A) \ |
403 | INSTANTIATE_REQUANT_RELU(PACK_A, int32_t); \ |
404 | INSTANTIATE_REQUANT_RELU(PACK_A, int16_t); |
405 | |
406 | INSTANTIATE_REQUANT_ACC_T(PackAMatrix); |
407 | INSTANTIATE_REQUANT_ACC_T(PackAWithRowOffset); |
408 | |
409 | #undef INSTANTIATE_REQUANT_ACC_T |
410 | #undef INSTANTIATE_REQUANT_RELU |
411 | #undef INSTANTIATE_REQUANT_Q_GRANS |
412 | #undef INSTANTIATE_REQUANT_BIAS_T |
413 | #undef INSTANTIATE_REQUANT_BASE |
414 | |
415 | #define INSTANTIATE_IM2COL_REQUANT_BASE( \ |
416 | ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \ |
417 | template class ExecuteKernel< \ |
418 | PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ |
419 | PackBMatrix<int8_t, ACC_T>, \ |
420 | uint8_t, \ |
421 | ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>; |
422 | |
423 | #define INSTANTIATE_IM2COL_REQUANT_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ |
424 | INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \ |
425 | INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t); |
426 | |
427 | #define INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ |
428 | INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ |
429 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ |
430 | INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ |
431 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ |
432 | INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ |
433 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL); |
434 | |
435 | #define INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, RELU) \ |
436 | INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, 1); \ |
437 | INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, 2); \ |
438 | INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, 3); |
439 | |
440 | #define INSTANTIATE_IM2COL_REQUANT_RELU(ACC_T) \ |
441 | INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, false); \ |
442 | INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, true); |
443 | |
444 | INSTANTIATE_IM2COL_REQUANT_RELU(int32_t); |
445 | INSTANTIATE_IM2COL_REQUANT_RELU(int16_t); |
446 | |
447 | #undef INSTANTIATE_IM2COL_REQUANT_RELU |
448 | #undef INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM |
449 | #undef INSTANTIATE_IM2COL_REQUANT_Q_GRANS |
450 | #undef INSTANTIATE_IM2COL_REQUANT_BIAS_T |
451 | #undef INSTANTIATE_IM2COL_REQUANT_BASE |
452 | |
453 | //////////////////////////////////////////////////////////////////////////////// |
454 | // ReQuantizeForFloat |
455 | #define INSTANTIATE_REQUANT_FLOAT_BASE(PACK_A, RELU, Q_GRAN) \ |
456 | template class ExecuteKernel< \ |
457 | PACK_A<uint8_t, int32_t>, \ |
458 | PackBMatrix<int8_t, int32_t>, \ |
459 | float, \ |
460 | ReQuantizeForFloat<RELU, Q_GRAN>>; |
461 | |
462 | #define INSTANTIATE_REQUANT_FLOAT_Q_GRANS(PACK_A, RELU) \ |
463 | INSTANTIATE_REQUANT_FLOAT_BASE( \ |
464 | PACK_A, RELU, QuantizationGranularity::TENSOR); \ |
465 | INSTANTIATE_REQUANT_FLOAT_BASE( \ |
466 | PACK_A, RELU, QuantizationGranularity::GROUP); \ |
467 | INSTANTIATE_REQUANT_FLOAT_BASE( \ |
468 | PACK_A, RELU, QuantizationGranularity::OUT_CHANNEL); |
469 | |
470 | #define INSTANTIATE_REQUANT_FLOAT_RELU(PACK_A) \ |
471 | INSTANTIATE_REQUANT_FLOAT_Q_GRANS(PACK_A, false); \ |
472 | INSTANTIATE_REQUANT_FLOAT_Q_GRANS(PACK_A, true); |
473 | |
474 | INSTANTIATE_REQUANT_FLOAT_RELU(PackAWithRowOffset); |
475 | INSTANTIATE_REQUANT_FLOAT_RELU(PackAWithQuantRowOffset); |
476 | |
477 | #undef INSTANTIATE_REQUANT_FLOAT_RELU |
478 | #undef INSTANTIATE_REQUANT_FLOAT_Q_GRANS |
479 | #undef INSTANTIATE_REQUANT_FLOAT_BASE |
480 | |
481 | #define INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE( \ |
482 | ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ |
483 | template class ExecuteKernel< \ |
484 | PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ |
485 | PackBMatrix<int8_t, ACC_T>, \ |
486 | float, \ |
487 | ReQuantizeForFloat<RELU, Q_GRAN>>; |
488 | |
489 | #define INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ |
490 | INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE( \ |
491 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ |
492 | INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE( \ |
493 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ |
494 | INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE( \ |
495 | ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL); |
496 | |
497 | #define INSTANTIATE_REQUANT_FLOAT_IM2COL_SPATIAL_DIM(ACC_T, RELU) \ |
498 | INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS(ACC_T, RELU, 1); \ |
499 | INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS(ACC_T, RELU, 2); \ |
500 | INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS(ACC_T, RELU, 3); |
501 | |
502 | #define INSTANTIATE_REQUANT_FLOAT_IM2COL_RELU(ACC_T) \ |
503 | INSTANTIATE_REQUANT_FLOAT_IM2COL_SPATIAL_DIM(ACC_T, false); \ |
504 | INSTANTIATE_REQUANT_FLOAT_IM2COL_SPATIAL_DIM(ACC_T, true); |
505 | |
506 | INSTANTIATE_REQUANT_FLOAT_IM2COL_RELU(int32_t); |
507 | INSTANTIATE_REQUANT_FLOAT_IM2COL_RELU(int16_t); |
508 | |
509 | #undef INSTANTIATE_REQUANT_FLOAT_IM2COL_RELU |
510 | #undef INSTANTIATE_REQUANT_FLOAT_IM2COL_SPATIAL_DIM |
511 | #undef INSTANTIATE_REQUANT_FLOAT_IM2COL_Q_GRANS |
512 | #undef INSTANTIATE_REQUANT_FLOAT_IM2COL_BASE |
513 | |
514 | template class ExecuteKernel< |
515 | PackAWithRowOffset<uint8_t, int16_t>, |
516 | PackBMatrix<int8_t, int16_t>, |
517 | float, |
518 | ReQuantizeForFloat<false /* FUSE_RELU*/>>; |
519 | |
520 | //////////////////////////////////////////////////////////////////////////////// |
521 | // DoSpmdmOnInpBuffer |
522 | #define INSTANTIATE_SPMDM_BASE(PACK_A, RELU, Q_GRAN) \ |
523 | template class ExecuteKernel< \ |
524 | PACK_A<uint8_t, int16_t>, \ |
525 | PackBMatrix<int8_t, int16_t>, \ |
526 | uint8_t, \ |
527 | DoSpmdmOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<RELU, Q_GRAN>>>; |
528 | |
529 | #define INSTANTIATE_SPMDM_Q_GRANS(PACK_A, RELU) \ |
530 | INSTANTIATE_SPMDM_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR); \ |
531 | INSTANTIATE_SPMDM_BASE(PACK_A, RELU, QuantizationGranularity::GROUP); \ |
532 | INSTANTIATE_SPMDM_BASE(PACK_A, RELU, QuantizationGranularity::OUT_CHANNEL); |
533 | |
534 | #define INSTANTIATE_SPMDM_RELU(PACK_A) \ |
535 | INSTANTIATE_SPMDM_Q_GRANS(PACK_A, false); \ |
536 | INSTANTIATE_SPMDM_Q_GRANS(PACK_A, true); |
537 | |
538 | INSTANTIATE_SPMDM_RELU(PackAMatrix); |
539 | INSTANTIATE_SPMDM_RELU(PackAWithRowOffset); |
540 | |
541 | #undef INSTANTIATE_SPMDM_RELU |
542 | #undef INSTANTIATE_SPMDM_Q_GRANS |
543 | #undef INSTANTIATE_SPMDM_BASE |
544 | |
545 | #define INSTANTIATE_SCONV_BASE(RELU, Q_GRAN) \ |
546 | template class ExecuteKernel< \ |
547 | PackAWithIm2Col<uint8_t, int16_t>, \ |
548 | PackBMatrix<int8_t, int16_t>, \ |
549 | uint8_t, \ |
550 | DoSConvOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<RELU, Q_GRAN>>>; |
551 | |
552 | #define INSTANTIATE_SCONV_Q_GRANS(RELU) \ |
553 | INSTANTIATE_SCONV_BASE(RELU, QuantizationGranularity::TENSOR); \ |
554 | INSTANTIATE_SCONV_BASE(RELU, QuantizationGranularity::GROUP); \ |
555 | INSTANTIATE_SCONV_BASE(RELU, QuantizationGranularity::OUT_CHANNEL); |
556 | |
557 | INSTANTIATE_SCONV_Q_GRANS(false); |
558 | INSTANTIATE_SCONV_Q_GRANS(true); |
559 | |
560 | #undef INSTANTIATE_SCONV_Q_GRANS |
561 | #undef INSTANTIATE_SCONV_BASE |
562 | |
563 | template class ExecuteKernel< |
564 | PackAWithRowOffset<uint8_t, int16_t>, |
565 | PackBMatrix<int8_t, int16_t>, |
566 | float, |
567 | DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>>; |
568 | |
569 | //////////////////////////////////////////////////////////////////////////////// |
570 | // memCopy |
571 | #define INSTANTIATE_MEMCPY_BASE(PACK_A, ACC_T) \ |
572 | template class ExecuteKernel< \ |
573 | PACK_A<uint8_t, ACC_T>, \ |
574 | PackBMatrix<int8_t, ACC_T>, \ |
575 | int32_t, \ |
576 | memCopy<>>; |
577 | |
578 | #define INSTANTIATE_MEMCPY_ACC_T(PACK_A) \ |
579 | INSTANTIATE_MEMCPY_BASE(PACK_A, int32_t) \ |
580 | INSTANTIATE_MEMCPY_BASE(PACK_A, int16_t) |
581 | |
582 | INSTANTIATE_MEMCPY_ACC_T(PackAMatrix); |
583 | INSTANTIATE_MEMCPY_ACC_T(PackAWithRowOffset); |
584 | |
585 | #undef INSTANTIATE_MEMCPY_ACC_T |
586 | #undef INSTANTIATE_MEMCPY_BASE |
587 | |
588 | #define INSTANTIATE_MEMCPY_IM2COL_BASE(ACC_T, SPATIAL_DIM) \ |
589 | template class ExecuteKernel< \ |
590 | PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ |
591 | PackBMatrix<int8_t, ACC_T>, \ |
592 | int32_t, \ |
593 | memCopy<>>; |
594 | |
595 | #define INSTANTIATE_MEMCPY_IM2COL_SPATIAL_DIM(ACC_T) \ |
596 | INSTANTIATE_MEMCPY_IM2COL_BASE(ACC_T, 1); \ |
597 | INSTANTIATE_MEMCPY_IM2COL_BASE(ACC_T, 2); \ |
598 | INSTANTIATE_MEMCPY_IM2COL_BASE(ACC_T, 3); |
599 | |
600 | INSTANTIATE_MEMCPY_IM2COL_SPATIAL_DIM(int32_t); |
601 | INSTANTIATE_MEMCPY_IM2COL_SPATIAL_DIM(int16_t); |
602 | |
603 | #undef INSTANTIATE_MEMCPY_IM2COL_SPATIAL_DIM |
604 | #undef INSTANTIATE_MEMCPY_IM2COL_BASE |
605 | |
606 | template class ExecuteKernel< |
607 | PackAWithQuantRowOffset<uint8_t, int32_t>, |
608 | PackBMatrix<int8_t, int32_t>, |
609 | int32_t, |
610 | memCopy<>>; |
611 | |
612 | template class ExecuteKernel< |
613 | PackAMatrix<uint8_t, int16_t>, |
614 | PackBMatrix<int8_t, int16_t>, |
615 | int32_t, |
616 | DoNothing<int32_t, int32_t>>; |
617 | |
618 | } // namespace fbgemm |
619 | |