1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmI8DirectconvAvx2.h"
9
10#if defined(__x86_64__) || defined(__i386__) || \
11 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
12#include <immintrin.h>
13#endif
14#include <cassert>
15
16#include "./DirectConv.h"
17#include "./ExecuteKernel.h"
18#include "./MaskAvx2.h"
19#include "fbgemm/ConvUtils.h"
20#include "fbgemm/Fbgemm.h"
21#include "fbgemm/FbgemmBuild.h"
22#include "fbgemm/UtilsAvx2.h"
23
24#include "./CodeGenHelpers.h"
25#include "./OptimizedKernelsAvx2.h"
26#include "./RefImplementations.h"
27#include "./TransposeUtils.h"
28#include "fbgemm/QuantUtilsAvx512.h"
29namespace fbgemm {
30
31PackedDirectConvMatrix::PackedDirectConvMatrix(
32 int IC_per_G,
33 int OC_per_G,
34 int filter_prod,
35 const int8_t* smat) {
36 // Allocate packed arrays
37 int kernel_prod_aligned = (filter_prod + 1) / 2 * 2;
38 pmat_ = static_cast<int8_t*>(fbgemmAlignedAlloc(
39 64,
40 ((OC_per_G + 31) / 32 * 32) * kernel_prod_aligned * IC_per_G *
41 sizeof(int8_t)));
42
43 // the transposed weight layout: W[oc/8][r][s][ic/4][8][4]
44 for (int g = 0; g < /* G */ 1; ++g) {
45 for (int k = 0; k < OC_per_G; ++k) {
46 for (int f = 0; f < filter_prod; ++f) {
47 for (int c = 0; c < IC_per_G; ++c) {
48 int ocB = k / 8;
49 int ocb = k % 8;
50 int icB = c / 4;
51 int icb = c % 4;
52 pmat_
53 [((((g * (OC_per_G / 8) + ocB) * filter_prod + f) *
54 (IC_per_G / 4) +
55 icB) *
56 8 +
57 ocb) *
58 4 +
59 icb] =
60 smat[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c];
61 }
62 }
63 }
64 }
65}
66
67PackedDirectConvMatrix::~PackedDirectConvMatrix() {
68 fbgemmAlignedFree(pmat_);
69}
70
71template <int kSpatialDim>
72void PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT(
73 const fbgemm::conv_param_t<kSpatialDim>& conv_p,
74 std::int32_t* B_zero_point,
75 std::vector<int32_t>& col_offsets,
76 int ncols_per_quant_group) {
77 // if use direct convolution implementation, compute the col_offsets
78 // of the weight matrix at the first time of inference.
79 // We need to know the shape of output matrix
80 // to compute col_offsets for direct convolution.
81 // Hence it cannot be called from inside weight packing function
82 // at initialization stage like other quantized conv implementation.
83 // Thus the col_offsets computation will be invoked at forward pass,
84 // and only the first pass will prepare the col_offsets.
85 if (first_call == false) {
86 return;
87 }
88 int IC = conv_p.IC;
89 int OC = conv_p.OC;
90
91 int IN_DIM0 = conv_p.IN_DIM[0];
92 int IN_DIM1 = conv_p.IN_DIM[1];
93 int OUT_DIM0 = conv_p.OUT_DIM[0];
94 int OUT_DIM1 = conv_p.OUT_DIM[1];
95 int K0 = conv_p.K[0];
96 int K1 = conv_p.K[1];
97 int stride0 = conv_p.stride[0];
98 int stride1 = conv_p.stride[1];
99
100 int MDim = conv_p.MB * OUT_DIM0 * OUT_DIM1;
101 int NDim = conv_p.OC / conv_p.G;
102 // int KDim = K[0] * K[1] * conv_p.IC;
103
104 col_offsets.resize(MDim * NDim, 0);
105 std::fill(col_offsets.begin(), col_offsets.end(), 0);
106 std::vector<int> count(MDim * NDim, 0);
107
108 for (int oc = 0; oc < OC; oc++) {
109 for (int ih = 0; ih < IN_DIM0; ih++) {
110 for (int iw = 0; iw < IN_DIM1; iw++) {
111 for (int kh = 0; kh < K0; kh++) {
112 for (int kw = 0; kw < K1; kw++) {
113 for (int ic = 0; ic < IC; ic++) {
114 int oh = ih * stride0 + kh;
115 int ow = iw * stride1 + kw;
116 col_offsets[(oh * OUT_DIM1 + ow) * OC + oc] += pmat_
117 [(((((oc / 8) * K0 + kh) * K1 + kw) * (IC / 4) + ic / 4) * 8 +
118 (oc % 8)) *
119 4 +
120 (ic % 4)];
121 count[(oh * OUT_DIM1 + ow) * OC + oc]++;
122 }
123 }
124 }
125 }
126 }
127 }
128
129 for (int oc = 0; oc < OC; oc++) {
130 for (int oh = 0; oh < OUT_DIM0; oh++) {
131 for (int ow = 0; ow < OUT_DIM1; ow++) {
132 col_offsets[(oh * OUT_DIM1 + ow) * OC + oc] -=
133 B_zero_point[oc / ncols_per_quant_group] *
134 count[(oh * OUT_DIM1 + ow) * OC + oc];
135 }
136 }
137 }
138
139 first_call = false;
140}
141
142template FBGEMM_API void
143PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<1>(
144 const fbgemm::conv_param_t<1>& conv_p,
145 std::int32_t* B_zero_point,
146 std::vector<int32_t>& col_offsets,
147 int ncols_per_quant_group);
148
149template FBGEMM_API void
150PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<2>(
151 const fbgemm::conv_param_t<2>& conv_p,
152 std::int32_t* B_zero_point,
153 std::vector<int32_t>& col_offsets,
154 int ncols_per_quant_group);
155
156template FBGEMM_API void
157PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<3>(
158 const fbgemm::conv_param_t<3>& conv_p,
159 std::int32_t* B_zero_point,
160 std::vector<int32_t>& col_offsets,
161 int ncols_per_quant_group);
162
163template <int SPATIAL_DIM>
164void directConvRowSum(
165 const conv_param_t<SPATIAL_DIM>& conv_p,
166 const uint8_t* A,
167 int32_t* inSum,
168 int32_t* rowSum) {
169 int IN0 = conv_p.IN_DIM[0];
170 int IN1 = conv_p.IN_DIM[1];
171 int IC = conv_p.IC;
172 int K0 = conv_p.K[0];
173 int K1 = conv_p.K[1];
174 int OUT0 = conv_p.OUT_DIM[0];
175 int OUT1 = conv_p.OUT_DIM[1];
176 int stride = conv_p.stride[1];
177
178 memset(rowSum, 0, sizeof(int32_t) * OUT0 * OUT1);
179
180 for (int ih = 0; ih < IN0; ++ih) {
181 for (int iw = 0; iw < IN1; ++iw) {
182 inSum[ih * IN1 + iw] = reduceAvx2(A + ih * IN1 * IC + iw * IC, IC);
183 }
184 }
185
186 for (int ih = 0; ih < IN0; ++ih) {
187 for (int iw = 0; iw < IN1; iw++) {
188 for (int r = 0; r < K0; ++r) {
189 for (int s = 0; s < K1; ++s) {
190 rowSum[(ih + r) * OUT1 + iw * stride + s] += inSum[ih * IN1 + iw];
191 }
192 }
193 }
194 }
195 /*
196 compare_buffers(
197 rowSum,
198 rowoffsets,
199 OUT0,
200 OUT1,
201 OUT1,
202 5);
203 */
204}
205
206template void directConvRowSum<1>(
207 const conv_param_t<1>& conv_p,
208 const uint8_t* A,
209 int32_t* inSum,
210 int32_t* rowSum);
211
212template void directConvRowSum<2>(
213 const conv_param_t<2>& conv_p,
214 const uint8_t* A,
215 int32_t* inSum,
216 int32_t* rowSum);
217
218template void directConvRowSum<3>(
219 const conv_param_t<3>& conv_p,
220 const uint8_t* A,
221 int32_t* inSum,
222 int32_t* rowSum);
223
224template <
225 int SPATIAL_DIM,
226 QuantizationGranularity Q_GRAN,
227 bool FUSE_RELU,
228 typename BIAS_TYPE>
229void fbgemmDirectConv(
230 const conv_param_t<SPATIAL_DIM>& conv_p,
231 const uint8_t* Aint8,
232 PackedDirectConvMatrix& Bint8_tr,
233 uint8_t* C,
234 int32_t* C_buffer,
235 const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
236 const BIAS_TYPE* bias,
237 // const int32_t* bias,
238 int thread_id,
239 int num_threads) {
240 // support for single thread now,
241 // will enable multithread later
242 if (thread_id > 0 || thread_id >= num_threads) {
243 return;
244 }
245
246 if (SPATIAL_DIM != 2) {
247 assert(false && "1d/3d direct conv not supported");
248 } else {
249 if (conv_p.transposed) {
250 DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
251 jit_micro_kernel_fp_convT fn;
252 DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
253 /*
254 fn = codeObj.getOrCreateDirectConvTrans<inst_set_t::avx2>(
255 true, conv_p.stride[1]);
256 */
257 fn = codeObj.getOrCreateDirectConvTrans<inst_set_t::avx2>(
258 true, conv_p.stride[1], conv_p.K[1]);
259
260 int32_t* inSum = static_cast<int32_t*>(fbgemmAlignedAlloc(
261 64, conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * sizeof(int32_t)));
262 int32_t* rowSum = static_cast<int32_t*>(fbgemmAlignedAlloc(
263 64, conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * sizeof(int32_t)));
264
265 directConvRowSum(conv_p, Aint8, inSum, rowSum);
266 int kernel_dim = conv_p.K[0] * conv_p.K[1];
267
268 std::memset(
269 C_buffer,
270 0,
271 sizeof(int32_t) * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC);
272 std::memset(
273 C,
274 0,
275 sizeof(int8_t) * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC);
276 // no-op output process objects
277 for (int i = 0; i < conv_p.OC; i += 8) {
278 for (int j = 0; j < conv_p.IN_DIM[0]; j++) {
279 fn(Aint8 + j * conv_p.IC * conv_p.IN_DIM[1],
280 Bint8_tr.PackedMat() + i * kernel_dim * conv_p.IC,
281 C_buffer + j * conv_p.OUT_DIM[1] * conv_p.OC + i,
282 conv_p.IC,
283 conv_p.OC,
284 (conv_p.OC * conv_p.OUT_DIM[1] - conv_p.OC * conv_p.K[1]) * 4,
285 conv_p.IN_DIM[1]);
286 }
287 }
288
289 int32_t A_zero_point = outProcess.getAZeroPoint();
290 const int32_t* B_zero_point = outProcess.getBZeroPoint();
291 // const float* C_multiplier = outProcess.getCMultiplier();
292 const int32_t* col_offsets = outProcess.getColOffsets();
293
294 /*
295 int groups = 1;
296 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
297 groups = conv_p.OC;
298 }
299 */
300 requantizationParams_t<BIAS_TYPE> reqObj = {
301 outProcess.getAZeroPoint(),
302 outProcess.getBZeroPoint(),
303 outProcess.getCZeroPoint(),
304 outProcess.getCMultiplier(),
305 rowSum, // rowOffsetBuf,
306 outProcess.getColOffsets(),
307 (outProcess.getBias()),
308 static_cast<std::uint32_t>(conv_p.OC), // outProcess.getNCols(),
309 1, // groups
310 outProcess.getActWScale()};
311
312 // Dispatch HAS_BIAS
313 if (bias == nullptr) {
314 // Dispatch A_SYMMETRIC and B_SYMMETRIC
315 if (A_zero_point == 0 || col_offsets == nullptr) {
316 if (Q_GRAN == QuantizationGranularity::TENSOR &&
317 B_zero_point[0] == 0) {
318 requantizeOutputProcessingAvx2<
319 true,
320 true,
321 QuantizationGranularity::TENSOR,
322 false, // HAS_BIAS,
323 FUSE_RELU,
324 BIAS_TYPE,
325 true>(
326 C,
327 C_buffer,
328 {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
329 conv_p.OC,
330 conv_p.OC,
331 reqObj);
332 } else {
333 requantizeOutputProcessingAvx2<
334 true,
335 false,
336 Q_GRAN,
337 false, // HAS_BIAS,
338 FUSE_RELU,
339 BIAS_TYPE,
340 true>(
341 C,
342 C_buffer,
343 {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
344 conv_p.OC,
345 conv_p.OC,
346 reqObj);
347 }
348 } else {
349 if (Q_GRAN == QuantizationGranularity::TENSOR &&
350 B_zero_point[0] == 0) {
351 requantizeOutputProcessingAvx2<
352 false,
353 true,
354 QuantizationGranularity::TENSOR,
355 false, // HAS_BIAS,
356 FUSE_RELU,
357 BIAS_TYPE,
358 true>(
359 C,
360 C_buffer,
361 {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
362 conv_p.OC,
363 conv_p.OC,
364 reqObj);
365 } else {
366 requantizeOutputProcessingAvx2<
367 false,
368 false,
369 Q_GRAN,
370 false, // HAS_BIAS,
371 FUSE_RELU,
372 BIAS_TYPE,
373 true>(
374 C,
375 C_buffer,
376 {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
377 conv_p.OC,
378 conv_p.OC,
379 reqObj);
380 }
381 }
382 } else { // has_bias == true
383
384 // dispatch A_SYMMETRIC and B_SYMMETRIC
385 if (A_zero_point == 0 || col_offsets == nullptr) {
386 if (Q_GRAN == QuantizationGranularity::TENSOR &&
387 B_zero_point[0] == 0) {
388 requantizeOutputProcessingAvx2<
389 true,
390 true,
391 QuantizationGranularity::TENSOR,
392 true, // HAS_BIAS,
393 FUSE_RELU,
394 BIAS_TYPE,
395 true>(
396 C,
397 C_buffer,
398 {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
399 conv_p.OC,
400 conv_p.OC,
401 reqObj);
402 } else {
403 requantizeOutputProcessingAvx2<
404 true,
405 false,
406 Q_GRAN,
407 true, // HAS_BIAS,
408 FUSE_RELU,
409 BIAS_TYPE,
410 true>(
411 C,
412 C_buffer,
413 {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
414 conv_p.OC,
415 conv_p.OC,
416 reqObj);
417 }
418 } else {
419 if (Q_GRAN == QuantizationGranularity::TENSOR &&
420 B_zero_point[0] == 0) {
421 requantizeOutputProcessingAvx2<
422 false,
423 true,
424 QuantizationGranularity::TENSOR,
425 true, // HAS_BIAS,
426 FUSE_RELU,
427 BIAS_TYPE,
428 true>(
429 C,
430 C_buffer,
431 {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
432 conv_p.OC,
433 conv_p.OC,
434 reqObj);
435 } else {
436 requantizeOutputProcessingAvx2<
437 false,
438 false,
439 Q_GRAN,
440 true, // HAS_BIAS,
441 FUSE_RELU,
442 BIAS_TYPE,
443 true>(
444 C,
445 C_buffer,
446 {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
447 conv_p.OC,
448 conv_p.OC,
449 reqObj);
450 }
451 }
452 }
453 fbgemmAlignedFree(inSum);
454 fbgemmAlignedFree(rowSum);
455 } // transposed conv
456 else { // non-transposed conv
457 assert(false && "non-transposed direct conv not integrated yet.");
458 }
459 } // else SPATIAL_DIM
460}
461
462#define INSTANTIATE_REQUANTIZE_SPATIAL_DIM( \
463 SPATIAL_DIM, Q_GRAN, RELU, BIAS_TYPE) \
464 template void FBGEMM_API \
465 fbgemmDirectConv<SPATIAL_DIM, Q_GRAN, RELU, BIAS_TYPE>( \
466 const conv_param_t<SPATIAL_DIM>& conv_p, \
467 const uint8_t* Aint8, \
468 PackedDirectConvMatrix& Bint8_tr, \
469 uint8_t* C, \
470 int32_t* C_buffer, \
471 const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
472 const BIAS_TYPE* bias, \
473 int thread_id, \
474 int num_threads);
475
476#define INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, BIAS_TYPE) \
477 INSTANTIATE_REQUANTIZE_SPATIAL_DIM(1, Q_GRAN, RELU, BIAS_TYPE) \
478 INSTANTIATE_REQUANTIZE_SPATIAL_DIM(2, Q_GRAN, RELU, BIAS_TYPE) \
479 INSTANTIATE_REQUANTIZE_SPATIAL_DIM(3, Q_GRAN, RELU, BIAS_TYPE)
480
481#define INSTANTIATE_REQUANTIZE(Q_GRAN, RELU) \
482 INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, float) \
483 INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, int32_t)
484
485#define INSTANTIATE_Q_GRANS(RELU) \
486 INSTANTIATE_REQUANTIZE(QuantizationGranularity::TENSOR, RELU) \
487 INSTANTIATE_REQUANTIZE(QuantizationGranularity::GROUP, RELU) \
488 INSTANTIATE_REQUANTIZE(QuantizationGranularity::OUT_CHANNEL, RELU)
489
490INSTANTIATE_Q_GRANS(true)
491INSTANTIATE_Q_GRANS(false)
492
493#undef INSTANTIATE_REQUANTIZE_SPATIAL_DIM
494#undef INSTANTIATE_REQUANTIZE_BIAS_TYPE
495#undef INSTANTIATE_REQUANTIZE
496#undef INSTANTIATE_Q_GRANS
497} // namespace fbgemm
498