1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "./GroupwiseConv.h"
9#include <asmjit/asmjit.h>
10#include <cpuinfo.h>
11#include <array>
12#include <iostream>
13#include <map>
14#include <stdexcept>
15#include <tuple>
16#include <type_traits>
17#include "./CodeGenHelpers.h"
18#include "./RefImplementations.h"
19#include "./TransposeUtils.h"
20#include "fbgemm/Fbgemm.h"
21#include "fbgemm/QuantUtilsAvx512.h"
22#include "fbgemm/SimdUtils.h"
23
24namespace fbgemm {
25
26using namespace std;
27
28template <int SPATIAL_DIM>
29void calculateRowOffsets(
30 const conv_param_t<SPATIAL_DIM>& conv_param,
31 const uint8_t* activations,
32 int32_t* rowOffsetBuf,
33 int32_t a_zero_point,
34 int groupNum) {
35 int OH = conv_param.OUT_DIM[0];
36 int OW = conv_param.OUT_DIM[1];
37 int IH = conv_param.IN_DIM[0];
38 int IW = conv_param.IN_DIM[1];
39 int G = conv_param.G;
40 int C_per_G = conv_param.IC / conv_param.G;
41 int H_PAD = conv_param.pad[0];
42 int W_PAD = conv_param.pad[1];
43 // calculate row offset
44 for (int h = 0; h < OH; ++h) {
45 for (int w = 0; w < OW; ++w) {
46 int32_t sum = 0;
47 for (int r = 0; r < conv_param.K[0]; ++r) {
48 int h_in = -H_PAD + h * conv_param.stride[0] + r;
49 for (int s = 0; s < conv_param.K[1]; ++s) {
50 int w_in = -W_PAD + w * conv_param.stride[1] + s;
51 for (int c = 0; c < C_per_G; ++c) {
52 int a_val;
53 if (h_in < 0 || h_in >= IH || w_in < 0 || w_in >= IW) {
54 a_val = a_zero_point;
55 } else {
56 a_val = activations
57 [((h_in * IW + w_in) * G + groupNum) * C_per_G + c];
58 }
59 sum += a_val;
60 }
61 }
62 }
63 rowOffsetBuf[h * OW + w] = sum;
64 }
65 }
66}
67
68template <int SPATIAL_DIM = 2>
69kernel_sig_t getKernelSig(
70 const conv_param_t<SPATIAL_DIM>& conv_param,
71 bool isAZeroPointZero,
72 bool needRowOffset,
73 bool isTopEdgeIncluded,
74 bool isBottomEdgeIncluded,
75 bool isTopBottomEdgeSame,
76 bool accum) {
77 // kernel is specialized on number of input channels per group, number of
78 // output channels per group, whether stride is 1 or stride is 2, whether or
79 // not zero point for activations is 0 or not, whether or not row offset
80 // calculations are needed, whether or not top edge is included and whether or
81 // not bottom edge is included.
82 // use_padding_: If false, the right padding on the width side and bottom
83 // padding on height side are not used for the case of stride = 2
84 // accum: accumulate results for output and rowoffset
85 int C_per_G = conv_param.IC / conv_param.G;
86 int K_per_G = conv_param.OC / conv_param.G;
87 auto kernelSig = make_tuple(
88 isAZeroPointZero,
89 needRowOffset,
90 isTopEdgeIncluded,
91 isBottomEdgeIncluded,
92 isTopBottomEdgeSame,
93 !(conv_param.stride[SPATIAL_DIM - 2] > 1 &&
94 conv_param.IN_DIM[SPATIAL_DIM - 2] % 2 == 0),
95 !(conv_param.stride[SPATIAL_DIM - 1] > 1 &&
96 conv_param.IN_DIM[SPATIAL_DIM - 1] % 2 == 0),
97 accum,
98 conv_param.G,
99 conv_param.stride[0],
100 C_per_G,
101 K_per_G);
102 return kernelSig;
103}
104
105template <int SPATIAL_DIM = 2>
106jit_conv_kernel_fp getOrCreateConvKernel(
107 const conv_param_t<SPATIAL_DIM>& conv_param,
108 int a_zero_point,
109 bool needRowOffset,
110 bool isTopEdgeIncluded,
111 bool isBottomEdgeIncluded,
112 bool isTopBottomEdgeSame,
113 bool accum) {
114 // Note: Wrong code is generated if it's not one of the supported convolution
115 assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
116 auto kernelSig = getKernelSig(
117 conv_param,
118 a_zero_point == 0,
119 needRowOffset,
120 isTopEdgeIncluded,
121 isBottomEdgeIncluded,
122 isTopBottomEdgeSame,
123 accum);
124
125 if (cpuinfo_initialize()) {
126 if (fbgemmHasAvx512VnniSupport()) {
127 return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>::codeCache_
128 .getOrCreate(kernelSig, [&]() {
129 auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>(
130 conv_param,
131 a_zero_point,
132 needRowOffset,
133 isTopEdgeIncluded,
134 isBottomEdgeIncluded,
135 isTopBottomEdgeSame,
136 accum);
137 return genObj.getOrCreate();
138 });
139 } else if (fbgemmHasAvx512Support()) {
140 return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512>::codeCache_
141 .getOrCreate(kernelSig, [&]() {
142 auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx512>(
143 conv_param,
144 a_zero_point,
145 needRowOffset,
146 isTopEdgeIncluded,
147 isBottomEdgeIncluded,
148 isTopBottomEdgeSame,
149 accum);
150 return genObj.getOrCreate();
151 });
152 } else if (fbgemmHasAvx2Support()) {
153 return GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>::codeCache_
154 .getOrCreate(kernelSig, [&]() {
155 auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>(
156 conv_param,
157 a_zero_point,
158 needRowOffset,
159 isTopEdgeIncluded,
160 isBottomEdgeIncluded,
161 isTopBottomEdgeSame,
162 accum);
163 return genObj.getOrCreate();
164 });
165 } else {
166 // TODO: Have default slower path
167 assert(0 && "unsupported architecture");
168 }
169 } else {
170 throw runtime_error("Failed to initialize cpuinfo!");
171 }
172 return nullptr;
173}
174
175template <int SPATIAL_DIM, inst_set_t INST_SET>
176jit_conv_kernel_fp GenConvKernel<SPATIAL_DIM, INST_SET>::getOrCreate() {
177 asmjit::CodeHolder code;
178 code.init(this->runtime().environment());
179 x86::Assembler assembler(&code);
180 x86::Emitter* a = assembler.as<x86::Emitter>();
181
182 typedef typename simd_info<INST_SET>::vec_reg_t vec_reg_t;
183#if defined(FBGEMM_LOG_CODE)
184 auto kernelSig = make_tuple(
185 this->isAZeroPointZero_,
186 this->needRowOffset_,
187 this->isTopEdgeIncluded_,
188 this->isBottomEdgeIncluded_,
189 this->use_bottom_padding_,
190 this->use_right_padding_,
191 this->accum_,
192 this->G_,
193 this->STRIDE_,
194 this->C_per_G_,
195 this->K_per_G_);
196 // log code to a file
197 FILE* codeLogfile = fopen(this->getCodeLoggingFile(kernelSig).c_str(), "w");
198 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
199 if (codeLogger) {
200 code.setLogger(codeLogger);
201 }
202#endif
203
204 // arguments to the function created
205 in_acts_R_ = a->zdi();
206 wghts_R_ = a->zsi();
207 out_acts_R_ = a->zdx();
208 a_zero_pt_R_ = a->zcx();
209 H_start_R_ = a->gpz(8);
210 H_end_R_ = a->gpz(9);
211 W_R_ = a->gpz(10);
212 row_offset_R_ = a->gpz(11);
213
214 // register for temporary use
215 scratchReg1_ = a->gpz(12);
216 scratchReg2_ = a->gpz(13);
217
218 func_.init(
219 asmjit::FuncSignatureT<
220 void,
221 uint8_t*,
222 int8_t*,
223 int32_t*,
224 int32_t,
225 int32_t,
226 int32_t,
227 int32_t,
228 int32_t*>(asmjit::CallConvId::kHost),
229 a->environment());
230
231 frame_.init(func_);
232
233 frame_.setDirtyRegs(
234 asmjit::RegGroup::kVec,
235 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
236 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
237 frame_.setDirtyRegs(
238 asmjit::RegGroup::kGp,
239 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
240
241 asmjit::FuncArgsAssignment args(&func_);
242 args.assignAll(
243 in_acts_R_,
244 wghts_R_,
245 out_acts_R_,
246 a_zero_pt_R_,
247 H_start_R_,
248 H_end_R_,
249 W_R_,
250 row_offset_R_);
251
252 args.updateFuncFrame(frame_);
253 frame_.finalize();
254
255 a->emitProlog(frame_);
256 a->emitArgsAssignment(frame_, args);
257
258 // We have run out of register so can't keep
259 // this in a register. It's generated again at
260 // each use. Only used for the case of C_per_G == 2 or 4
261 // gen8BitVectorOne(a, oneReg8Bit_V_);
262 gen16BitVectorOne<INST_SET, vec_reg_t>(a, oneReg16Bit_V_);
263
264 loopR1_ = a->gpz(14);
265 loopR2_ = a->gpz(15);
266
267 if (!this->isAZeroPointZero_) {
268 broadcast8Bit<vec_reg_t>(a, a_zero_pt_R_, zeroPTReg_V_);
269 }
270
271 genConstForPermutations(a);
272
273 genForLoadingWeights(a);
274
275 // W_R_ is an input to the JIT'ed kernel and is output image width.
276 // W_R_ is passed in using stack. We reload it inside kernel where we need it.
277 // The following logic calculates the input image width in the same register.
278 // Only works for stride == 2
279 if (this->STRIDE_ > 1) {
280 a->imul(W_R_, W_R_, static_cast<asmjit::Imm>(this->STRIDE_));
281 if (!this->use_right_padding_) {
282 a->inc(W_R_);
283 }
284 a->sub(W_R_, static_cast<asmjit::Imm>(this->STRIDE_ - 1));
285 }
286
287 if (this->isTopEdgeIncluded_) {
288 genForTopOrBottomEdge(
289 a,
290 true /* isTopEdge */,
291 this->isTopBottomEdgeSame_ && this->use_bottom_padding_);
292 }
293 genCoreInsts(a);
294 if (this->isBottomEdgeIncluded_ && !this->isTopBottomEdgeSame_) {
295 genForTopOrBottomEdge(
296 a, false /* isTopEdge */, this->use_bottom_padding_ /* isBottomEdge */);
297 }
298
299 a->emitEpilog(frame_);
300
301 jit_conv_kernel_fp fn;
302 asmjit::Error err;
303 {
304 unique_lock<mutex> lock(this->rtMutex_);
305 err = this->runtime().add(&fn, &code);
306 }
307
308 if (err) {
309 cout << "Error: in fn add" << endl;
310 return nullptr;
311 }
312
313#if defined(FBGEMM_LOG_CODE)
314 fclose(codeLogfile);
315 delete codeLogger;
316#endif
317
318 return fn;
319}
320
321template <int SPATIAL_DIM, inst_set_t INST_SET>
322void GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleOutput(
323 x86::Emitter* a,
324 bool isLeft,
325 bool isRight,
326 bool isTop,
327 bool isBottom) {
328 // init result regs
329 initResultRegs(a);
330
331 // row offset
332 if (this->needRowOffset_) {
333 a->vpxor(
334 rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm());
335 }
336
337 bool isWidthMiddle = !isLeft && !isRight;
338 bool isHeightMiddle = !isTop && !isBottom;
339 int num_rows_advanced = 0;
340 for (int r = 0; r < this->R_; ++r) {
341 int h_in = r;
342 if (isTop) {
343 h_in = -this->H_PAD_ + r;
344 }
345 bool in_image_H = (isTop && !isBottom && h_in >= 0) ||
346 (!isTop && isBottom && h_in < (this->R_ - this->H_PAD_)) ||
347 (isTop && isBottom && h_in >= 0 &&
348 h_in < (this->R_ - 2 * this->H_PAD_)) ||
349 isHeightMiddle;
350 for (int s = 0; s < this->S_; ++s) {
351 int w_in = s;
352 if (isLeft) {
353 w_in = -this->W_PAD_ + s;
354 }
355 bool in_image_W = (isLeft && !isRight && w_in >= 0) ||
356 (!isLeft && isRight && w_in < (this->S_ - this->W_PAD_)) ||
357 (isLeft && isRight && w_in >= 0 &&
358 w_in < (this->S_ - 2 * this->W_PAD_)) ||
359 isWidthMiddle;
360 if (in_image_H && in_image_W) {
361 genForSingleFilterPoint(a, r, s, w_in, false);
362 } else {
363 if (!this->isAZeroPointZero_) {
364 genForSingleFilterPoint(a, r, s, w_in, true);
365 }
366 }
367 }
368 if (in_image_H) {
369 // advance input pointer by one row
370 a->imul(
371 scratchReg2_,
372 W_R_,
373 static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t)));
374 a->add(in_acts_R_, scratchReg2_);
375 ++num_rows_advanced;
376 }
377 }
378
379 storeResult(a);
380
381 // row offset
382 if (this->needRowOffset_) {
383 storeOffset(a);
384 a->add(
385 row_offset_R_, static_cast<asmjit::Imm>(GTogether_ * sizeof(int32_t)));
386 }
387
388 // rewind input ptr
389 a->imul(
390 scratchReg2_,
391 W_R_,
392 static_cast<asmjit::Imm>(num_rows_advanced * this->C_ * sizeof(uint8_t)));
393 a->sub(in_acts_R_, scratchReg2_);
394
395 // advance output pointer
396 a->add(out_acts_R_, static_cast<asmjit::Imm>(this->K_ * sizeof(int32_t)));
397
398 // advance input ptr
399 if (!isLeft) {
400 a->add(
401 in_acts_R_,
402 static_cast<asmjit::Imm>(this->STRIDE_ * this->C_ * sizeof(uint8_t)));
403 } else if (this->STRIDE_ - this->W_PAD_) {
404 a->add(
405 in_acts_R_,
406 static_cast<asmjit::Imm>(
407 (this->STRIDE_ - this->W_PAD_) * this->C_ * sizeof(uint8_t)));
408 }
409}
410
411template <int SPATIAL_DIM, inst_set_t INST_SET>
412void GenConvKernel<SPATIAL_DIM, INST_SET>::genForTopOrBottomEdge(
413 x86::Emitter* a,
414 bool isTopEdge,
415 bool isBottomEdge) {
416 // Output width was passed in as the 7th argument (i.e., using stack).
417 // Reload it from the same location.
418 a->movsxd(
419 loopR1_,
420 x86::dword_ptr(
421 x86::rsp, frame_.saOffsetFromSP() + func_.arg(6).stackOffset()));
422 asmjit::Label LoopWStart = a->newLabel();
423 asmjit::Label LoopWEnd = a->newLabel();
424 asmjit::Label skipRightEdge = a->newLabel();
425 asmjit::Label skipRightEdgeTemp = a->newLabel();
426 a->cmp(loopR1_, static_cast<asmjit::Imm>(this->W_PAD_));
427 a->jle(skipRightEdgeTemp);
428
429 // left corner code
430 genForSingleOutput(
431 a,
432 true, // isLeft
433 false, // isRight
434 isTopEdge, // isTop
435 isBottomEdge // isBotom
436 );
437 a->jmp(LoopWStart);
438
439 a->bind(skipRightEdgeTemp);
440 // top-left corner code
441 genForSingleOutput(
442 a,
443 true, // isLeft
444 this->use_right_padding_, // isRight
445 isTopEdge, // isTop
446 isBottomEdge // isBotom
447 );
448 a->jmp(skipRightEdge);
449
450 // edge excluding corners
451 a->bind(LoopWStart);
452
453 a->cmp(loopR1_, static_cast<asmjit::Imm>(2 * this->W_PAD_));
454 a->jle(LoopWEnd);
455
456 genForSingleOutput(
457 a,
458 false, // isLeft
459 false, // isRight
460 isTopEdge, // isTop
461 isBottomEdge // isBotom
462 );
463
464 a->dec(loopR1_);
465 a->jmp(LoopWStart);
466 a->bind(LoopWEnd);
467
468 // top-right corner code
469 genForSingleOutput(
470 a,
471 false, // isLeft
472 this->use_right_padding_, // isRight
473 isTopEdge, // isTop
474 isBottomEdge // isBottom
475 );
476
477 a->bind(skipRightEdge);
478
479 if (this->STRIDE_ > 1) {
480 // STRIDE_ == 2 and even widths,
481 // We increase it by C_;
482 // STRIDE_ == 2 and odd widths, nothing to do
483 // input ptr is already at the right position
484 if (!this->use_right_padding_) {
485 a->add(in_acts_R_, static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t)));
486 }
487 } else {
488 // reset input activation pointer by (W_R_ - W_PAD_) * C_
489 a->mov(scratchReg2_, W_R_);
490 a->imul(scratchReg2_, static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t)));
491 a->sub(
492 scratchReg2_,
493 static_cast<asmjit::Imm>(this->W_PAD_ * this->C_ * sizeof(uint8_t)));
494 a->sub(in_acts_R_, scratchReg2_);
495 }
496}
497
498template <int SPATIAL_DIM, inst_set_t INST_SET>
499void GenConvKernel<SPATIAL_DIM, INST_SET>::genCoreInsts(x86::Emitter* a) {
500 // Top edge and bottom edge calculations are done separately
501 // so start from next and leave out the last
502 if (this->isTopEdgeIncluded_) {
503 a->inc(H_start_R_);
504 }
505 if (this->isBottomEdgeIncluded_) {
506 a->dec(H_end_R_);
507 }
508 // main compute
509 asmjit::Label LoopHStart = a->newLabel();
510 asmjit::Label LoopHEnd = a->newLabel();
511 asmjit::Label LoopWStart = a->newLabel();
512 asmjit::Label LoopWEnd = a->newLabel();
513
514 // H loop
515 a->mov(loopR1_, H_start_R_);
516 a->jmp(LoopHEnd);
517 a->bind(LoopHStart);
518 a->inc(loopR1_);
519
520 a->movsxd(
521 loopR2_,
522 x86::dword_ptr(
523 x86::rsp, frame_.saOffsetFromSP() + func_.arg(6).stackOffset()));
524 asmjit::Label skipRightEdge = a->newLabel();
525 asmjit::Label skipRightEdgeTemp = a->newLabel();
526 a->cmp(loopR2_, static_cast<asmjit::Imm>(this->W_PAD_));
527 a->jle(skipRightEdgeTemp);
528
529 genForSingleOutput(
530 a,
531 true, // isLeft,
532 false, // isRight
533 false, // isTop
534 false // isBottom
535 );
536 a->jmp(LoopWStart);
537
538 a->bind(skipRightEdgeTemp);
539 genForSingleOutput(
540 a,
541 true, // isLeft,
542 this->use_right_padding_, // isRight
543 false, // isTop
544 false // isBottom
545 );
546 a->jmp(skipRightEdge);
547
548 // W loop
549 a->bind(LoopWStart);
550
551 a->cmp(loopR2_, static_cast<asmjit::Imm>(2 * this->W_PAD_));
552 a->jle(LoopWEnd);
553
554 genForSingleOutput(
555 a,
556 false, // isLeft,
557 false, // isRight
558 false, // isTop
559 false // isBottom
560 );
561
562 a->dec(loopR2_);
563 a->jmp(LoopWStart);
564 a->bind(LoopWEnd);
565
566 genForSingleOutput(
567 a,
568 false, // isLeft
569 this->use_right_padding_, // isRight
570 false, // isTop
571 false // isBottom
572 );
573
574 a->bind(skipRightEdge);
575
576 if (this->STRIDE_ > 1) {
577 // STRIDE_ == 2 and even widths,
578 // We increase it by extra C_;
579 // STRIDE_ == 2 and odd widths, no extra C_
580 assert(this->STRIDE_ == 2 && "Not supported case");
581 a->mov(scratchReg2_, W_R_);
582 if (!this->use_right_padding_) {
583 a->add(scratchReg2_, static_cast<asmjit::Imm>(1));
584 }
585 a->imul(scratchReg2_, static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t)));
586 a->add(in_acts_R_, scratchReg2_);
587 } else {
588 a->add(in_acts_R_, static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t)));
589 }
590
591 a->bind(LoopHEnd);
592 a->cmp(loopR1_, H_end_R_);
593 a->jl(LoopHStart);
594}
595
596template <int SPATIAL_DIM, inst_set_t INST_SET>
597void GenConvKernel<SPATIAL_DIM, INST_SET>::initResultRegs(x86::Emitter* a) {
598 if (kLoopIters_ > 0) {
599 // Take advantage of implicit zeroing out
600 // i.e., zero out xmm and ymm and zmm will be zeroed out too
601 for (int k = 0; k < kLoopIters_; ++k) {
602 a->vpxor(x86::Xmm(9 - k), x86::Xmm(9 - k), x86::Xmm(9 - k));
603 }
604 } else {
605 a->vpxor(x86::Xmm(9), x86::Xmm(9), x86::Xmm(9));
606 }
607}
608
609/*
610namespace {
611
612template <
613 typename packed_W,
614 typename outType,
615 typename processOutputType,
616 int SPATIAL_DIM>
617void fbgemmGroupwiseConvBase_(
618 const conv_param_t<SPATIAL_DIM>& conv_param,
619 const uint8_t* activations,
620 int32_t a_zero_point,
621 int32_t* rowOffsetBuf,
622 packed_W& packed_weights,
623 outType* out,
624 int32_t* outBuffer,
625 const processOutputType& outProcess,
626 int thread_id,
627 int num_threads) {
628 int MB = conv_param.MB;
629 int H = conv_param.OUT_DIM[0];
630 int W = conv_param.OUT_DIM[1];
631 int G = conv_param.G;
632 int K_per_G = conv_param.OC / G;
633 int C_per_G = conv_param.IC / G;
634 int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
635 int ih_iw = conv_param.IN_DIM[0] * conv_param.IN_DIM[1];
636
637 assert(SPATIAL_DIM == 2 && "3D groupwise conv not supported yet");
638
639 int32_t* rowOffsetTrDest = rowOffsetBuf ? rowOffsetBuf + 8 * ih_iw : nullptr;
640 if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) {
641 assert(G % 8 == 0);
642 // generate convolution kernel
643 jit_conv_kernel_fp fpConv =
644 getOrCreateConvKernel<SPATIAL_DIM>(conv_param, a_zero_point);
645 // generate row offset kernel
646 jit_rowoffset_kernel_fp fpRowoffset =
647 getOrCreateRowOffsetKernel(conv_param, a_zero_point);
648 for (int i = 0; i < MB; ++i) {
649 const uint8_t* actStartBatch = activations + i * ih_iw * conv_param.IC;
650 for (int gOuter = 0; gOuter < G; gOuter += 8) {
651 // row offset is calcualted for 8 groups at a time.
652 // The result is row offsets in the format IH*IW x G
653 if (rowOffsetBuf) {
654 fpRowoffset(
655 actStartBatch + gOuter * C_per_G,
656 a_zero_point,
657 H,
658 W,
659 rowOffsetBuf);
660 // Transpose to get row offsets in the format G x IH*IW
661 internal::transpose_avx2(
662 ih_iw,
663 8,
664 reinterpret_cast<const float*>(rowOffsetBuf),
665 8,
666 reinterpret_cast<float*>(rowOffsetTrDest),
667 ih_iw);
668 }
669 int gLimit = gOuter + 8;
670 // Work on 8 output channels at a time (8 * sizeof(int32_t) == 32B VLEN
671 // of AVX2), and we need multiple groups if a group has not enough
672 // number of channels.
673 int gDelta = std::max(8 / C_per_G, 1);
674 for (int g = gOuter; g < gLimit; g += gDelta) {
675 int32_t* currOutBuf =
676 outBuffer + i * oh_ow * conv_param.OC + g * K_per_G;
677 const uint8_t* actStartGroup = actStartBatch + g * C_per_G;
678 for (int k = 0; k < K_per_G; k += 8) {
679 // Don't be confused with k above which refers to output channels.
680 // k0 and k1 are filter dimensions (commonly 3 and 3)
681 int k0 = conv_param.K[0];
682 int k1 = conv_param.K[1];
683 fpConv(
684 actStartGroup,
685 // packed weight is in G (C/4) R S K 4 layout for IC_per_G >= 8
686 // in (G/2) R S K (2C) for IC_per_G == 4
687 packed_weights.getBuf() +
688 (g * (C_per_G / 4) * k0 * k1 * K_per_G + k) * 4,
689 currOutBuf + k,
690 a_zero_point,
691 H,
692 W);
693 } // k loop
694
695 // Output processing should be called for each group
696 for (int j = 0; j < gDelta; ++j) {
697 // calculateRowOffsets(
698 // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j);
699 int32_t* rowOffsetForCurG = rowOffsetTrDest
700 ? rowOffsetTrDest + ((g - gOuter) + j) * ih_iw
701 : nullptr;
702 // compare_buffers(rowOffsetBuf, rowOffsetForCurG,
703 // conv_param.IN_DIM[0]*conv_param.IN_DIM[1], 1, 1, 100);
704
705 // outProcess expects rowOffsetBuf to contain row offsets for the
706 // current group
707 memcpy(rowOffsetBuf, rowOffsetForCurG, ih_iw * sizeof(int32_t));
708
709 if (fbgemmHasAvx512Support()) {
710 // Currently use avx2 code
711 outProcess.template f<inst_set_t::avx2>(
712 out,
713 currOutBuf + j * K_per_G,
714 {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G},
715 K_per_G * G,
716 K_per_G * G);
717 } else if (fbgemmHasAvx2Support()) {
718 outProcess.template f<inst_set_t::avx2>(
719 out,
720 currOutBuf + j * K_per_G,
721 {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G},
722 K_per_G * G,
723 K_per_G * G);
724 } else {
725 // TODO: Have default slower path
726 assert(0 && "unsupported architecure");
727 }
728 } // j loop
729 } // g loop
730 } // gOuter loop
731 } // i loop
732 } else {
733 // for the not supported cases, just execute the naive C implementation
734 conv_ref(
735 conv_param,
736 activations,
737 a_zero_point,
738 packed_weights.getBuf(),
739 outBuffer);
740 for (int i = 0; i < conv_param.MB; ++i) {
741 for (int g = 0; g < conv_param.G; ++g) {
742 if (rowOffsetBuf) {
743 calculateRowOffsets(
744 conv_param,
745 activations +
746 i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] *
747 conv_param.IC,
748 rowOffsetBuf,
749 a_zero_point,
750 g);
751 }
752 outProcess.template f<inst_set_t::anyarch>(
753 out,
754 outBuffer + i * oh_ow * conv_param.OC + g * K_per_G,
755 {i * oh_ow, oh_ow, g * K_per_G, K_per_G},
756 K_per_G * G,
757 K_per_G * G);
758 }
759 }
760 }
761}
762
763} // namespace
764
765template <
766 typename packed_W,
767 typename outType,
768 typename processOutputType,
769 int SPATIAL_DIM>
770void fbgemmGroupwiseConv(
771 const conv_param_t<SPATIAL_DIM>& conv_param,
772 const uint8_t* activations,
773 int32_t a_zero_point,
774 int32_t* rowOffsetBuf,
775 packed_W& packed_weights,
776 outType* out,
777 int32_t* outBuffer,
778 const processOutputType& outProcess,
779 int thread_id,
780 int num_threads) {
781 // TODO: Remove this when threading is supported.
782 if (thread_id > 0) {
783 return;
784 }
785
786 return fbgemmGroupwiseConvBase_<
787 packed_W,
788 outType,
789 processOutputType,
790 SPATIAL_DIM>(
791 conv_param,
792 activations,
793 a_zero_point,
794 rowOffsetBuf,
795 packed_weights,
796 out,
797 outBuffer,
798 outProcess,
799 thread_id,
800 num_threads);
801}
802
803*/
804
805/*
806 *
807 * This function does exactly the same compute as the JIT'ed kernel
808 */
809template <int SPATIAL_DIM>
810void kernel_compute(
811 const conv_param_t<SPATIAL_DIM>& conv_p,
812 const uint8_t* in_acts,
813 int8_t* wghts,
814 int32_t* out_acts,
815 int32_t a_zero_pt,
816 int32_t h_start,
817 int32_t h_end,
818 int32_t width,
819 int32_t* rowOffset,
820 bool accum) {
821 int IW = conv_p.IN_DIM[1];
822 int IC = conv_p.IC;
823 int OC = conv_p.OC;
824 int G = conv_p.G;
825 int R = conv_p.K[0];
826 int S = conv_p.K[1];
827 int IC_per_G = conv_p.IC / G;
828 int OC_per_G = conv_p.OC / G;
829 int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>::
830 numOfGroupsTogether(conv_p);
831 int paddedICPerG = (IC_per_G + 3) / 4 * 4;
832 for (int h = h_start; h < h_end; ++h) {
833 for (int w = 0; w < width; ++w) {
834 for (int g = 0; g < G_together; ++g) {
835 for (int k = 0; k < OC_per_G; ++k) {
836 int sum = 0;
837 int rowSum = 0;
838 for (int r = 0; r < R; ++r) {
839 int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
840 for (int s = 0; s < S; ++s) {
841 int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
842 for (int c = 0; c < IC_per_G; ++c) {
843 bool out_of_image = h_in < 0 || h_in >= conv_p.IN_DIM[0] ||
844 w_in < 0 || w_in >= conv_p.IN_DIM[1];
845 int h_index = h_in;
846 if (h_start > 0) {
847 h_index = (h - h_start) * conv_p.stride[1] + r;
848 }
849 int a = out_of_image
850 ? a_zero_pt
851 : in_acts[(h_index * IW + w_in) * IC + g * IC_per_G + c];
852 int idx = (((r * S + s) * OC_per_G + k) * G_together + g) *
853 paddedICPerG +
854 c;
855 int b = wghts[idx];
856 sum += a * b;
857 rowSum += a;
858 }
859 }
860 }
861 if (accum) {
862 out_acts[((h - h_start) * width + w) * OC + g * OC_per_G + k] +=
863 sum;
864 if (k == 0 && rowOffset != nullptr) {
865 // only accumulate for k == 0
866 rowOffset[((h - h_start) * width + w) * G_together + g] += rowSum;
867 }
868 } else {
869 out_acts[((h - h_start) * width + w) * OC + g * OC_per_G + k] = sum;
870 if (rowOffset != nullptr) {
871 rowOffset[((h - h_start) * width + w) * G_together + g] = rowSum;
872 }
873 }
874 }
875 }
876 }
877 }
878}
879
880template <typename processOutputType, typename outT, typename inT>
881void dispatchOutputProcessing(
882 const processOutputType& outProcess,
883 int32_t* rowOffsetBuf,
884 outT* out,
885 const inT* inp,
886 const block_type_t& block,
887 int ld_out,
888 int ld_in,
889 int groups,
890 int C_per_G,
891 true_type) {
892 constexpr QuantizationGranularity Q_GRAN = processOutputType::QGRANType;
893 constexpr int FUSE_RELU = processOutputType::RELU_FUSED;
894 bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR &&
895 outProcess.getBZeroPoint()[0] == 0) ||
896 rowOffsetBuf == nullptr;
897 int32_t a_zero_point = outProcess.getAZeroPoint();
898
899 // Requantization
900 requantizationParams_t<typename processOutputType::BIAS_T> r = {
901 a_zero_point,
902 outProcess.getBZeroPoint(),
903 outProcess.getCZeroPoint(),
904 outProcess.getCMultiplier(),
905 rowOffsetBuf,
906 outProcess.getColOffsets(),
907 outProcess.getBias(),
908 outProcess.getNCols(),
909 groups,
910 outProcess.getActWScale()};
911
912#define REQUANTIZE_BASE(ISA, C_PER_G, A_SYM, B_SYM, BIAS) \
913 requantizeOutputProcessingGConv##ISA< \
914 A_SYM, \
915 B_SYM, \
916 Q_GRAN, \
917 BIAS, \
918 FUSE_RELU, \
919 C_PER_G>(out, inp, block, ld_out, ld_in, r);
920
921#define REQUANTIZE_BIAS(ISA, C_PER_G, A_SYM, B_SYM) \
922 if (outProcess.getBias() == nullptr) { \
923 REQUANTIZE_BASE(ISA, C_PER_G, A_SYM, B_SYM, /*bias=*/false); \
924 } else { \
925 REQUANTIZE_BASE(ISA, C_PER_G, A_SYM, B_SYM, /*bias=*/true); \
926 }
927
928#define REQUANTIZE_BSYM(ISA, C_PER_G, A_SYM) \
929 if (b_symmetric) { \
930 REQUANTIZE_BIAS(ISA, C_PER_G, A_SYM, true); \
931 } else { \
932 REQUANTIZE_BIAS(ISA, C_PER_G, A_SYM, false); \
933 }
934
935#define REQUANTIZE_ASYM(ISA, C_PER_G) \
936 if (a_zero_point == 0) { \
937 REQUANTIZE_BSYM(ISA, C_PER_G, true); \
938 } else { \
939 REQUANTIZE_BSYM(ISA, C_PER_G, false); \
940 }
941
942#define REQUANTIZE_C_PER_G(ISA) \
943 if (C_per_G == 2) { \
944 REQUANTIZE_ASYM(ISA, 2); \
945 } else if (C_per_G == 4) { \
946 REQUANTIZE_ASYM(ISA, 4); \
947 } else if (C_per_G == 8) { \
948 REQUANTIZE_ASYM(ISA, 8); \
949 } else { \
950 REQUANTIZE_ASYM(ISA, 16); \
951 }
952
953 if (cpuinfo_initialize()) {
954 if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) {
955 REQUANTIZE_C_PER_G(Avx512);
956 } else if (fbgemmHasAvx2Support() || fbgemmHasArmNeonSupport()) {
957 REQUANTIZE_C_PER_G(Avx2);
958 } else {
959 assert(0 && "unsupported architecture");
960 }
961 } else {
962 throw runtime_error("Failed to initialize cpuinfo!");
963 }
964}
965
966#undef REQUANTIZE_C_PER_G
967#undef REQUANTIZE_ASYM
968#undef REQUANTIZE_BSYM
969#undef REQUANTIZE_BIAS
970#undef REQUANTIZE_BASE
971
972template <
973 typename packed_W,
974 typename outType,
975 bool FUSE_RELU,
976 QuantizationGranularity Q_GRAN,
977 int SPATIAL_DIM,
978 typename BIAS_TYPE>
979void fbgemmGroupwiseConv(
980 const conv_param_t<SPATIAL_DIM>& conv_param,
981 const uint8_t* activations,
982 int32_t a_zero_point,
983 int32_t* rowOffsetBuf,
984 packed_W& packed_weights,
985 outType* out,
986 int32_t* outBuffer,
987 const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
988 int thread_id,
989 int num_threads) {
990 using processOutputType = ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>;
991
992 if (!cpuinfo_initialize()) {
993 throw runtime_error("Failed to initialize cpuinfo!");
994 }
995
996 int MB = conv_param.MB;
997 int OT = SPATIAL_DIM <= 2 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 3];
998 int OH = SPATIAL_DIM == 1 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 2];
999 int OW = conv_param.OUT_DIM[SPATIAL_DIM - 1];
1000 int T = SPATIAL_DIM <= 2 ? 1 : conv_param.K[SPATIAL_DIM - 3];
1001 int R = SPATIAL_DIM == 1 ? 1 : conv_param.K[SPATIAL_DIM - 2];
1002 int S = conv_param.K[SPATIAL_DIM - 1];
1003 int G = conv_param.G;
1004 int OC = conv_param.OC;
1005 int IC = conv_param.IC;
1006 int K_per_G = conv_param.OC / G;
1007 int C_per_G = conv_param.IC / G;
1008 int OH_OW = OH * OW;
1009 int OT_OH_OW = OT * OH * OW;
1010 int IT = SPATIAL_DIM <= 2 ? 1 : conv_param.IN_DIM[SPATIAL_DIM - 3];
1011 int IH = SPATIAL_DIM == 1 ? 1 : conv_param.IN_DIM[SPATIAL_DIM - 2];
1012 int IW = conv_param.IN_DIM[SPATIAL_DIM - 1];
1013 int IH_IW = IH * IW;
1014 int IT_IH_IW = IT * IH * IW;
1015 int paddedCPerG = (C_per_G + 3) / 4 * 4;
1016
1017#if defined(__x86_64__) || defined(__i386__) || \
1018 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1019 bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR &&
1020 outProcess.getBZeroPoint()[0] == 0) ||
1021 rowOffsetBuf == nullptr;
1022#endif
1023 int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>::
1024 numOfGroupsTogether(conv_param);
1025
1026 if (SPATIAL_DIM == 1) {
1027 throw std::runtime_error("Groupwise 1D not implemented!");
1028 }
1029 if (SPATIAL_DIM == 2) {
1030 // Parallelization:
1031 int64_t batch_start = 0;
1032 int64_t batch_end = MB;
1033 int64_t oh_start = 0;
1034 int64_t oh_end = OH;
1035 if (MB >= num_threads) {
1036 fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end);
1037 } else {
1038 fbgemmPartition1D(thread_id, num_threads, OH, oh_start, oh_end);
1039 }
1040
1041 if (batch_start >= batch_end || oh_start >= oh_end) {
1042 // There is no work for this thread
1043 return;
1044 }
1045
1046#if defined(__x86_64__) || defined(__i386__) || \
1047 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1048 // generate convolution + rowOffset kernel
1049 bool calculateRowOffset = !b_symmetric;
1050 bool isTopEdgeIncluded = oh_start == 0;
1051 bool isBottomEdgeIncluded = oh_end == OH;
1052 bool isTopBottomEdgeSame =
1053 isTopEdgeIncluded && isBottomEdgeIncluded && oh_end == oh_start + 1;
1054 jit_conv_kernel_fp fpConv = getOrCreateConvKernel<SPATIAL_DIM>(
1055 conv_param,
1056 a_zero_point,
1057 calculateRowOffset,
1058 isTopEdgeIncluded,
1059 isBottomEdgeIncluded,
1060 isTopBottomEdgeSame,
1061 false);
1062#endif
1063
1064 int ih_start = 0;
1065 if (oh_start > 0) {
1066 ih_start = -conv_param.pad[SPATIAL_DIM - 2] +
1067 oh_start * conv_param.stride[SPATIAL_DIM - 2];
1068 }
1069 int32_t* out_start = outBuffer + oh_start * OW * OC;
1070 const uint8_t* in_start = activations + ih_start * IW * IC;
1071 int32_t* rowOffsetBuf_start =
1072 rowOffsetBuf ? rowOffsetBuf + oh_start * OW * G_together : nullptr;
1073 for (int i = batch_start; i < batch_end; ++i) {
1074 const uint8_t* in_start_batch = in_start + i * IH_IW * conv_param.IC;
1075 int32_t* out_start_batch = out_start + i * OH_OW * OC;
1076 int32_t* rowOffsetBuf_start_batch =
1077 rowOffsetBuf ? rowOffsetBuf_start + i * OH_OW * G_together : nullptr;
1078 for (int g = 0; g < G; g += G_together) {
1079 const uint8_t* in_start_group = in_start_batch + g * C_per_G;
1080 int8_t* weight_start =
1081 packed_weights.getBuf() + g * R * S * K_per_G * paddedCPerG;
1082 int32_t* out_start_group = out_start_batch;
1083 int32_t* rowOffsetBuf_start_group = rowOffsetBuf_start_batch;
1084 // Uncomment the following two lines to stop
1085 // reuse of output and rowoffset buffer
1086 // out_start_group = out_start_batch + g * K_per_G;
1087 // rowOffsetBuf_start_group = rowOffsetBuf_start_batch + g * MB * OH_OW;
1088
1089 // exactly the same compute as the JIT'ed below
1090 // kernel_compute(
1091 // conv_param,
1092 // in_start_group,
1093 // weight_start,
1094 // out_start_group,
1095 // a_zero_point,
1096 // oh_start,
1097 // oh_end,
1098 // OW,
1099 // rowOffsetBuf_start_group);
1100#if defined(__x86_64__) || defined(__i386__) || \
1101 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1102 fpConv(
1103 in_start_group,
1104 weight_start,
1105 out_start_group,
1106 a_zero_point,
1107 oh_start,
1108 oh_end,
1109 OW,
1110 rowOffsetBuf_start_group);
1111#else
1112 kernel_compute(
1113 conv_param,
1114 in_start_group,
1115 weight_start,
1116 out_start_group,
1117 a_zero_point,
1118 oh_start,
1119 oh_end,
1120 OW,
1121 rowOffsetBuf_start_group,
1122 false);
1123#endif
1124
1125 const int32_t* inp = out_start_group;
1126 block_type_t block{
1127 static_cast<int>(i * OT_OH_OW + oh_start * OW),
1128 static_cast<int>((oh_end - oh_start) * OW),
1129 g * K_per_G,
1130 G_together * K_per_G};
1131 int ld_out = G * K_per_G;
1132 int ld_in = G * K_per_G;
1133
1134 dispatchOutputProcessing(
1135 outProcess,
1136 rowOffsetBuf_start_group,
1137 out,
1138 inp,
1139 block,
1140 ld_out,
1141 ld_in,
1142 G,
1143 C_per_G,
1144 is_requantization<processOutputType>());
1145 } // for each g
1146 } // for each i
1147 } else {
1148 assert(SPATIAL_DIM == 3 && "Unsupported SPATIAL_DIM");
1149
1150 conv_param_t<> conv_p_2d(
1151 conv_param.MB,
1152 conv_param.IC,
1153 conv_param.OC,
1154 {conv_param.IN_DIM[SPATIAL_DIM - 2],
1155 conv_param.IN_DIM[SPATIAL_DIM - 1]},
1156 conv_param.G,
1157 {conv_param.K[SPATIAL_DIM - 2], conv_param.K[SPATIAL_DIM - 1]},
1158 {conv_param.stride[SPATIAL_DIM - 2],
1159 conv_param.stride[SPATIAL_DIM - 1]},
1160 {conv_param.pad[1],
1161 conv_param.pad[2],
1162 conv_param.pad[4],
1163 conv_param.pad[5]});
1164
1165 // Parallelization:
1166 int64_t batch_start = 0;
1167 int64_t batch_end = MB;
1168 int64_t oh_start = 0;
1169 int64_t oh_end = OH;
1170 if (MB >= num_threads) {
1171 fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end);
1172 } else {
1173 fbgemmPartition1D(thread_id, num_threads, OH, oh_start, oh_end);
1174 }
1175
1176 if (batch_start >= batch_end || oh_start >= oh_end) {
1177 // There is no work for this thread
1178 return;
1179 }
1180
1181#if defined(__x86_64__) || defined(__i386__) || \
1182 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1183 // generate convolution + rowOffset kernel
1184 bool calculateRowOffset = !b_symmetric;
1185 bool isTopEdgeIncluded = oh_start == 0;
1186 bool isBottomEdgeIncluded = oh_end == OH;
1187 bool isTopBottomEdgeSame =
1188 isTopEdgeIncluded && isBottomEdgeIncluded && oh_end == oh_start + 1;
1189 jit_conv_kernel_fp fpConvNoAccum = getOrCreateConvKernel<2>(
1190 conv_p_2d,
1191 a_zero_point,
1192 calculateRowOffset,
1193 isTopEdgeIncluded,
1194 isBottomEdgeIncluded,
1195 isTopBottomEdgeSame,
1196 false);
1197 jit_conv_kernel_fp fpConvAccum = getOrCreateConvKernel<2>(
1198 conv_p_2d,
1199 a_zero_point,
1200 calculateRowOffset,
1201 isTopEdgeIncluded,
1202 isBottomEdgeIncluded,
1203 isTopBottomEdgeSame,
1204 true);
1205 jit_conv_kernel_fp fpConv;
1206#endif
1207
1208 int ih_start = 0;
1209 if (oh_start > 0) {
1210 ih_start = -conv_p_2d.pad[0] + oh_start * conv_p_2d.stride[0];
1211 }
1212
1213 vector<uint8_t> zero_points(IH * IW * IC, a_zero_point);
1214 int32_t* out_start = outBuffer + oh_start * OW * OC;
1215 const uint8_t* in_start = activations + ih_start * IW * IC;
1216 int32_t* rowOffsetBuf_start =
1217 rowOffsetBuf ? rowOffsetBuf + oh_start * OW * G_together : nullptr;
1218 for (int i = batch_start; i < batch_end; ++i) {
1219 const uint8_t* in_start_batch = in_start + i * IT_IH_IW * IC;
1220 int32_t* out_start_batch = out_start + i * OT_OH_OW * OC;
1221 int32_t* rowOffsetBuf_start_batch = rowOffsetBuf
1222 ? rowOffsetBuf_start + i * OT_OH_OW * G_together
1223 : nullptr;
1224 for (int g = 0; g < G; g += G_together) {
1225 const uint8_t* in_start_group = in_start_batch + g * C_per_G;
1226 int8_t* weight_start =
1227 packed_weights.getBuf() + g * T * R * S * K_per_G * paddedCPerG;
1228 int32_t* out_start_group = out_start_batch;
1229 int32_t* rowOffsetBuf_start_group = rowOffsetBuf_start_batch;
1230 // Uncomment the following two lines to stop
1231 // reuse of output and rowoffset buffer
1232 // out_start_group = out_start_batch + g * K_per_G;
1233 // rowOffsetBuf_start_group = rowOffsetBuf_start_batch + g * MB *
1234 // OT_OH_OW;
1235
1236 for (int ot = 0; ot < OT; ++ot) {
1237 int32_t* out_start_t = out_start_group + ot * OH_OW * OC;
1238 int32_t* rowOffsetBuf_start_t = rowOffsetBuf
1239 ? rowOffsetBuf_start_group + ot * OH_OW * G_together
1240 : nullptr;
1241 for (int t = 0; t < T; ++t) {
1242 int t_in = -conv_param.pad[0] + ot * conv_param.stride[0] + t;
1243 const uint8_t* in_start_t = in_start_group + t_in * IH_IW * IC;
1244 int8_t* weight_start_t =
1245 weight_start + t * R * S * K_per_G * G_together * paddedCPerG;
1246 if (t_in < 0 || t_in >= IT) {
1247 in_start_t = zero_points.data();
1248 }
1249 // exactly the same compute as the JIT'ed below
1250 // kernel_compute(
1251 // conv_p_2d,
1252 // in_start_t,
1253 // weight_start_t,
1254 // out_start_t,
1255 // a_zero_point,
1256 // oh_start,
1257 // oh_end,
1258 // OW,
1259 // rowOffsetBuf_start_t,
1260 // t > 0);
1261
1262#if defined(__x86_64__) || defined(__i386__) || \
1263 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1264 fpConv = t > 0 ? fpConvAccum : fpConvNoAccum;
1265 fpConv(
1266 in_start_t,
1267 weight_start_t,
1268 out_start_t,
1269 a_zero_point,
1270 oh_start,
1271 oh_end,
1272 OW,
1273 rowOffsetBuf_start_t);
1274#else
1275 kernel_compute(
1276 conv_p_2d,
1277 in_start_t,
1278 weight_start_t,
1279 out_start_t,
1280 a_zero_point,
1281 oh_start,
1282 oh_end,
1283 OW,
1284 rowOffsetBuf_start_t,
1285 t > 0);
1286#endif
1287 }
1288
1289 const int32_t* inp = out_start_t;
1290 block_type_t block{
1291 static_cast<int>(i * OT_OH_OW + oh_start * OW),
1292 static_cast<int>((oh_end - oh_start) * OW),
1293 g * K_per_G,
1294 G_together * K_per_G};
1295 int ld_out = G * K_per_G;
1296 int ld_in = G * K_per_G;
1297
1298 dispatchOutputProcessing(
1299 outProcess,
1300 rowOffsetBuf_start_t,
1301 out + ot * OH_OW * OC,
1302 inp,
1303 block,
1304 ld_out,
1305 ld_in,
1306 G,
1307 C_per_G,
1308 is_requantization<processOutputType>());
1309 } // for each ot
1310 } // for each g
1311 } // for each i
1312 } // SPATIAL_DIM == 3
1313}
1314
1315template <int SPATIAL_DIM>
1316int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
1317 // row offset buffer should be a able to hold row offsets for however
1318 // number of groups we process at a time.
1319 if (cpuinfo_initialize()) {
1320 int OT = SPATIAL_DIM <= 2 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 3];
1321 int OH = SPATIAL_DIM == 1 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 2];
1322 int bufferSize = OT * OH * conv_param.OUT_DIM[SPATIAL_DIM - 1];
1323 if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support() ||
1324 fbgemmHasArmNeonSupport()) {
1325 return conv_param.MB * bufferSize * conv_param.G;
1326 } else {
1327 // TODO: Have default slower path
1328 assert(0 && "unsupported architecture");
1329 return -1;
1330 }
1331 } else {
1332 throw runtime_error("Failed to initialize cpuinfo!");
1333 }
1334}
1335
1336template FBGEMM_API int rowOffsetBufferSizeGConv<1>(
1337 const conv_param_t<1>& conv_param);
1338template FBGEMM_API int rowOffsetBufferSizeGConv<2>(
1339 const conv_param_t<2>& conv_param);
1340template FBGEMM_API int rowOffsetBufferSizeGConv<3>(
1341 const conv_param_t<3>& conv_param);
1342
1343#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, BIAS_TYPE) \
1344 template FBGEMM_API void fbgemmGroupwiseConv( \
1345 const conv_param_t<SPATIAL_DIM>& conv_param, \
1346 const uint8_t* activations, \
1347 int32_t a_zero_point, \
1348 int32_t* rowOffsetBuf, \
1349 PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>& packed_weights, \
1350 uint8_t* out, \
1351 int32_t* outBuffer, \
1352 const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
1353 int thread_id, \
1354 int num_threads);
1355
1356#define INSTANTIATE_BIAS_T(RELU, Q_GRAN, SPATIAL_DIM) \
1357 INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, float) \
1358 INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, int32_t)
1359
1360#define INSTANTIATE_SPATIAL_DIM(RELU, Q_GRAN) \
1361 INSTANTIATE_BIAS_T(RELU, Q_GRAN, 1) \
1362 INSTANTIATE_BIAS_T(RELU, Q_GRAN, 2) \
1363 INSTANTIATE_BIAS_T(RELU, Q_GRAN, 3)
1364
1365#define INSTANTIATE_Q_GRANS(RELU) \
1366 INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::TENSOR) \
1367 INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::GROUP) \
1368 INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::OUT_CHANNEL)
1369
1370INSTANTIATE_Q_GRANS(false);
1371INSTANTIATE_Q_GRANS(true);
1372
1373#undef INSTANTIATE_Q_GRANS
1374#undef INSTANTIATE_SPATIAL_DIM
1375#undef INSTANTIATE_BIAS_T
1376#undef INSTANTIATE_BASE
1377
1378/*
1379template void fbgemmGroupwiseConv(
1380 const conv_param_t<2>& conv_param,
1381 const uint8_t* activations,
1382 int32_t a_zero_point,
1383 int32_t* rowOffsetBuf,
1384 PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights,
1385 int32_t* out,
1386 int32_t* outBuffer,
1387 const DoNothing<int32_t, int32_t>& outProcess,
1388 int thread_id,
1389 int num_threads);
1390*/
1391
1392} // namespace fbgemm
1393