1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "./GroupwiseConv.h" |
9 | #include <asmjit/asmjit.h> |
10 | #include <cpuinfo.h> |
11 | #include <array> |
12 | #include <iostream> |
13 | #include <map> |
14 | #include <stdexcept> |
15 | #include <tuple> |
16 | #include <type_traits> |
17 | #include "./CodeGenHelpers.h" |
18 | #include "./RefImplementations.h" |
19 | #include "./TransposeUtils.h" |
20 | #include "fbgemm/Fbgemm.h" |
21 | #include "fbgemm/QuantUtilsAvx512.h" |
22 | #include "fbgemm/SimdUtils.h" |
23 | |
24 | namespace fbgemm { |
25 | |
26 | using namespace std; |
27 | |
28 | template <int SPATIAL_DIM> |
29 | void calculateRowOffsets( |
30 | const conv_param_t<SPATIAL_DIM>& conv_param, |
31 | const uint8_t* activations, |
32 | int32_t* rowOffsetBuf, |
33 | int32_t a_zero_point, |
34 | int groupNum) { |
35 | int OH = conv_param.OUT_DIM[0]; |
36 | int OW = conv_param.OUT_DIM[1]; |
37 | int IH = conv_param.IN_DIM[0]; |
38 | int IW = conv_param.IN_DIM[1]; |
39 | int G = conv_param.G; |
40 | int C_per_G = conv_param.IC / conv_param.G; |
41 | int H_PAD = conv_param.pad[0]; |
42 | int W_PAD = conv_param.pad[1]; |
43 | // calculate row offset |
44 | for (int h = 0; h < OH; ++h) { |
45 | for (int w = 0; w < OW; ++w) { |
46 | int32_t sum = 0; |
47 | for (int r = 0; r < conv_param.K[0]; ++r) { |
48 | int h_in = -H_PAD + h * conv_param.stride[0] + r; |
49 | for (int s = 0; s < conv_param.K[1]; ++s) { |
50 | int w_in = -W_PAD + w * conv_param.stride[1] + s; |
51 | for (int c = 0; c < C_per_G; ++c) { |
52 | int a_val; |
53 | if (h_in < 0 || h_in >= IH || w_in < 0 || w_in >= IW) { |
54 | a_val = a_zero_point; |
55 | } else { |
56 | a_val = activations |
57 | [((h_in * IW + w_in) * G + groupNum) * C_per_G + c]; |
58 | } |
59 | sum += a_val; |
60 | } |
61 | } |
62 | } |
63 | rowOffsetBuf[h * OW + w] = sum; |
64 | } |
65 | } |
66 | } |
67 | |
68 | template <int SPATIAL_DIM = 2> |
69 | kernel_sig_t getKernelSig( |
70 | const conv_param_t<SPATIAL_DIM>& conv_param, |
71 | bool isAZeroPointZero, |
72 | bool needRowOffset, |
73 | bool isTopEdgeIncluded, |
74 | bool isBottomEdgeIncluded, |
75 | bool isTopBottomEdgeSame, |
76 | bool accum) { |
77 | // kernel is specialized on number of input channels per group, number of |
78 | // output channels per group, whether stride is 1 or stride is 2, whether or |
79 | // not zero point for activations is 0 or not, whether or not row offset |
80 | // calculations are needed, whether or not top edge is included and whether or |
81 | // not bottom edge is included. |
82 | // use_padding_: If false, the right padding on the width side and bottom |
83 | // padding on height side are not used for the case of stride = 2 |
84 | // accum: accumulate results for output and rowoffset |
85 | int C_per_G = conv_param.IC / conv_param.G; |
86 | int K_per_G = conv_param.OC / conv_param.G; |
87 | auto kernelSig = make_tuple( |
88 | isAZeroPointZero, |
89 | needRowOffset, |
90 | isTopEdgeIncluded, |
91 | isBottomEdgeIncluded, |
92 | isTopBottomEdgeSame, |
93 | !(conv_param.stride[SPATIAL_DIM - 2] > 1 && |
94 | conv_param.IN_DIM[SPATIAL_DIM - 2] % 2 == 0), |
95 | !(conv_param.stride[SPATIAL_DIM - 1] > 1 && |
96 | conv_param.IN_DIM[SPATIAL_DIM - 1] % 2 == 0), |
97 | accum, |
98 | conv_param.G, |
99 | conv_param.stride[0], |
100 | C_per_G, |
101 | K_per_G); |
102 | return kernelSig; |
103 | } |
104 | |
105 | template <int SPATIAL_DIM = 2> |
106 | jit_conv_kernel_fp getOrCreateConvKernel( |
107 | const conv_param_t<SPATIAL_DIM>& conv_param, |
108 | int a_zero_point, |
109 | bool needRowOffset, |
110 | bool isTopEdgeIncluded, |
111 | bool isBottomEdgeIncluded, |
112 | bool isTopBottomEdgeSame, |
113 | bool accum) { |
114 | // Note: Wrong code is generated if it's not one of the supported convolution |
115 | assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)); |
116 | auto kernelSig = getKernelSig( |
117 | conv_param, |
118 | a_zero_point == 0, |
119 | needRowOffset, |
120 | isTopEdgeIncluded, |
121 | isBottomEdgeIncluded, |
122 | isTopBottomEdgeSame, |
123 | accum); |
124 | |
125 | if (cpuinfo_initialize()) { |
126 | if (fbgemmHasAvx512VnniSupport()) { |
127 | return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>::codeCache_ |
128 | .getOrCreate(kernelSig, [&]() { |
129 | auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>( |
130 | conv_param, |
131 | a_zero_point, |
132 | needRowOffset, |
133 | isTopEdgeIncluded, |
134 | isBottomEdgeIncluded, |
135 | isTopBottomEdgeSame, |
136 | accum); |
137 | return genObj.getOrCreate(); |
138 | }); |
139 | } else if (fbgemmHasAvx512Support()) { |
140 | return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512>::codeCache_ |
141 | .getOrCreate(kernelSig, [&]() { |
142 | auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx512>( |
143 | conv_param, |
144 | a_zero_point, |
145 | needRowOffset, |
146 | isTopEdgeIncluded, |
147 | isBottomEdgeIncluded, |
148 | isTopBottomEdgeSame, |
149 | accum); |
150 | return genObj.getOrCreate(); |
151 | }); |
152 | } else if (fbgemmHasAvx2Support()) { |
153 | return GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>::codeCache_ |
154 | .getOrCreate(kernelSig, [&]() { |
155 | auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>( |
156 | conv_param, |
157 | a_zero_point, |
158 | needRowOffset, |
159 | isTopEdgeIncluded, |
160 | isBottomEdgeIncluded, |
161 | isTopBottomEdgeSame, |
162 | accum); |
163 | return genObj.getOrCreate(); |
164 | }); |
165 | } else { |
166 | // TODO: Have default slower path |
167 | assert(0 && "unsupported architecture" ); |
168 | } |
169 | } else { |
170 | throw runtime_error("Failed to initialize cpuinfo!" ); |
171 | } |
172 | return nullptr; |
173 | } |
174 | |
175 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
176 | jit_conv_kernel_fp GenConvKernel<SPATIAL_DIM, INST_SET>::getOrCreate() { |
177 | asmjit::CodeHolder code; |
178 | code.init(this->runtime().environment()); |
179 | x86::Assembler assembler(&code); |
180 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
181 | |
182 | typedef typename simd_info<INST_SET>::vec_reg_t vec_reg_t; |
183 | #if defined(FBGEMM_LOG_CODE) |
184 | auto kernelSig = make_tuple( |
185 | this->isAZeroPointZero_, |
186 | this->needRowOffset_, |
187 | this->isTopEdgeIncluded_, |
188 | this->isBottomEdgeIncluded_, |
189 | this->use_bottom_padding_, |
190 | this->use_right_padding_, |
191 | this->accum_, |
192 | this->G_, |
193 | this->STRIDE_, |
194 | this->C_per_G_, |
195 | this->K_per_G_); |
196 | // log code to a file |
197 | FILE* codeLogfile = fopen(this->getCodeLoggingFile(kernelSig).c_str(), "w" ); |
198 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); |
199 | if (codeLogger) { |
200 | code.setLogger(codeLogger); |
201 | } |
202 | #endif |
203 | |
204 | // arguments to the function created |
205 | in_acts_R_ = a->zdi(); |
206 | wghts_R_ = a->zsi(); |
207 | out_acts_R_ = a->zdx(); |
208 | a_zero_pt_R_ = a->zcx(); |
209 | H_start_R_ = a->gpz(8); |
210 | H_end_R_ = a->gpz(9); |
211 | W_R_ = a->gpz(10); |
212 | row_offset_R_ = a->gpz(11); |
213 | |
214 | // register for temporary use |
215 | scratchReg1_ = a->gpz(12); |
216 | scratchReg2_ = a->gpz(13); |
217 | |
218 | func_.init( |
219 | asmjit::FuncSignatureT< |
220 | void, |
221 | uint8_t*, |
222 | int8_t*, |
223 | int32_t*, |
224 | int32_t, |
225 | int32_t, |
226 | int32_t, |
227 | int32_t, |
228 | int32_t*>(asmjit::CallConvId::kHost), |
229 | a->environment()); |
230 | |
231 | frame_.init(func_); |
232 | |
233 | frame_.setDirtyRegs( |
234 | asmjit::RegGroup::kVec, |
235 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
236 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
237 | frame_.setDirtyRegs( |
238 | asmjit::RegGroup::kGp, |
239 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
240 | |
241 | asmjit::FuncArgsAssignment args(&func_); |
242 | args.assignAll( |
243 | in_acts_R_, |
244 | wghts_R_, |
245 | out_acts_R_, |
246 | a_zero_pt_R_, |
247 | H_start_R_, |
248 | H_end_R_, |
249 | W_R_, |
250 | row_offset_R_); |
251 | |
252 | args.updateFuncFrame(frame_); |
253 | frame_.finalize(); |
254 | |
255 | a->emitProlog(frame_); |
256 | a->emitArgsAssignment(frame_, args); |
257 | |
258 | // We have run out of register so can't keep |
259 | // this in a register. It's generated again at |
260 | // each use. Only used for the case of C_per_G == 2 or 4 |
261 | // gen8BitVectorOne(a, oneReg8Bit_V_); |
262 | gen16BitVectorOne<INST_SET, vec_reg_t>(a, oneReg16Bit_V_); |
263 | |
264 | loopR1_ = a->gpz(14); |
265 | loopR2_ = a->gpz(15); |
266 | |
267 | if (!this->isAZeroPointZero_) { |
268 | broadcast8Bit<vec_reg_t>(a, a_zero_pt_R_, zeroPTReg_V_); |
269 | } |
270 | |
271 | genConstForPermutations(a); |
272 | |
273 | genForLoadingWeights(a); |
274 | |
275 | // W_R_ is an input to the JIT'ed kernel and is output image width. |
276 | // W_R_ is passed in using stack. We reload it inside kernel where we need it. |
277 | // The following logic calculates the input image width in the same register. |
278 | // Only works for stride == 2 |
279 | if (this->STRIDE_ > 1) { |
280 | a->imul(W_R_, W_R_, static_cast<asmjit::Imm>(this->STRIDE_)); |
281 | if (!this->use_right_padding_) { |
282 | a->inc(W_R_); |
283 | } |
284 | a->sub(W_R_, static_cast<asmjit::Imm>(this->STRIDE_ - 1)); |
285 | } |
286 | |
287 | if (this->isTopEdgeIncluded_) { |
288 | genForTopOrBottomEdge( |
289 | a, |
290 | true /* isTopEdge */, |
291 | this->isTopBottomEdgeSame_ && this->use_bottom_padding_); |
292 | } |
293 | genCoreInsts(a); |
294 | if (this->isBottomEdgeIncluded_ && !this->isTopBottomEdgeSame_) { |
295 | genForTopOrBottomEdge( |
296 | a, false /* isTopEdge */, this->use_bottom_padding_ /* isBottomEdge */); |
297 | } |
298 | |
299 | a->emitEpilog(frame_); |
300 | |
301 | jit_conv_kernel_fp fn; |
302 | asmjit::Error err; |
303 | { |
304 | unique_lock<mutex> lock(this->rtMutex_); |
305 | err = this->runtime().add(&fn, &code); |
306 | } |
307 | |
308 | if (err) { |
309 | cout << "Error: in fn add" << endl; |
310 | return nullptr; |
311 | } |
312 | |
313 | #if defined(FBGEMM_LOG_CODE) |
314 | fclose(codeLogfile); |
315 | delete codeLogger; |
316 | #endif |
317 | |
318 | return fn; |
319 | } |
320 | |
321 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
322 | void GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleOutput( |
323 | x86::Emitter* a, |
324 | bool isLeft, |
325 | bool isRight, |
326 | bool isTop, |
327 | bool isBottom) { |
328 | // init result regs |
329 | initResultRegs(a); |
330 | |
331 | // row offset |
332 | if (this->needRowOffset_) { |
333 | a->vpxor( |
334 | rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm()); |
335 | } |
336 | |
337 | bool isWidthMiddle = !isLeft && !isRight; |
338 | bool isHeightMiddle = !isTop && !isBottom; |
339 | int num_rows_advanced = 0; |
340 | for (int r = 0; r < this->R_; ++r) { |
341 | int h_in = r; |
342 | if (isTop) { |
343 | h_in = -this->H_PAD_ + r; |
344 | } |
345 | bool in_image_H = (isTop && !isBottom && h_in >= 0) || |
346 | (!isTop && isBottom && h_in < (this->R_ - this->H_PAD_)) || |
347 | (isTop && isBottom && h_in >= 0 && |
348 | h_in < (this->R_ - 2 * this->H_PAD_)) || |
349 | isHeightMiddle; |
350 | for (int s = 0; s < this->S_; ++s) { |
351 | int w_in = s; |
352 | if (isLeft) { |
353 | w_in = -this->W_PAD_ + s; |
354 | } |
355 | bool in_image_W = (isLeft && !isRight && w_in >= 0) || |
356 | (!isLeft && isRight && w_in < (this->S_ - this->W_PAD_)) || |
357 | (isLeft && isRight && w_in >= 0 && |
358 | w_in < (this->S_ - 2 * this->W_PAD_)) || |
359 | isWidthMiddle; |
360 | if (in_image_H && in_image_W) { |
361 | genForSingleFilterPoint(a, r, s, w_in, false); |
362 | } else { |
363 | if (!this->isAZeroPointZero_) { |
364 | genForSingleFilterPoint(a, r, s, w_in, true); |
365 | } |
366 | } |
367 | } |
368 | if (in_image_H) { |
369 | // advance input pointer by one row |
370 | a->imul( |
371 | scratchReg2_, |
372 | W_R_, |
373 | static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t))); |
374 | a->add(in_acts_R_, scratchReg2_); |
375 | ++num_rows_advanced; |
376 | } |
377 | } |
378 | |
379 | storeResult(a); |
380 | |
381 | // row offset |
382 | if (this->needRowOffset_) { |
383 | storeOffset(a); |
384 | a->add( |
385 | row_offset_R_, static_cast<asmjit::Imm>(GTogether_ * sizeof(int32_t))); |
386 | } |
387 | |
388 | // rewind input ptr |
389 | a->imul( |
390 | scratchReg2_, |
391 | W_R_, |
392 | static_cast<asmjit::Imm>(num_rows_advanced * this->C_ * sizeof(uint8_t))); |
393 | a->sub(in_acts_R_, scratchReg2_); |
394 | |
395 | // advance output pointer |
396 | a->add(out_acts_R_, static_cast<asmjit::Imm>(this->K_ * sizeof(int32_t))); |
397 | |
398 | // advance input ptr |
399 | if (!isLeft) { |
400 | a->add( |
401 | in_acts_R_, |
402 | static_cast<asmjit::Imm>(this->STRIDE_ * this->C_ * sizeof(uint8_t))); |
403 | } else if (this->STRIDE_ - this->W_PAD_) { |
404 | a->add( |
405 | in_acts_R_, |
406 | static_cast<asmjit::Imm>( |
407 | (this->STRIDE_ - this->W_PAD_) * this->C_ * sizeof(uint8_t))); |
408 | } |
409 | } |
410 | |
411 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
412 | void GenConvKernel<SPATIAL_DIM, INST_SET>::genForTopOrBottomEdge( |
413 | x86::Emitter* a, |
414 | bool isTopEdge, |
415 | bool isBottomEdge) { |
416 | // Output width was passed in as the 7th argument (i.e., using stack). |
417 | // Reload it from the same location. |
418 | a->movsxd( |
419 | loopR1_, |
420 | x86::dword_ptr( |
421 | x86::rsp, frame_.saOffsetFromSP() + func_.arg(6).stackOffset())); |
422 | asmjit::Label LoopWStart = a->newLabel(); |
423 | asmjit::Label LoopWEnd = a->newLabel(); |
424 | asmjit::Label skipRightEdge = a->newLabel(); |
425 | asmjit::Label skipRightEdgeTemp = a->newLabel(); |
426 | a->cmp(loopR1_, static_cast<asmjit::Imm>(this->W_PAD_)); |
427 | a->jle(skipRightEdgeTemp); |
428 | |
429 | // left corner code |
430 | genForSingleOutput( |
431 | a, |
432 | true, // isLeft |
433 | false, // isRight |
434 | isTopEdge, // isTop |
435 | isBottomEdge // isBotom |
436 | ); |
437 | a->jmp(LoopWStart); |
438 | |
439 | a->bind(skipRightEdgeTemp); |
440 | // top-left corner code |
441 | genForSingleOutput( |
442 | a, |
443 | true, // isLeft |
444 | this->use_right_padding_, // isRight |
445 | isTopEdge, // isTop |
446 | isBottomEdge // isBotom |
447 | ); |
448 | a->jmp(skipRightEdge); |
449 | |
450 | // edge excluding corners |
451 | a->bind(LoopWStart); |
452 | |
453 | a->cmp(loopR1_, static_cast<asmjit::Imm>(2 * this->W_PAD_)); |
454 | a->jle(LoopWEnd); |
455 | |
456 | genForSingleOutput( |
457 | a, |
458 | false, // isLeft |
459 | false, // isRight |
460 | isTopEdge, // isTop |
461 | isBottomEdge // isBotom |
462 | ); |
463 | |
464 | a->dec(loopR1_); |
465 | a->jmp(LoopWStart); |
466 | a->bind(LoopWEnd); |
467 | |
468 | // top-right corner code |
469 | genForSingleOutput( |
470 | a, |
471 | false, // isLeft |
472 | this->use_right_padding_, // isRight |
473 | isTopEdge, // isTop |
474 | isBottomEdge // isBottom |
475 | ); |
476 | |
477 | a->bind(skipRightEdge); |
478 | |
479 | if (this->STRIDE_ > 1) { |
480 | // STRIDE_ == 2 and even widths, |
481 | // We increase it by C_; |
482 | // STRIDE_ == 2 and odd widths, nothing to do |
483 | // input ptr is already at the right position |
484 | if (!this->use_right_padding_) { |
485 | a->add(in_acts_R_, static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t))); |
486 | } |
487 | } else { |
488 | // reset input activation pointer by (W_R_ - W_PAD_) * C_ |
489 | a->mov(scratchReg2_, W_R_); |
490 | a->imul(scratchReg2_, static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t))); |
491 | a->sub( |
492 | scratchReg2_, |
493 | static_cast<asmjit::Imm>(this->W_PAD_ * this->C_ * sizeof(uint8_t))); |
494 | a->sub(in_acts_R_, scratchReg2_); |
495 | } |
496 | } |
497 | |
498 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
499 | void GenConvKernel<SPATIAL_DIM, INST_SET>::genCoreInsts(x86::Emitter* a) { |
500 | // Top edge and bottom edge calculations are done separately |
501 | // so start from next and leave out the last |
502 | if (this->isTopEdgeIncluded_) { |
503 | a->inc(H_start_R_); |
504 | } |
505 | if (this->isBottomEdgeIncluded_) { |
506 | a->dec(H_end_R_); |
507 | } |
508 | // main compute |
509 | asmjit::Label LoopHStart = a->newLabel(); |
510 | asmjit::Label LoopHEnd = a->newLabel(); |
511 | asmjit::Label LoopWStart = a->newLabel(); |
512 | asmjit::Label LoopWEnd = a->newLabel(); |
513 | |
514 | // H loop |
515 | a->mov(loopR1_, H_start_R_); |
516 | a->jmp(LoopHEnd); |
517 | a->bind(LoopHStart); |
518 | a->inc(loopR1_); |
519 | |
520 | a->movsxd( |
521 | loopR2_, |
522 | x86::dword_ptr( |
523 | x86::rsp, frame_.saOffsetFromSP() + func_.arg(6).stackOffset())); |
524 | asmjit::Label skipRightEdge = a->newLabel(); |
525 | asmjit::Label skipRightEdgeTemp = a->newLabel(); |
526 | a->cmp(loopR2_, static_cast<asmjit::Imm>(this->W_PAD_)); |
527 | a->jle(skipRightEdgeTemp); |
528 | |
529 | genForSingleOutput( |
530 | a, |
531 | true, // isLeft, |
532 | false, // isRight |
533 | false, // isTop |
534 | false // isBottom |
535 | ); |
536 | a->jmp(LoopWStart); |
537 | |
538 | a->bind(skipRightEdgeTemp); |
539 | genForSingleOutput( |
540 | a, |
541 | true, // isLeft, |
542 | this->use_right_padding_, // isRight |
543 | false, // isTop |
544 | false // isBottom |
545 | ); |
546 | a->jmp(skipRightEdge); |
547 | |
548 | // W loop |
549 | a->bind(LoopWStart); |
550 | |
551 | a->cmp(loopR2_, static_cast<asmjit::Imm>(2 * this->W_PAD_)); |
552 | a->jle(LoopWEnd); |
553 | |
554 | genForSingleOutput( |
555 | a, |
556 | false, // isLeft, |
557 | false, // isRight |
558 | false, // isTop |
559 | false // isBottom |
560 | ); |
561 | |
562 | a->dec(loopR2_); |
563 | a->jmp(LoopWStart); |
564 | a->bind(LoopWEnd); |
565 | |
566 | genForSingleOutput( |
567 | a, |
568 | false, // isLeft |
569 | this->use_right_padding_, // isRight |
570 | false, // isTop |
571 | false // isBottom |
572 | ); |
573 | |
574 | a->bind(skipRightEdge); |
575 | |
576 | if (this->STRIDE_ > 1) { |
577 | // STRIDE_ == 2 and even widths, |
578 | // We increase it by extra C_; |
579 | // STRIDE_ == 2 and odd widths, no extra C_ |
580 | assert(this->STRIDE_ == 2 && "Not supported case" ); |
581 | a->mov(scratchReg2_, W_R_); |
582 | if (!this->use_right_padding_) { |
583 | a->add(scratchReg2_, static_cast<asmjit::Imm>(1)); |
584 | } |
585 | a->imul(scratchReg2_, static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t))); |
586 | a->add(in_acts_R_, scratchReg2_); |
587 | } else { |
588 | a->add(in_acts_R_, static_cast<asmjit::Imm>(this->C_ * sizeof(uint8_t))); |
589 | } |
590 | |
591 | a->bind(LoopHEnd); |
592 | a->cmp(loopR1_, H_end_R_); |
593 | a->jl(LoopHStart); |
594 | } |
595 | |
596 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
597 | void GenConvKernel<SPATIAL_DIM, INST_SET>::initResultRegs(x86::Emitter* a) { |
598 | if (kLoopIters_ > 0) { |
599 | // Take advantage of implicit zeroing out |
600 | // i.e., zero out xmm and ymm and zmm will be zeroed out too |
601 | for (int k = 0; k < kLoopIters_; ++k) { |
602 | a->vpxor(x86::Xmm(9 - k), x86::Xmm(9 - k), x86::Xmm(9 - k)); |
603 | } |
604 | } else { |
605 | a->vpxor(x86::Xmm(9), x86::Xmm(9), x86::Xmm(9)); |
606 | } |
607 | } |
608 | |
609 | /* |
610 | namespace { |
611 | |
612 | template < |
613 | typename packed_W, |
614 | typename outType, |
615 | typename processOutputType, |
616 | int SPATIAL_DIM> |
617 | void fbgemmGroupwiseConvBase_( |
618 | const conv_param_t<SPATIAL_DIM>& conv_param, |
619 | const uint8_t* activations, |
620 | int32_t a_zero_point, |
621 | int32_t* rowOffsetBuf, |
622 | packed_W& packed_weights, |
623 | outType* out, |
624 | int32_t* outBuffer, |
625 | const processOutputType& outProcess, |
626 | int thread_id, |
627 | int num_threads) { |
628 | int MB = conv_param.MB; |
629 | int H = conv_param.OUT_DIM[0]; |
630 | int W = conv_param.OUT_DIM[1]; |
631 | int G = conv_param.G; |
632 | int K_per_G = conv_param.OC / G; |
633 | int C_per_G = conv_param.IC / G; |
634 | int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; |
635 | int ih_iw = conv_param.IN_DIM[0] * conv_param.IN_DIM[1]; |
636 | |
637 | assert(SPATIAL_DIM == 2 && "3D groupwise conv not supported yet"); |
638 | |
639 | int32_t* rowOffsetTrDest = rowOffsetBuf ? rowOffsetBuf + 8 * ih_iw : nullptr; |
640 | if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) { |
641 | assert(G % 8 == 0); |
642 | // generate convolution kernel |
643 | jit_conv_kernel_fp fpConv = |
644 | getOrCreateConvKernel<SPATIAL_DIM>(conv_param, a_zero_point); |
645 | // generate row offset kernel |
646 | jit_rowoffset_kernel_fp fpRowoffset = |
647 | getOrCreateRowOffsetKernel(conv_param, a_zero_point); |
648 | for (int i = 0; i < MB; ++i) { |
649 | const uint8_t* actStartBatch = activations + i * ih_iw * conv_param.IC; |
650 | for (int gOuter = 0; gOuter < G; gOuter += 8) { |
651 | // row offset is calcualted for 8 groups at a time. |
652 | // The result is row offsets in the format IH*IW x G |
653 | if (rowOffsetBuf) { |
654 | fpRowoffset( |
655 | actStartBatch + gOuter * C_per_G, |
656 | a_zero_point, |
657 | H, |
658 | W, |
659 | rowOffsetBuf); |
660 | // Transpose to get row offsets in the format G x IH*IW |
661 | internal::transpose_avx2( |
662 | ih_iw, |
663 | 8, |
664 | reinterpret_cast<const float*>(rowOffsetBuf), |
665 | 8, |
666 | reinterpret_cast<float*>(rowOffsetTrDest), |
667 | ih_iw); |
668 | } |
669 | int gLimit = gOuter + 8; |
670 | // Work on 8 output channels at a time (8 * sizeof(int32_t) == 32B VLEN |
671 | // of AVX2), and we need multiple groups if a group has not enough |
672 | // number of channels. |
673 | int gDelta = std::max(8 / C_per_G, 1); |
674 | for (int g = gOuter; g < gLimit; g += gDelta) { |
675 | int32_t* currOutBuf = |
676 | outBuffer + i * oh_ow * conv_param.OC + g * K_per_G; |
677 | const uint8_t* actStartGroup = actStartBatch + g * C_per_G; |
678 | for (int k = 0; k < K_per_G; k += 8) { |
679 | // Don't be confused with k above which refers to output channels. |
680 | // k0 and k1 are filter dimensions (commonly 3 and 3) |
681 | int k0 = conv_param.K[0]; |
682 | int k1 = conv_param.K[1]; |
683 | fpConv( |
684 | actStartGroup, |
685 | // packed weight is in G (C/4) R S K 4 layout for IC_per_G >= 8 |
686 | // in (G/2) R S K (2C) for IC_per_G == 4 |
687 | packed_weights.getBuf() + |
688 | (g * (C_per_G / 4) * k0 * k1 * K_per_G + k) * 4, |
689 | currOutBuf + k, |
690 | a_zero_point, |
691 | H, |
692 | W); |
693 | } // k loop |
694 | |
695 | // Output processing should be called for each group |
696 | for (int j = 0; j < gDelta; ++j) { |
697 | // calculateRowOffsets( |
698 | // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j); |
699 | int32_t* rowOffsetForCurG = rowOffsetTrDest |
700 | ? rowOffsetTrDest + ((g - gOuter) + j) * ih_iw |
701 | : nullptr; |
702 | // compare_buffers(rowOffsetBuf, rowOffsetForCurG, |
703 | // conv_param.IN_DIM[0]*conv_param.IN_DIM[1], 1, 1, 100); |
704 | |
705 | // outProcess expects rowOffsetBuf to contain row offsets for the |
706 | // current group |
707 | memcpy(rowOffsetBuf, rowOffsetForCurG, ih_iw * sizeof(int32_t)); |
708 | |
709 | if (fbgemmHasAvx512Support()) { |
710 | // Currently use avx2 code |
711 | outProcess.template f<inst_set_t::avx2>( |
712 | out, |
713 | currOutBuf + j * K_per_G, |
714 | {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G}, |
715 | K_per_G * G, |
716 | K_per_G * G); |
717 | } else if (fbgemmHasAvx2Support()) { |
718 | outProcess.template f<inst_set_t::avx2>( |
719 | out, |
720 | currOutBuf + j * K_per_G, |
721 | {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G}, |
722 | K_per_G * G, |
723 | K_per_G * G); |
724 | } else { |
725 | // TODO: Have default slower path |
726 | assert(0 && "unsupported architecure"); |
727 | } |
728 | } // j loop |
729 | } // g loop |
730 | } // gOuter loop |
731 | } // i loop |
732 | } else { |
733 | // for the not supported cases, just execute the naive C implementation |
734 | conv_ref( |
735 | conv_param, |
736 | activations, |
737 | a_zero_point, |
738 | packed_weights.getBuf(), |
739 | outBuffer); |
740 | for (int i = 0; i < conv_param.MB; ++i) { |
741 | for (int g = 0; g < conv_param.G; ++g) { |
742 | if (rowOffsetBuf) { |
743 | calculateRowOffsets( |
744 | conv_param, |
745 | activations + |
746 | i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * |
747 | conv_param.IC, |
748 | rowOffsetBuf, |
749 | a_zero_point, |
750 | g); |
751 | } |
752 | outProcess.template f<inst_set_t::anyarch>( |
753 | out, |
754 | outBuffer + i * oh_ow * conv_param.OC + g * K_per_G, |
755 | {i * oh_ow, oh_ow, g * K_per_G, K_per_G}, |
756 | K_per_G * G, |
757 | K_per_G * G); |
758 | } |
759 | } |
760 | } |
761 | } |
762 | |
763 | } // namespace |
764 | |
765 | template < |
766 | typename packed_W, |
767 | typename outType, |
768 | typename processOutputType, |
769 | int SPATIAL_DIM> |
770 | void fbgemmGroupwiseConv( |
771 | const conv_param_t<SPATIAL_DIM>& conv_param, |
772 | const uint8_t* activations, |
773 | int32_t a_zero_point, |
774 | int32_t* rowOffsetBuf, |
775 | packed_W& packed_weights, |
776 | outType* out, |
777 | int32_t* outBuffer, |
778 | const processOutputType& outProcess, |
779 | int thread_id, |
780 | int num_threads) { |
781 | // TODO: Remove this when threading is supported. |
782 | if (thread_id > 0) { |
783 | return; |
784 | } |
785 | |
786 | return fbgemmGroupwiseConvBase_< |
787 | packed_W, |
788 | outType, |
789 | processOutputType, |
790 | SPATIAL_DIM>( |
791 | conv_param, |
792 | activations, |
793 | a_zero_point, |
794 | rowOffsetBuf, |
795 | packed_weights, |
796 | out, |
797 | outBuffer, |
798 | outProcess, |
799 | thread_id, |
800 | num_threads); |
801 | } |
802 | |
803 | */ |
804 | |
805 | /* |
806 | * |
807 | * This function does exactly the same compute as the JIT'ed kernel |
808 | */ |
809 | template <int SPATIAL_DIM> |
810 | void kernel_compute( |
811 | const conv_param_t<SPATIAL_DIM>& conv_p, |
812 | const uint8_t* in_acts, |
813 | int8_t* wghts, |
814 | int32_t* out_acts, |
815 | int32_t a_zero_pt, |
816 | int32_t h_start, |
817 | int32_t h_end, |
818 | int32_t width, |
819 | int32_t* rowOffset, |
820 | bool accum) { |
821 | int IW = conv_p.IN_DIM[1]; |
822 | int IC = conv_p.IC; |
823 | int OC = conv_p.OC; |
824 | int G = conv_p.G; |
825 | int R = conv_p.K[0]; |
826 | int S = conv_p.K[1]; |
827 | int IC_per_G = conv_p.IC / G; |
828 | int OC_per_G = conv_p.OC / G; |
829 | int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>:: |
830 | numOfGroupsTogether(conv_p); |
831 | int paddedICPerG = (IC_per_G + 3) / 4 * 4; |
832 | for (int h = h_start; h < h_end; ++h) { |
833 | for (int w = 0; w < width; ++w) { |
834 | for (int g = 0; g < G_together; ++g) { |
835 | for (int k = 0; k < OC_per_G; ++k) { |
836 | int sum = 0; |
837 | int rowSum = 0; |
838 | for (int r = 0; r < R; ++r) { |
839 | int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; |
840 | for (int s = 0; s < S; ++s) { |
841 | int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; |
842 | for (int c = 0; c < IC_per_G; ++c) { |
843 | bool out_of_image = h_in < 0 || h_in >= conv_p.IN_DIM[0] || |
844 | w_in < 0 || w_in >= conv_p.IN_DIM[1]; |
845 | int h_index = h_in; |
846 | if (h_start > 0) { |
847 | h_index = (h - h_start) * conv_p.stride[1] + r; |
848 | } |
849 | int a = out_of_image |
850 | ? a_zero_pt |
851 | : in_acts[(h_index * IW + w_in) * IC + g * IC_per_G + c]; |
852 | int idx = (((r * S + s) * OC_per_G + k) * G_together + g) * |
853 | paddedICPerG + |
854 | c; |
855 | int b = wghts[idx]; |
856 | sum += a * b; |
857 | rowSum += a; |
858 | } |
859 | } |
860 | } |
861 | if (accum) { |
862 | out_acts[((h - h_start) * width + w) * OC + g * OC_per_G + k] += |
863 | sum; |
864 | if (k == 0 && rowOffset != nullptr) { |
865 | // only accumulate for k == 0 |
866 | rowOffset[((h - h_start) * width + w) * G_together + g] += rowSum; |
867 | } |
868 | } else { |
869 | out_acts[((h - h_start) * width + w) * OC + g * OC_per_G + k] = sum; |
870 | if (rowOffset != nullptr) { |
871 | rowOffset[((h - h_start) * width + w) * G_together + g] = rowSum; |
872 | } |
873 | } |
874 | } |
875 | } |
876 | } |
877 | } |
878 | } |
879 | |
880 | template <typename processOutputType, typename outT, typename inT> |
881 | void dispatchOutputProcessing( |
882 | const processOutputType& outProcess, |
883 | int32_t* rowOffsetBuf, |
884 | outT* out, |
885 | const inT* inp, |
886 | const block_type_t& block, |
887 | int ld_out, |
888 | int ld_in, |
889 | int groups, |
890 | int C_per_G, |
891 | true_type) { |
892 | constexpr QuantizationGranularity Q_GRAN = processOutputType::QGRANType; |
893 | constexpr int FUSE_RELU = processOutputType::RELU_FUSED; |
894 | bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR && |
895 | outProcess.getBZeroPoint()[0] == 0) || |
896 | rowOffsetBuf == nullptr; |
897 | int32_t a_zero_point = outProcess.getAZeroPoint(); |
898 | |
899 | // Requantization |
900 | requantizationParams_t<typename processOutputType::BIAS_T> r = { |
901 | a_zero_point, |
902 | outProcess.getBZeroPoint(), |
903 | outProcess.getCZeroPoint(), |
904 | outProcess.getCMultiplier(), |
905 | rowOffsetBuf, |
906 | outProcess.getColOffsets(), |
907 | outProcess.getBias(), |
908 | outProcess.getNCols(), |
909 | groups, |
910 | outProcess.getActWScale()}; |
911 | |
912 | #define REQUANTIZE_BASE(ISA, C_PER_G, A_SYM, B_SYM, BIAS) \ |
913 | requantizeOutputProcessingGConv##ISA< \ |
914 | A_SYM, \ |
915 | B_SYM, \ |
916 | Q_GRAN, \ |
917 | BIAS, \ |
918 | FUSE_RELU, \ |
919 | C_PER_G>(out, inp, block, ld_out, ld_in, r); |
920 | |
921 | #define REQUANTIZE_BIAS(ISA, C_PER_G, A_SYM, B_SYM) \ |
922 | if (outProcess.getBias() == nullptr) { \ |
923 | REQUANTIZE_BASE(ISA, C_PER_G, A_SYM, B_SYM, /*bias=*/false); \ |
924 | } else { \ |
925 | REQUANTIZE_BASE(ISA, C_PER_G, A_SYM, B_SYM, /*bias=*/true); \ |
926 | } |
927 | |
928 | #define REQUANTIZE_BSYM(ISA, C_PER_G, A_SYM) \ |
929 | if (b_symmetric) { \ |
930 | REQUANTIZE_BIAS(ISA, C_PER_G, A_SYM, true); \ |
931 | } else { \ |
932 | REQUANTIZE_BIAS(ISA, C_PER_G, A_SYM, false); \ |
933 | } |
934 | |
935 | #define REQUANTIZE_ASYM(ISA, C_PER_G) \ |
936 | if (a_zero_point == 0) { \ |
937 | REQUANTIZE_BSYM(ISA, C_PER_G, true); \ |
938 | } else { \ |
939 | REQUANTIZE_BSYM(ISA, C_PER_G, false); \ |
940 | } |
941 | |
942 | #define REQUANTIZE_C_PER_G(ISA) \ |
943 | if (C_per_G == 2) { \ |
944 | REQUANTIZE_ASYM(ISA, 2); \ |
945 | } else if (C_per_G == 4) { \ |
946 | REQUANTIZE_ASYM(ISA, 4); \ |
947 | } else if (C_per_G == 8) { \ |
948 | REQUANTIZE_ASYM(ISA, 8); \ |
949 | } else { \ |
950 | REQUANTIZE_ASYM(ISA, 16); \ |
951 | } |
952 | |
953 | if (cpuinfo_initialize()) { |
954 | if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) { |
955 | REQUANTIZE_C_PER_G(Avx512); |
956 | } else if (fbgemmHasAvx2Support() || fbgemmHasArmNeonSupport()) { |
957 | REQUANTIZE_C_PER_G(Avx2); |
958 | } else { |
959 | assert(0 && "unsupported architecture" ); |
960 | } |
961 | } else { |
962 | throw runtime_error("Failed to initialize cpuinfo!" ); |
963 | } |
964 | } |
965 | |
966 | #undef REQUANTIZE_C_PER_G |
967 | #undef REQUANTIZE_ASYM |
968 | #undef REQUANTIZE_BSYM |
969 | #undef REQUANTIZE_BIAS |
970 | #undef REQUANTIZE_BASE |
971 | |
972 | template < |
973 | typename packed_W, |
974 | typename outType, |
975 | bool FUSE_RELU, |
976 | QuantizationGranularity Q_GRAN, |
977 | int SPATIAL_DIM, |
978 | typename BIAS_TYPE> |
979 | void fbgemmGroupwiseConv( |
980 | const conv_param_t<SPATIAL_DIM>& conv_param, |
981 | const uint8_t* activations, |
982 | int32_t a_zero_point, |
983 | int32_t* rowOffsetBuf, |
984 | packed_W& packed_weights, |
985 | outType* out, |
986 | int32_t* outBuffer, |
987 | const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess, |
988 | int thread_id, |
989 | int num_threads) { |
990 | using processOutputType = ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>; |
991 | |
992 | if (!cpuinfo_initialize()) { |
993 | throw runtime_error("Failed to initialize cpuinfo!" ); |
994 | } |
995 | |
996 | int MB = conv_param.MB; |
997 | int OT = SPATIAL_DIM <= 2 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 3]; |
998 | int OH = SPATIAL_DIM == 1 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 2]; |
999 | int OW = conv_param.OUT_DIM[SPATIAL_DIM - 1]; |
1000 | int T = SPATIAL_DIM <= 2 ? 1 : conv_param.K[SPATIAL_DIM - 3]; |
1001 | int R = SPATIAL_DIM == 1 ? 1 : conv_param.K[SPATIAL_DIM - 2]; |
1002 | int S = conv_param.K[SPATIAL_DIM - 1]; |
1003 | int G = conv_param.G; |
1004 | int OC = conv_param.OC; |
1005 | int IC = conv_param.IC; |
1006 | int K_per_G = conv_param.OC / G; |
1007 | int C_per_G = conv_param.IC / G; |
1008 | int OH_OW = OH * OW; |
1009 | int OT_OH_OW = OT * OH * OW; |
1010 | int IT = SPATIAL_DIM <= 2 ? 1 : conv_param.IN_DIM[SPATIAL_DIM - 3]; |
1011 | int IH = SPATIAL_DIM == 1 ? 1 : conv_param.IN_DIM[SPATIAL_DIM - 2]; |
1012 | int IW = conv_param.IN_DIM[SPATIAL_DIM - 1]; |
1013 | int IH_IW = IH * IW; |
1014 | int IT_IH_IW = IT * IH * IW; |
1015 | int paddedCPerG = (C_per_G + 3) / 4 * 4; |
1016 | |
1017 | #if defined(__x86_64__) || defined(__i386__) || \ |
1018 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
1019 | bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR && |
1020 | outProcess.getBZeroPoint()[0] == 0) || |
1021 | rowOffsetBuf == nullptr; |
1022 | #endif |
1023 | int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>:: |
1024 | numOfGroupsTogether(conv_param); |
1025 | |
1026 | if (SPATIAL_DIM == 1) { |
1027 | throw std::runtime_error("Groupwise 1D not implemented!" ); |
1028 | } |
1029 | if (SPATIAL_DIM == 2) { |
1030 | // Parallelization: |
1031 | int64_t batch_start = 0; |
1032 | int64_t batch_end = MB; |
1033 | int64_t oh_start = 0; |
1034 | int64_t oh_end = OH; |
1035 | if (MB >= num_threads) { |
1036 | fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end); |
1037 | } else { |
1038 | fbgemmPartition1D(thread_id, num_threads, OH, oh_start, oh_end); |
1039 | } |
1040 | |
1041 | if (batch_start >= batch_end || oh_start >= oh_end) { |
1042 | // There is no work for this thread |
1043 | return; |
1044 | } |
1045 | |
1046 | #if defined(__x86_64__) || defined(__i386__) || \ |
1047 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
1048 | // generate convolution + rowOffset kernel |
1049 | bool calculateRowOffset = !b_symmetric; |
1050 | bool isTopEdgeIncluded = oh_start == 0; |
1051 | bool isBottomEdgeIncluded = oh_end == OH; |
1052 | bool isTopBottomEdgeSame = |
1053 | isTopEdgeIncluded && isBottomEdgeIncluded && oh_end == oh_start + 1; |
1054 | jit_conv_kernel_fp fpConv = getOrCreateConvKernel<SPATIAL_DIM>( |
1055 | conv_param, |
1056 | a_zero_point, |
1057 | calculateRowOffset, |
1058 | isTopEdgeIncluded, |
1059 | isBottomEdgeIncluded, |
1060 | isTopBottomEdgeSame, |
1061 | false); |
1062 | #endif |
1063 | |
1064 | int ih_start = 0; |
1065 | if (oh_start > 0) { |
1066 | ih_start = -conv_param.pad[SPATIAL_DIM - 2] + |
1067 | oh_start * conv_param.stride[SPATIAL_DIM - 2]; |
1068 | } |
1069 | int32_t* out_start = outBuffer + oh_start * OW * OC; |
1070 | const uint8_t* in_start = activations + ih_start * IW * IC; |
1071 | int32_t* rowOffsetBuf_start = |
1072 | rowOffsetBuf ? rowOffsetBuf + oh_start * OW * G_together : nullptr; |
1073 | for (int i = batch_start; i < batch_end; ++i) { |
1074 | const uint8_t* in_start_batch = in_start + i * IH_IW * conv_param.IC; |
1075 | int32_t* out_start_batch = out_start + i * OH_OW * OC; |
1076 | int32_t* rowOffsetBuf_start_batch = |
1077 | rowOffsetBuf ? rowOffsetBuf_start + i * OH_OW * G_together : nullptr; |
1078 | for (int g = 0; g < G; g += G_together) { |
1079 | const uint8_t* in_start_group = in_start_batch + g * C_per_G; |
1080 | int8_t* weight_start = |
1081 | packed_weights.getBuf() + g * R * S * K_per_G * paddedCPerG; |
1082 | int32_t* out_start_group = out_start_batch; |
1083 | int32_t* rowOffsetBuf_start_group = rowOffsetBuf_start_batch; |
1084 | // Uncomment the following two lines to stop |
1085 | // reuse of output and rowoffset buffer |
1086 | // out_start_group = out_start_batch + g * K_per_G; |
1087 | // rowOffsetBuf_start_group = rowOffsetBuf_start_batch + g * MB * OH_OW; |
1088 | |
1089 | // exactly the same compute as the JIT'ed below |
1090 | // kernel_compute( |
1091 | // conv_param, |
1092 | // in_start_group, |
1093 | // weight_start, |
1094 | // out_start_group, |
1095 | // a_zero_point, |
1096 | // oh_start, |
1097 | // oh_end, |
1098 | // OW, |
1099 | // rowOffsetBuf_start_group); |
1100 | #if defined(__x86_64__) || defined(__i386__) || \ |
1101 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
1102 | fpConv( |
1103 | in_start_group, |
1104 | weight_start, |
1105 | out_start_group, |
1106 | a_zero_point, |
1107 | oh_start, |
1108 | oh_end, |
1109 | OW, |
1110 | rowOffsetBuf_start_group); |
1111 | #else |
1112 | kernel_compute( |
1113 | conv_param, |
1114 | in_start_group, |
1115 | weight_start, |
1116 | out_start_group, |
1117 | a_zero_point, |
1118 | oh_start, |
1119 | oh_end, |
1120 | OW, |
1121 | rowOffsetBuf_start_group, |
1122 | false); |
1123 | #endif |
1124 | |
1125 | const int32_t* inp = out_start_group; |
1126 | block_type_t block{ |
1127 | static_cast<int>(i * OT_OH_OW + oh_start * OW), |
1128 | static_cast<int>((oh_end - oh_start) * OW), |
1129 | g * K_per_G, |
1130 | G_together * K_per_G}; |
1131 | int ld_out = G * K_per_G; |
1132 | int ld_in = G * K_per_G; |
1133 | |
1134 | dispatchOutputProcessing( |
1135 | outProcess, |
1136 | rowOffsetBuf_start_group, |
1137 | out, |
1138 | inp, |
1139 | block, |
1140 | ld_out, |
1141 | ld_in, |
1142 | G, |
1143 | C_per_G, |
1144 | is_requantization<processOutputType>()); |
1145 | } // for each g |
1146 | } // for each i |
1147 | } else { |
1148 | assert(SPATIAL_DIM == 3 && "Unsupported SPATIAL_DIM" ); |
1149 | |
1150 | conv_param_t<> conv_p_2d( |
1151 | conv_param.MB, |
1152 | conv_param.IC, |
1153 | conv_param.OC, |
1154 | {conv_param.IN_DIM[SPATIAL_DIM - 2], |
1155 | conv_param.IN_DIM[SPATIAL_DIM - 1]}, |
1156 | conv_param.G, |
1157 | {conv_param.K[SPATIAL_DIM - 2], conv_param.K[SPATIAL_DIM - 1]}, |
1158 | {conv_param.stride[SPATIAL_DIM - 2], |
1159 | conv_param.stride[SPATIAL_DIM - 1]}, |
1160 | {conv_param.pad[1], |
1161 | conv_param.pad[2], |
1162 | conv_param.pad[4], |
1163 | conv_param.pad[5]}); |
1164 | |
1165 | // Parallelization: |
1166 | int64_t batch_start = 0; |
1167 | int64_t batch_end = MB; |
1168 | int64_t oh_start = 0; |
1169 | int64_t oh_end = OH; |
1170 | if (MB >= num_threads) { |
1171 | fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end); |
1172 | } else { |
1173 | fbgemmPartition1D(thread_id, num_threads, OH, oh_start, oh_end); |
1174 | } |
1175 | |
1176 | if (batch_start >= batch_end || oh_start >= oh_end) { |
1177 | // There is no work for this thread |
1178 | return; |
1179 | } |
1180 | |
1181 | #if defined(__x86_64__) || defined(__i386__) || \ |
1182 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
1183 | // generate convolution + rowOffset kernel |
1184 | bool calculateRowOffset = !b_symmetric; |
1185 | bool isTopEdgeIncluded = oh_start == 0; |
1186 | bool isBottomEdgeIncluded = oh_end == OH; |
1187 | bool isTopBottomEdgeSame = |
1188 | isTopEdgeIncluded && isBottomEdgeIncluded && oh_end == oh_start + 1; |
1189 | jit_conv_kernel_fp fpConvNoAccum = getOrCreateConvKernel<2>( |
1190 | conv_p_2d, |
1191 | a_zero_point, |
1192 | calculateRowOffset, |
1193 | isTopEdgeIncluded, |
1194 | isBottomEdgeIncluded, |
1195 | isTopBottomEdgeSame, |
1196 | false); |
1197 | jit_conv_kernel_fp fpConvAccum = getOrCreateConvKernel<2>( |
1198 | conv_p_2d, |
1199 | a_zero_point, |
1200 | calculateRowOffset, |
1201 | isTopEdgeIncluded, |
1202 | isBottomEdgeIncluded, |
1203 | isTopBottomEdgeSame, |
1204 | true); |
1205 | jit_conv_kernel_fp fpConv; |
1206 | #endif |
1207 | |
1208 | int ih_start = 0; |
1209 | if (oh_start > 0) { |
1210 | ih_start = -conv_p_2d.pad[0] + oh_start * conv_p_2d.stride[0]; |
1211 | } |
1212 | |
1213 | vector<uint8_t> zero_points(IH * IW * IC, a_zero_point); |
1214 | int32_t* out_start = outBuffer + oh_start * OW * OC; |
1215 | const uint8_t* in_start = activations + ih_start * IW * IC; |
1216 | int32_t* rowOffsetBuf_start = |
1217 | rowOffsetBuf ? rowOffsetBuf + oh_start * OW * G_together : nullptr; |
1218 | for (int i = batch_start; i < batch_end; ++i) { |
1219 | const uint8_t* in_start_batch = in_start + i * IT_IH_IW * IC; |
1220 | int32_t* out_start_batch = out_start + i * OT_OH_OW * OC; |
1221 | int32_t* rowOffsetBuf_start_batch = rowOffsetBuf |
1222 | ? rowOffsetBuf_start + i * OT_OH_OW * G_together |
1223 | : nullptr; |
1224 | for (int g = 0; g < G; g += G_together) { |
1225 | const uint8_t* in_start_group = in_start_batch + g * C_per_G; |
1226 | int8_t* weight_start = |
1227 | packed_weights.getBuf() + g * T * R * S * K_per_G * paddedCPerG; |
1228 | int32_t* out_start_group = out_start_batch; |
1229 | int32_t* rowOffsetBuf_start_group = rowOffsetBuf_start_batch; |
1230 | // Uncomment the following two lines to stop |
1231 | // reuse of output and rowoffset buffer |
1232 | // out_start_group = out_start_batch + g * K_per_G; |
1233 | // rowOffsetBuf_start_group = rowOffsetBuf_start_batch + g * MB * |
1234 | // OT_OH_OW; |
1235 | |
1236 | for (int ot = 0; ot < OT; ++ot) { |
1237 | int32_t* out_start_t = out_start_group + ot * OH_OW * OC; |
1238 | int32_t* rowOffsetBuf_start_t = rowOffsetBuf |
1239 | ? rowOffsetBuf_start_group + ot * OH_OW * G_together |
1240 | : nullptr; |
1241 | for (int t = 0; t < T; ++t) { |
1242 | int t_in = -conv_param.pad[0] + ot * conv_param.stride[0] + t; |
1243 | const uint8_t* in_start_t = in_start_group + t_in * IH_IW * IC; |
1244 | int8_t* weight_start_t = |
1245 | weight_start + t * R * S * K_per_G * G_together * paddedCPerG; |
1246 | if (t_in < 0 || t_in >= IT) { |
1247 | in_start_t = zero_points.data(); |
1248 | } |
1249 | // exactly the same compute as the JIT'ed below |
1250 | // kernel_compute( |
1251 | // conv_p_2d, |
1252 | // in_start_t, |
1253 | // weight_start_t, |
1254 | // out_start_t, |
1255 | // a_zero_point, |
1256 | // oh_start, |
1257 | // oh_end, |
1258 | // OW, |
1259 | // rowOffsetBuf_start_t, |
1260 | // t > 0); |
1261 | |
1262 | #if defined(__x86_64__) || defined(__i386__) || \ |
1263 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
1264 | fpConv = t > 0 ? fpConvAccum : fpConvNoAccum; |
1265 | fpConv( |
1266 | in_start_t, |
1267 | weight_start_t, |
1268 | out_start_t, |
1269 | a_zero_point, |
1270 | oh_start, |
1271 | oh_end, |
1272 | OW, |
1273 | rowOffsetBuf_start_t); |
1274 | #else |
1275 | kernel_compute( |
1276 | conv_p_2d, |
1277 | in_start_t, |
1278 | weight_start_t, |
1279 | out_start_t, |
1280 | a_zero_point, |
1281 | oh_start, |
1282 | oh_end, |
1283 | OW, |
1284 | rowOffsetBuf_start_t, |
1285 | t > 0); |
1286 | #endif |
1287 | } |
1288 | |
1289 | const int32_t* inp = out_start_t; |
1290 | block_type_t block{ |
1291 | static_cast<int>(i * OT_OH_OW + oh_start * OW), |
1292 | static_cast<int>((oh_end - oh_start) * OW), |
1293 | g * K_per_G, |
1294 | G_together * K_per_G}; |
1295 | int ld_out = G * K_per_G; |
1296 | int ld_in = G * K_per_G; |
1297 | |
1298 | dispatchOutputProcessing( |
1299 | outProcess, |
1300 | rowOffsetBuf_start_t, |
1301 | out + ot * OH_OW * OC, |
1302 | inp, |
1303 | block, |
1304 | ld_out, |
1305 | ld_in, |
1306 | G, |
1307 | C_per_G, |
1308 | is_requantization<processOutputType>()); |
1309 | } // for each ot |
1310 | } // for each g |
1311 | } // for each i |
1312 | } // SPATIAL_DIM == 3 |
1313 | } |
1314 | |
1315 | template <int SPATIAL_DIM> |
1316 | int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { |
1317 | // row offset buffer should be a able to hold row offsets for however |
1318 | // number of groups we process at a time. |
1319 | if (cpuinfo_initialize()) { |
1320 | int OT = SPATIAL_DIM <= 2 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 3]; |
1321 | int OH = SPATIAL_DIM == 1 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 2]; |
1322 | int bufferSize = OT * OH * conv_param.OUT_DIM[SPATIAL_DIM - 1]; |
1323 | if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support() || |
1324 | fbgemmHasArmNeonSupport()) { |
1325 | return conv_param.MB * bufferSize * conv_param.G; |
1326 | } else { |
1327 | // TODO: Have default slower path |
1328 | assert(0 && "unsupported architecture" ); |
1329 | return -1; |
1330 | } |
1331 | } else { |
1332 | throw runtime_error("Failed to initialize cpuinfo!" ); |
1333 | } |
1334 | } |
1335 | |
1336 | template FBGEMM_API int rowOffsetBufferSizeGConv<1>( |
1337 | const conv_param_t<1>& conv_param); |
1338 | template FBGEMM_API int rowOffsetBufferSizeGConv<2>( |
1339 | const conv_param_t<2>& conv_param); |
1340 | template FBGEMM_API int rowOffsetBufferSizeGConv<3>( |
1341 | const conv_param_t<3>& conv_param); |
1342 | |
1343 | #define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, BIAS_TYPE) \ |
1344 | template FBGEMM_API void fbgemmGroupwiseConv( \ |
1345 | const conv_param_t<SPATIAL_DIM>& conv_param, \ |
1346 | const uint8_t* activations, \ |
1347 | int32_t a_zero_point, \ |
1348 | int32_t* rowOffsetBuf, \ |
1349 | PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>& packed_weights, \ |
1350 | uint8_t* out, \ |
1351 | int32_t* outBuffer, \ |
1352 | const ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ |
1353 | int thread_id, \ |
1354 | int num_threads); |
1355 | |
1356 | #define INSTANTIATE_BIAS_T(RELU, Q_GRAN, SPATIAL_DIM) \ |
1357 | INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, float) \ |
1358 | INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM, int32_t) |
1359 | |
1360 | #define INSTANTIATE_SPATIAL_DIM(RELU, Q_GRAN) \ |
1361 | INSTANTIATE_BIAS_T(RELU, Q_GRAN, 1) \ |
1362 | INSTANTIATE_BIAS_T(RELU, Q_GRAN, 2) \ |
1363 | INSTANTIATE_BIAS_T(RELU, Q_GRAN, 3) |
1364 | |
1365 | #define INSTANTIATE_Q_GRANS(RELU) \ |
1366 | INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::TENSOR) \ |
1367 | INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::GROUP) \ |
1368 | INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::OUT_CHANNEL) |
1369 | |
1370 | INSTANTIATE_Q_GRANS(false); |
1371 | INSTANTIATE_Q_GRANS(true); |
1372 | |
1373 | #undef INSTANTIATE_Q_GRANS |
1374 | #undef INSTANTIATE_SPATIAL_DIM |
1375 | #undef INSTANTIATE_BIAS_T |
1376 | #undef INSTANTIATE_BASE |
1377 | |
1378 | /* |
1379 | template void fbgemmGroupwiseConv( |
1380 | const conv_param_t<2>& conv_param, |
1381 | const uint8_t* activations, |
1382 | int32_t a_zero_point, |
1383 | int32_t* rowOffsetBuf, |
1384 | PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, |
1385 | int32_t* out, |
1386 | int32_t* outBuffer, |
1387 | const DoNothing<int32_t, int32_t>& outProcess, |
1388 | int thread_id, |
1389 | int num_threads); |
1390 | */ |
1391 | |
1392 | } // namespace fbgemm |
1393 | |