1 | /* |
2 | * Copyright (c) Facebook, Inc. and its affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmI8DirectconvAvx2.h" |
9 | |
10 | #if defined(__x86_64__) || defined(__i386__) || \ |
11 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
12 | #include <immintrin.h> |
13 | #endif |
14 | #include <cassert> |
15 | |
16 | #include "./DirectConv.h" |
17 | #include "./ExecuteKernel.h" |
18 | #include "./MaskAvx2.h" |
19 | #include "fbgemm/ConvUtils.h" |
20 | #include "fbgemm/Fbgemm.h" |
21 | #include "fbgemm/FbgemmBuild.h" |
22 | #include "fbgemm/UtilsAvx2.h" |
23 | |
24 | #include "./CodeGenHelpers.h" |
25 | #include "./OptimizedKernelsAvx2.h" |
26 | #include "./RefImplementations.h" |
27 | #include "./TransposeUtils.h" |
28 | #include "fbgemm/QuantUtilsAvx512.h" |
29 | namespace fbgemm { |
30 | |
31 | PackedDirectConvMatrix::PackedDirectConvMatrix( |
32 | int IC_per_G, |
33 | int OC_per_G, |
34 | int filter_prod, |
35 | const int8_t* smat) { |
36 | // Allocate packed arrays |
37 | int kernel_prod_aligned = (filter_prod + 1) / 2 * 2; |
38 | pmat_ = static_cast<int8_t*>(fbgemmAlignedAlloc( |
39 | 64, |
40 | ((OC_per_G + 31) / 32 * 32) * kernel_prod_aligned * IC_per_G * |
41 | sizeof(int8_t))); |
42 | |
43 | // the transposed weight layout: W[oc/8][r][s][ic/4][8][4] |
44 | for (int g = 0; g < /* G */ 1; ++g) { |
45 | for (int k = 0; k < OC_per_G; ++k) { |
46 | for (int f = 0; f < filter_prod; ++f) { |
47 | for (int c = 0; c < IC_per_G; ++c) { |
48 | int ocB = k / 8; |
49 | int ocb = k % 8; |
50 | int icB = c / 4; |
51 | int icb = c % 4; |
52 | pmat_ |
53 | [((((g * (OC_per_G / 8) + ocB) * filter_prod + f) * |
54 | (IC_per_G / 4) + |
55 | icB) * |
56 | 8 + |
57 | ocb) * |
58 | 4 + |
59 | icb] = |
60 | smat[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c]; |
61 | } |
62 | } |
63 | } |
64 | } |
65 | } |
66 | |
67 | PackedDirectConvMatrix::~PackedDirectConvMatrix() { |
68 | fbgemmAlignedFree(pmat_); |
69 | } |
70 | |
71 | template <int kSpatialDim> |
72 | void PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT( |
73 | const fbgemm::conv_param_t<kSpatialDim>& conv_p, |
74 | std::int32_t* B_zero_point, |
75 | std::vector<int32_t>& col_offsets, |
76 | int ncols_per_quant_group) { |
77 | // if use direct convolution implementation, compute the col_offsets |
78 | // of the weight matrix at the first time of inference. |
79 | // We need to know the shape of output matrix |
80 | // to compute col_offsets for direct convolution. |
81 | // Hence it cannot be called from inside weight packing function |
82 | // at initialization stage like other quantized conv implementation. |
83 | // Thus the col_offsets computation will be invoked at forward pass, |
84 | // and only the first pass will prepare the col_offsets. |
85 | if (first_call == false) { |
86 | return; |
87 | } |
88 | int IC = conv_p.IC; |
89 | int OC = conv_p.OC; |
90 | |
91 | int IN_DIM0 = conv_p.IN_DIM[0]; |
92 | int IN_DIM1 = conv_p.IN_DIM[1]; |
93 | int OUT_DIM0 = conv_p.OUT_DIM[0]; |
94 | int OUT_DIM1 = conv_p.OUT_DIM[1]; |
95 | int K0 = conv_p.K[0]; |
96 | int K1 = conv_p.K[1]; |
97 | int stride0 = conv_p.stride[0]; |
98 | int stride1 = conv_p.stride[1]; |
99 | |
100 | int MDim = conv_p.MB * OUT_DIM0 * OUT_DIM1; |
101 | int NDim = conv_p.OC / conv_p.G; |
102 | // int KDim = K[0] * K[1] * conv_p.IC; |
103 | |
104 | col_offsets.resize(MDim * NDim, 0); |
105 | std::fill(col_offsets.begin(), col_offsets.end(), 0); |
106 | std::vector<int> count(MDim * NDim, 0); |
107 | |
108 | for (int oc = 0; oc < OC; oc++) { |
109 | for (int ih = 0; ih < IN_DIM0; ih++) { |
110 | for (int iw = 0; iw < IN_DIM1; iw++) { |
111 | for (int kh = 0; kh < K0; kh++) { |
112 | for (int kw = 0; kw < K1; kw++) { |
113 | for (int ic = 0; ic < IC; ic++) { |
114 | int oh = ih * stride0 + kh; |
115 | int ow = iw * stride1 + kw; |
116 | col_offsets[(oh * OUT_DIM1 + ow) * OC + oc] += pmat_ |
117 | [(((((oc / 8) * K0 + kh) * K1 + kw) * (IC / 4) + ic / 4) * 8 + |
118 | (oc % 8)) * |
119 | 4 + |
120 | (ic % 4)]; |
121 | count[(oh * OUT_DIM1 + ow) * OC + oc]++; |
122 | } |
123 | } |
124 | } |
125 | } |
126 | } |
127 | } |
128 | |
129 | for (int oc = 0; oc < OC; oc++) { |
130 | for (int oh = 0; oh < OUT_DIM0; oh++) { |
131 | for (int ow = 0; ow < OUT_DIM1; ow++) { |
132 | col_offsets[(oh * OUT_DIM1 + ow) * OC + oc] -= |
133 | B_zero_point[oc / ncols_per_quant_group] * |
134 | count[(oh * OUT_DIM1 + ow) * OC + oc]; |
135 | } |
136 | } |
137 | } |
138 | |
139 | first_call = false; |
140 | } |
141 | |
142 | template FBGEMM_API void |
143 | PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<1>( |
144 | const fbgemm::conv_param_t<1>& conv_p, |
145 | std::int32_t* B_zero_point, |
146 | std::vector<int32_t>& col_offsets, |
147 | int ncols_per_quant_group); |
148 | |
149 | template FBGEMM_API void |
150 | PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<2>( |
151 | const fbgemm::conv_param_t<2>& conv_p, |
152 | std::int32_t* B_zero_point, |
153 | std::vector<int32_t>& col_offsets, |
154 | int ncols_per_quant_group); |
155 | |
156 | template FBGEMM_API void |
157 | PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<3>( |
158 | const fbgemm::conv_param_t<3>& conv_p, |
159 | std::int32_t* B_zero_point, |
160 | std::vector<int32_t>& col_offsets, |
161 | int ncols_per_quant_group); |
162 | |
163 | template <int SPATIAL_DIM> |
164 | void directConvRowSum( |
165 | const conv_param_t<SPATIAL_DIM>& conv_p, |
166 | const uint8_t* A, |
167 | int32_t* inSum, |
168 | int32_t* rowSum) { |
169 | int IN0 = conv_p.IN_DIM[0]; |
170 | int IN1 = conv_p.IN_DIM[1]; |
171 | int IC = conv_p.IC; |
172 | int K0 = conv_p.K[0]; |
173 | int K1 = conv_p.K[1]; |
174 | int OUT0 = conv_p.OUT_DIM[0]; |
175 | int OUT1 = conv_p.OUT_DIM[1]; |
176 | int stride = conv_p.stride[1]; |
177 | |
178 | memset(rowSum, 0, sizeof(int32_t) * OUT0 * OUT1); |
179 | |
180 | for (int ih = 0; ih < IN0; ++ih) { |
181 | for (int iw = 0; iw < IN1; ++iw) { |
182 | inSum[ih * IN1 + iw] = reduceAvx2(A + ih * IN1 * IC + iw * IC, IC); |
183 | } |
184 | } |
185 | |
186 | for (int ih = 0; ih < IN0; ++ih) { |
187 | for (int iw = 0; iw < IN1; iw++) { |
188 | for (int r = 0; r < K0; ++r) { |
189 | for (int s = 0; s < K1; ++s) { |
190 | rowSum[(ih + r) * OUT1 + iw * stride + s] += inSum[ih * IN1 + iw]; |
191 | } |
192 | } |
193 | } |
194 | } |
195 | /* |
196 | compare_buffers( |
197 | rowSum, |
198 | rowoffsets, |
199 | OUT0, |
200 | OUT1, |
201 | OUT1, |
202 | 5); |
203 | */ |
204 | } |
205 | |
206 | template void directConvRowSum<1>( |
207 | const conv_param_t<1>& conv_p, |
208 | const uint8_t* A, |
209 | int32_t* inSum, |
210 | int32_t* rowSum); |
211 | |
212 | template void directConvRowSum<2>( |
213 | const conv_param_t<2>& conv_p, |
214 | const uint8_t* A, |
215 | int32_t* inSum, |
216 | int32_t* rowSum); |
217 | |
218 | template void directConvRowSum<3>( |
219 | const conv_param_t<3>& conv_p, |
220 | const uint8_t* A, |
221 | int32_t* inSum, |
222 | int32_t* rowSum); |
223 | |
224 | template < |
225 | int SPATIAL_DIM, |
226 | QuantizationGranularity Q_GRAN, |
227 | bool FUSE_RELU, |
228 | typename BIAS_TYPE> |
229 | void fbgemmDirectConv( |
230 | const conv_param_t<SPATIAL_DIM>& conv_p, |
231 | const uint8_t* Aint8, |
232 | PackedDirectConvMatrix& Bint8_tr, |
233 | uint8_t* C, |
234 | int32_t* C_buffer, |
235 | const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess, |
236 | const BIAS_TYPE* bias, |
237 | // const int32_t* bias, |
238 | int thread_id, |
239 | int num_threads) { |
240 | // support for single thread now, |
241 | // will enable multithread later |
242 | if (thread_id > 0 || thread_id >= num_threads) { |
243 | return; |
244 | } |
245 | |
246 | if (SPATIAL_DIM != 2) { |
247 | assert(false && "1d/3d direct conv not supported" ); |
248 | } else { |
249 | if (conv_p.transposed) { |
250 | DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
251 | jit_micro_kernel_fp_convT fn; |
252 | DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; |
253 | /* |
254 | fn = codeObj.getOrCreateDirectConvTrans<inst_set_t::avx2>( |
255 | true, conv_p.stride[1]); |
256 | */ |
257 | fn = codeObj.getOrCreateDirectConvTrans<inst_set_t::avx2>( |
258 | true, conv_p.stride[1], conv_p.K[1]); |
259 | |
260 | int32_t* inSum = static_cast<int32_t*>(fbgemmAlignedAlloc( |
261 | 64, conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * sizeof(int32_t))); |
262 | int32_t* rowSum = static_cast<int32_t*>(fbgemmAlignedAlloc( |
263 | 64, conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * sizeof(int32_t))); |
264 | |
265 | directConvRowSum(conv_p, Aint8, inSum, rowSum); |
266 | int kernel_dim = conv_p.K[0] * conv_p.K[1]; |
267 | |
268 | std::memset( |
269 | C_buffer, |
270 | 0, |
271 | sizeof(int32_t) * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC); |
272 | std::memset( |
273 | C, |
274 | 0, |
275 | sizeof(int8_t) * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC); |
276 | // no-op output process objects |
277 | for (int i = 0; i < conv_p.OC; i += 8) { |
278 | for (int j = 0; j < conv_p.IN_DIM[0]; j++) { |
279 | fn(Aint8 + j * conv_p.IC * conv_p.IN_DIM[1], |
280 | Bint8_tr.PackedMat() + i * kernel_dim * conv_p.IC, |
281 | C_buffer + j * conv_p.OUT_DIM[1] * conv_p.OC + i, |
282 | conv_p.IC, |
283 | conv_p.OC, |
284 | (conv_p.OC * conv_p.OUT_DIM[1] - conv_p.OC * conv_p.K[1]) * 4, |
285 | conv_p.IN_DIM[1]); |
286 | } |
287 | } |
288 | |
289 | int32_t A_zero_point = outProcess.getAZeroPoint(); |
290 | const int32_t* B_zero_point = outProcess.getBZeroPoint(); |
291 | // const float* C_multiplier = outProcess.getCMultiplier(); |
292 | const int32_t* col_offsets = outProcess.getColOffsets(); |
293 | |
294 | /* |
295 | int groups = 1; |
296 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
297 | groups = conv_p.OC; |
298 | } |
299 | */ |
300 | requantizationParams_t<BIAS_TYPE> reqObj = { |
301 | outProcess.getAZeroPoint(), |
302 | outProcess.getBZeroPoint(), |
303 | outProcess.getCZeroPoint(), |
304 | outProcess.getCMultiplier(), |
305 | rowSum, // rowOffsetBuf, |
306 | outProcess.getColOffsets(), |
307 | (outProcess.getBias()), |
308 | static_cast<std::uint32_t>(conv_p.OC), // outProcess.getNCols(), |
309 | 1, // groups |
310 | outProcess.getActWScale()}; |
311 | |
312 | // Dispatch HAS_BIAS |
313 | if (bias == nullptr) { |
314 | // Dispatch A_SYMMETRIC and B_SYMMETRIC |
315 | if (A_zero_point == 0 || col_offsets == nullptr) { |
316 | if (Q_GRAN == QuantizationGranularity::TENSOR && |
317 | B_zero_point[0] == 0) { |
318 | requantizeOutputProcessingAvx2< |
319 | true, |
320 | true, |
321 | QuantizationGranularity::TENSOR, |
322 | false, // HAS_BIAS, |
323 | FUSE_RELU, |
324 | BIAS_TYPE, |
325 | true>( |
326 | C, |
327 | C_buffer, |
328 | {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, |
329 | conv_p.OC, |
330 | conv_p.OC, |
331 | reqObj); |
332 | } else { |
333 | requantizeOutputProcessingAvx2< |
334 | true, |
335 | false, |
336 | Q_GRAN, |
337 | false, // HAS_BIAS, |
338 | FUSE_RELU, |
339 | BIAS_TYPE, |
340 | true>( |
341 | C, |
342 | C_buffer, |
343 | {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, |
344 | conv_p.OC, |
345 | conv_p.OC, |
346 | reqObj); |
347 | } |
348 | } else { |
349 | if (Q_GRAN == QuantizationGranularity::TENSOR && |
350 | B_zero_point[0] == 0) { |
351 | requantizeOutputProcessingAvx2< |
352 | false, |
353 | true, |
354 | QuantizationGranularity::TENSOR, |
355 | false, // HAS_BIAS, |
356 | FUSE_RELU, |
357 | BIAS_TYPE, |
358 | true>( |
359 | C, |
360 | C_buffer, |
361 | {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, |
362 | conv_p.OC, |
363 | conv_p.OC, |
364 | reqObj); |
365 | } else { |
366 | requantizeOutputProcessingAvx2< |
367 | false, |
368 | false, |
369 | Q_GRAN, |
370 | false, // HAS_BIAS, |
371 | FUSE_RELU, |
372 | BIAS_TYPE, |
373 | true>( |
374 | C, |
375 | C_buffer, |
376 | {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, |
377 | conv_p.OC, |
378 | conv_p.OC, |
379 | reqObj); |
380 | } |
381 | } |
382 | } else { // has_bias == true |
383 | |
384 | // dispatch A_SYMMETRIC and B_SYMMETRIC |
385 | if (A_zero_point == 0 || col_offsets == nullptr) { |
386 | if (Q_GRAN == QuantizationGranularity::TENSOR && |
387 | B_zero_point[0] == 0) { |
388 | requantizeOutputProcessingAvx2< |
389 | true, |
390 | true, |
391 | QuantizationGranularity::TENSOR, |
392 | true, // HAS_BIAS, |
393 | FUSE_RELU, |
394 | BIAS_TYPE, |
395 | true>( |
396 | C, |
397 | C_buffer, |
398 | {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, |
399 | conv_p.OC, |
400 | conv_p.OC, |
401 | reqObj); |
402 | } else { |
403 | requantizeOutputProcessingAvx2< |
404 | true, |
405 | false, |
406 | Q_GRAN, |
407 | true, // HAS_BIAS, |
408 | FUSE_RELU, |
409 | BIAS_TYPE, |
410 | true>( |
411 | C, |
412 | C_buffer, |
413 | {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, |
414 | conv_p.OC, |
415 | conv_p.OC, |
416 | reqObj); |
417 | } |
418 | } else { |
419 | if (Q_GRAN == QuantizationGranularity::TENSOR && |
420 | B_zero_point[0] == 0) { |
421 | requantizeOutputProcessingAvx2< |
422 | false, |
423 | true, |
424 | QuantizationGranularity::TENSOR, |
425 | true, // HAS_BIAS, |
426 | FUSE_RELU, |
427 | BIAS_TYPE, |
428 | true>( |
429 | C, |
430 | C_buffer, |
431 | {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, |
432 | conv_p.OC, |
433 | conv_p.OC, |
434 | reqObj); |
435 | } else { |
436 | requantizeOutputProcessingAvx2< |
437 | false, |
438 | false, |
439 | Q_GRAN, |
440 | true, // HAS_BIAS, |
441 | FUSE_RELU, |
442 | BIAS_TYPE, |
443 | true>( |
444 | C, |
445 | C_buffer, |
446 | {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, |
447 | conv_p.OC, |
448 | conv_p.OC, |
449 | reqObj); |
450 | } |
451 | } |
452 | } |
453 | fbgemmAlignedFree(inSum); |
454 | fbgemmAlignedFree(rowSum); |
455 | } // transposed conv |
456 | else { // non-transposed conv |
457 | assert(false && "non-transposed direct conv not integrated yet." ); |
458 | } |
459 | } // else SPATIAL_DIM |
460 | } |
461 | |
462 | #define INSTANTIATE_REQUANTIZE_SPATIAL_DIM( \ |
463 | SPATIAL_DIM, Q_GRAN, RELU, BIAS_TYPE) \ |
464 | template void FBGEMM_API \ |
465 | fbgemmDirectConv<SPATIAL_DIM, Q_GRAN, RELU, BIAS_TYPE>( \ |
466 | const conv_param_t<SPATIAL_DIM>& conv_p, \ |
467 | const uint8_t* Aint8, \ |
468 | PackedDirectConvMatrix& Bint8_tr, \ |
469 | uint8_t* C, \ |
470 | int32_t* C_buffer, \ |
471 | const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ |
472 | const BIAS_TYPE* bias, \ |
473 | int thread_id, \ |
474 | int num_threads); |
475 | |
476 | #define INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, BIAS_TYPE) \ |
477 | INSTANTIATE_REQUANTIZE_SPATIAL_DIM(1, Q_GRAN, RELU, BIAS_TYPE) \ |
478 | INSTANTIATE_REQUANTIZE_SPATIAL_DIM(2, Q_GRAN, RELU, BIAS_TYPE) \ |
479 | INSTANTIATE_REQUANTIZE_SPATIAL_DIM(3, Q_GRAN, RELU, BIAS_TYPE) |
480 | |
481 | #define INSTANTIATE_REQUANTIZE(Q_GRAN, RELU) \ |
482 | INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, float) \ |
483 | INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, int32_t) |
484 | |
485 | #define INSTANTIATE_Q_GRANS(RELU) \ |
486 | INSTANTIATE_REQUANTIZE(QuantizationGranularity::TENSOR, RELU) \ |
487 | INSTANTIATE_REQUANTIZE(QuantizationGranularity::GROUP, RELU) \ |
488 | INSTANTIATE_REQUANTIZE(QuantizationGranularity::OUT_CHANNEL, RELU) |
489 | |
490 | INSTANTIATE_Q_GRANS(true) |
491 | INSTANTIATE_Q_GRANS(false) |
492 | |
493 | #undef INSTANTIATE_REQUANTIZE_SPATIAL_DIM |
494 | #undef INSTANTIATE_REQUANTIZE_BIAS_TYPE |
495 | #undef INSTANTIATE_REQUANTIZE |
496 | #undef INSTANTIATE_Q_GRANS |
497 | } // namespace fbgemm |
498 | |