1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#include <iostream>
8#include "./CodeGenHelpers.h"
9#include "./DirectConv.h"
10
11namespace fbgemm {
12
13namespace x86 = asmjit::x86;
14/**
15 * Generate AVX256 instructions for computing block in the rank-k update of
16 * 32-bit Accumulation kernel.
17 *
18 * this compute block implements the following register blocking
19 // register blocking:
20 // leverage vpmaddubsw instructions
21 for (int _icb = icb; _icb < icb + row_interleave; _icb ++ ) {
22 for (int _oc = oc; _oc < oc + mRegBLockSize; _oc ++) {
23 for (int _ow = ow; _ow < std::min(ow + 12, OUT_DIM[1]); _ow ++) {
24 out[_oc + _ow * OC] +=
25 input[_ich + (_ow + s * stride[1]) * IC + r * IC * IN_DIM[1]]
26 *
27 weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s)
28 *8 + (_oc % 8)) * 4 + (_icb % 4)];
29 }
30 }
31 }
32 *
33 */
34
35/**
36 * Generate AVX256 instructions for storing the C registers back to the memory
37 * in 32-bit Accumulation kernel.
38 */
39template <>
40template <inst_set_t instSet>
41void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs(
42 x86::Emitter* a,
43 int rowRegs,
44 int colRegs,
45 x86::Gp C_Offset,
46 x86::Gp ldcReg,
47 bool accum) {
48 using VecT = typename simd_info<instSet>::vec_reg_t;
49 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
50
51 for (int i = 0; i < rowRegs; ++i) {
52 if (i != 0) {
53 a->add(C_Offset, ldcReg);
54 } else {
55 a->xor_(C_Offset.r32(), C_Offset.r32());
56 }
57 for (int j = 0; j < colRegs; ++j) {
58 if (accum) {
59 a->vpaddd(
60 VecT(i * colRegs + j),
61 VecT(i * colRegs + j),
62 x86::dword_ptr(
63 a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)));
64 }
65 a->vmovups(
66 x86::dword_ptr(a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)),
67 VecT(i * colRegs + j));
68 }
69 }
70}
71
72template <>
73template <inst_set_t instSet>
74void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
75 genComputeBlockDirectConv(
76 x86::Emitter* a,
77 x86::Gp buffer_A,
78 x86::Gp buffer_B,
79 x86::Gp /*B_pf*/,
80 int rowRegs,
81 int colRegs,
82 int strideXich) {
83 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
84 using VecRegT = typename simd_info<instSet>::vec_reg_t;
85 constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS;
86
87 // used for matrix A
88 VecRegT AReg(numRegs - 1);
89
90 // used for matrix B
91 VecRegT BReg(numRegs - 2);
92
93 // Contains 16-bit 1s
94 VecRegT oneReg(numRegs - 3);
95
96 // temporary register
97 VecRegT res1(numRegs - 4);
98
99 for (int j = 0; j < colRegs; ++j) {
100 // load B
101 emitLoadDWord<instSet, VecRegT>(
102 a, BReg, x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t)));
103 // load A, broadcast and fmas
104 for (int i = 0; i < rowRegs; ++i) {
105 a->vpbroadcastd(
106 AReg, x86::dword_ptr(buffer_A, (i * strideXich) * sizeof(uint8_t)));
107 a->vpmaddubsw(res1, AReg, BReg);
108 a->vpmaddwd(res1, oneReg, res1);
109 a->vpaddd(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j));
110 }
111 // a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int8_t)));
112 }
113}
114
115/**
116 * Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel.
117 *
118 * This function implements a direct convolution kernel that is specialized
119 * for kernel size (2, 6) and input_height (IN_DIM[0]) = 2.
120 *
121 * More specifically the implementation has the following requirements:
122 * * Weights has layout {OC/8, KH, KW, IC/4, 8, 4}
123 * * kernel size (2, 6), IN_DIM[0] = 2, therefore: OUT_DIM[0] = 1
124 * * Features are in channel last format
125 *
126 * mRegBlockSize = 12: the number of avx2 registers for output
127 * nRegBlockSize = 8: the # of output elements in one avx2 register
128 * row_interleave = 4: the horizontal reduction size for vpmaddubsw instruction
129 * O1: output_width: OUT_DIM[1]
130 * i1Xich: input_width multiply input_channel: IN_DIM[1] x IC
131 * strideXich: stride multiply input_channel: stride[1] x input_channel
132 *
133 *
134 * The kernel implements the following algorithm:
135
136for (int ow = 0; ow < OUT_DIM[1]; ow+=12) {
137 L1 blocking: following weights are in L1 cache
138 for (int s = 0; s < K[1]; ++s) {
139 for (int r = 0; r < K[0]; ++r) {
140 for (int icb = 0; icb < IC; icb+=row_interleave) {
141
142 // register blocking:
143 // leverage vpmaddubsw instructions
144 for (int _icb = icb; _icb < icb + row_interleave; _icb ++ ) {
145 for (int _oc = oc; _oc < oc + mRegBLockSize; _oc ++) {
146 for (int _ow = ow; _ow < std::min(ow + 12, OUT_DIM[1]); _ow ++) {
147 out[_oc + _ow * OC] +=
148 input[_ich + (_ow + s * stride[1]) * IC + r * IC * IN_DIM[1]]
149 *
150 weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s)
151 *8 + (_oc % 8)) * 4 + (_icb % 4)];
152
153 // If we get rid of the brackets, and substitute corrresponding
154variables
155 //
156 // input[_ich + _ow * IC + s * strideXich + r * i1Xich]
157 // *
158 // weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s)
159 // *8 + (_oc % 8)) * 4 + (_icb % 4)];
160 }
161 }
162 }
163
164 }
165 }
166 }
167 *
168 */
169template <>
170template <inst_set_t instSet>
171DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
172DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreateDirectConv(
173 bool accum,
174 int32_t O1,
175 int32_t i1Xich,
176 int32_t strideXich) {
177 using VecRegT = typename simd_info<instSet>::vec_reg_t;
178 constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS;
179 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
180
181 std::tuple<bool, int, int, int, int, int, int> kernelSig;
182 // int ichSize = 32;
183 int mRegBlockSize = 12;
184 int nRegBlockSize = 8;
185 // int nRegBlockSizeMin;
186 int row_interleave = 4;
187
188 kernelSig = std::make_tuple(
189 accum, O1, i1Xich, strideXich, i1Xich, mRegBlockSize, nRegBlockSize);
190
191 return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
192 asmjit::CodeHolder code;
193 code.init(runtime().environment());
194 x86::Assembler assembler(&code);
195 x86::Emitter* a = assembler.as<x86::Emitter>();
196#if defined(FBGEMM_LOG_CODE)
197 // generated code logging
198 FILE* codeLogfile = fopen(
199 getCodeLoggingFile<instSet>(
200 accum, O1, i1Xich, strideXich, i1Xich, mRegBlockSize, nRegBlockSize)
201 .c_str(),
202 "w");
203 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
204 if (codeLogger) {
205 code.setLogger(codeLogger);
206 }
207#endif
208
209 const int maxMRegs = mRegBlockSize;
210 (void)maxMRegs; // Suppress unused variable warning
211 const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
212 assert(
213 maxMRegs * maxNRegs <= numRegs - 4 &&
214 "MRegs x NRegs is above available registers (MAX_REGS - 4)");
215
216 int O1RegBlocks = O1 / mRegBlockSize;
217 int O1RegBlocksRem = O1 % mRegBlockSize;
218
219 // arguments to the function created
220 x86::Gp buffer_A = a->zdi();
221 x86::Gp buffer_B = a->zsi();
222 x86::Gp B_pf = a->zdx();
223 x86::Gp CBase = a->zcx();
224 x86::Gp ichXk1 = a->gpz(8);
225 x86::Gp ldcReg = a->gpz(9);
226
227 asmjit::FuncDetail func;
228 func.init(
229 asmjit::FuncSignatureT<
230 void,
231 uint8_t*,
232 int8_t*,
233 int8_t*,
234 int32_t*,
235 int,
236 int>(asmjit::CallConvId::kHost),
237 a->environment());
238
239 asmjit::FuncFrame frame;
240 frame.init(func);
241
242 auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
243 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15);
244 if (numRegs >= 16) {
245 dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
246 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31);
247 }
248
249 frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs);
250 frame.setDirtyRegs(
251 asmjit::RegGroup::kGp,
252 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
253
254 asmjit::FuncArgsAssignment args(&func);
255 args.assignAll(buffer_A, buffer_B, B_pf, CBase, ichXk1, ldcReg);
256
257 args.updateFuncFrame(frame);
258 frame.finalize();
259
260 a->emitProlog(frame);
261 a->emitArgsAssignment(frame, args);
262
263 asmjit::Label LoopMBlocks = a->newLabel();
264 // asmjit::Label LoopOBlocks = a->newLabel();
265 // asmjit::Label LoopNBlocks = a->newLabel();
266
267 x86::Gp buffer_B_saved = a->gpz(10);
268 x86::Gp C_Offset = a->gpz(11);
269 // x86::Gp B_pf_saved = a->gpz(12);
270 x86::Gp iIdx = a->gpz(13);
271 // x86::Gp jIdx = a->gpz(14);
272 x86::Gp kIdx = a->gpz(15);
273 // x86::Gp B_pf = a->gpz(8);
274
275 VecRegT oneReg(numRegs - 3);
276
277 gen16BitVectorOne<instSet, VecRegT>(a, oneReg);
278 a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
279 // a->xor_(C_Offset.r32(), C_Offset.r32());
280
281 // a->mov(B_pf_saved, B_pf);
282
283 int colRegs = maxNRegs;
284
285 auto issueLoopOverK = [&](int rowRegs) {
286 // loopKLabel: corresponds to loop "r" where r = 0
287 // loopK0Label: corresponds to loop "r" where r = 1
288 asmjit::Label LoopKLabel = a->newLabel();
289 asmjit::Label LoopK0Label = a->newLabel();
290
291 // Init C (result) vector registers
292 initCRegs(a, rowRegs, colRegs);
293
294 // Loops over K: input channel
295 // a.k.a this issueLoopOverK code block generates code
296 // corresponding to the "ich" loop of the psedo-code
297 a->xor_(kIdx.r32(), kIdx.r32());
298 a->bind(LoopKLabel);
299
300 // k is incremented by row_interleave
301 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
302
303 // this ComputeBlock generates code correspondent to
304 // the above psedu-code since the kernel_height loop (loop "r").
305 // And because K[0] == 2 and IN_DIM[2] (requirement #2),
306 // we can unroll loop "r" here. Thus this following
307 // genComputeBlockDirectConv generates code for loop "r" = 0
308 genComputeBlockDirectConv<instSet>(
309 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, strideXich);
310
311 // update buffer_A address for next k iteration
312 a->add(
313 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
314
315 // update buffer_B address for next k iteration
316 a->add(buffer_B, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
317 a->add(B_pf, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
318
319 a->cmp(kIdx, ichXk1);
320 a->jl(LoopKLabel);
321
322 a->sub(buffer_A, ichXk1);
323
324 a->add(buffer_A, static_cast<asmjit::Imm>(i1Xich));
325
326 a->xor_(kIdx.r32(), kIdx.r32());
327 a->bind(LoopK0Label);
328
329 // k is incremented by row_interleave
330 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
331
332 // this ComputeBlock generates code that corresponds
333 // to the kernel_height loop (loop "r") in the psedu-code above.
334 // And the following genComputeBlockDirectConv
335 // generates code for loop "r" where "r" = 1
336 genComputeBlockDirectConv<instSet>(
337 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, strideXich);
338
339 // update buffer_A address for next k iteration
340 a->add(
341 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
342
343 // update buffer_B address for next k iteration
344 a->add(buffer_B, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
345 a->add(B_pf, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
346
347 a->cmp(kIdx, ichXk1);
348 a->jl(LoopK0Label);
349
350 a->sub(buffer_A, ichXk1);
351
352 // store C matrix
353 storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
354 };
355
356 if (O1RegBlocks > 0) {
357 // move 0 to iteration variables
358 a->xor_(iIdx.r32(), iIdx.r32());
359
360 // iIdex loop corresponds to kernel_width loop (loop "s")
361 // in the direct conv loops
362 a->bind(LoopMBlocks);
363 a->inc(iIdx);
364
365 // save B_buffer address
366 a->mov(buffer_B_saved, buffer_B);
367
368 issueLoopOverK(mRegBlockSize);
369
370 int rowRegs = mRegBlockSize;
371
372 // reset A
373 a->sub(buffer_A, static_cast<asmjit::Imm>(i1Xich));
374
375 // increment A for next block
376 a->add(
377 buffer_A,
378 static_cast<asmjit::Imm>(rowRegs * strideXich * sizeof(uint8_t)));
379
380 // B for next block
381 a->mov(buffer_B, buffer_B_saved);
382
383 // increment C for next B block
384 // ldcReg already multiplied with 4 (sizeof(int32_t))
385 a->imul(
386 C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int8_t)));
387 a->add(CBase, C_Offset);
388
389 // a->add(CBase, static_cast<asmjit::Imm>(12*16*4));
390 // storeCRegs<instSet>(a, 12, 1, C_Offset, ldcReg, accum);
391
392 a->cmp(iIdx, O1RegBlocks);
393 a->jl(LoopMBlocks);
394 }
395
396 // generate code for remainder
397 if (O1RegBlocksRem > 0) {
398 issueLoopOverK(O1RegBlocksRem);
399 }
400
401 a->emitEpilog(frame);
402
403 jit_micro_kernel_fp fn;
404 asmjit::Error err;
405 {
406 std::unique_lock<std::mutex> lock(rtMutex_);
407 err = runtime().add(&fn, &code);
408 }
409 if (err) {
410 std::cout << "Error: in fn add" << std::endl;
411 return nullptr;
412 }
413
414#if defined(FBGEMM_LOG_CODE)
415 fclose(codeLogfile);
416 delete codeLogger;
417#endif
418
419 return fn;
420 });
421}
422
423/**
424 * Generate AVX256 instructions for storing the C registers back to the memory
425 * in 32-bit Accumulation kernel.
426 */
427template <>
428template <inst_set_t instSet>
429void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegsTrans(
430 x86::Emitter* a,
431 int rowRegs,
432 int colRegs,
433 x86::Gp C_offset,
434 x86::Gp o1XocReg,
435 x86::Gp ldcReg,
436 bool accum) {
437 using VecT = typename simd_info<instSet>::vec_reg_t;
438 // static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
439
440 a->xor_(C_offset.r32(), C_offset.r32());
441 for (int i = 0; i < rowRegs; ++i) {
442 for (int j = 0; j < colRegs; ++j) {
443 if (accum) {
444 a->vpaddd(
445 VecT(i * colRegs + j),
446 VecT(i * colRegs + j),
447 x86::dword_ptr(a->zcx(), C_offset));
448 }
449 a->vmovups(x86::dword_ptr(a->zcx(), C_offset), VecT(i * colRegs + j));
450 a->add(C_offset, ldcReg);
451 }
452 a->add(C_offset, o1XocReg);
453 }
454}
455
456/**
457 * Generate AVX256 instructions for computing block in the rank-k update of
458 * 32-bit Accumulation kernel.
459
460The function generates the register blocking code for transposed
461direct convolution
462 // register blocking for transposed direct convolution:
463 // K[0] x K[1] = 12, the corresponding 12 x 8 output elements will
464 // be kept in the twelve avx2 registers, which is:
465 // out[ih + 0..r][iw + 0..s][8]
466 for (int r = 0; r < K[0]; ++r) {
467 for (int s = 0; s < K[1]; ++s) {
468 int oh = ih * conv_p.stride[0] + r;
469 int ow = iw * conv_p.stride[1] + s;
470
471 int a = input[((ih)*IN_DIM[1] + iw) * IC + icb];
472 int b = weight
473 [(((((oc / 8) * K[0] + r) * K[1] + s) * (IC / 4) + icb / 4) *
474 8 +
475 (oc % 8)) *
476 4 +
477 (icb % 4)];
478 out[((oh)*OUT_DIM[1] + ow) * OC + oc] += a * b;
479 }
480 }
481 }
482 *
483 */
484template <>
485template <inst_set_t instSet>
486void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
487 genComputeBlockDirectConvTrans(
488 x86::Emitter* a,
489 x86::Gp buffer_A,
490 x86::Gp buffer_B,
491 x86::Gp icReg,
492 x86::Gp C_offset,
493 int rowRegs,
494 int colRegs) {
495 // static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
496 using VecRegT = typename simd_info<instSet>::vec_reg_t;
497 constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS;
498
499 // used for matrix A
500 VecRegT AReg(numRegs - 1);
501
502 // used for matrix B
503 VecRegT BReg(numRegs - 2);
504
505 // Contains 16-bit 1s
506 VecRegT oneReg(numRegs - 3);
507
508 // temporary register
509 VecRegT res1(numRegs - 4);
510
511 // load A
512 a->vpbroadcastd(AReg, x86::dword_ptr(buffer_A));
513
514 a->xor_(C_offset.r32(), C_offset.r32());
515 for (int i = 0; i < rowRegs; ++i) {
516 for (int j = 0; j < colRegs; ++j) {
517 // load B, broadcast and fmas
518 emitLoadDWord<instSet, VecRegT>(
519 a, BReg, x86::dword_ptr(buffer_B, C_offset, 3, 0));
520 a->vpmaddubsw(res1, AReg, BReg);
521 a->vpmaddwd(res1, oneReg, res1);
522 a->vpaddd(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j));
523 a->add(C_offset, icReg);
524 }
525 // a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int8_t)));
526 }
527}
528
529/**
530 * Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel.
531 *
532 * This function implements a direct convolution kernel that is specialized
533 * for kernel size (2, 6)
534 *
535 * More specifically the implementation has the following prerequisites:
536 * * Weights has layout {OC/8, KH, KW, IC/4, 8, 4}
537 * * kernel size (2, 6)
538 * * Features are in channel last format
539 * * Stride[0] = 1, Stride[1] = 1 or 2
540 * * Padding = 0
541 *
542 * mRegBlockSize = 12: the number of avx2 registers for output
543 * nRegBlockSize = 8: the # of output elements in one avx2 register
544 * row_interleave = 4: the horizontal reduction size for vpmaddubsw instruction
545 * stride: stride[1], 1 or 2. we stride[0] = 1
546 * ic: input channel
547 * i1: input_width: IN_DIM[1]
548 * ldcReg: leading dimension of output, a.k.a OC
549 * o1Xoc: output width multiply output channel: OUT_DIM[1] x OC
550 *
551 * The kernel implements the following algorithm:
552
553 for (int oc = 0; oc < OC; oc++) {
554 for (int ih = 0; ih < IN_DIM[0]; ++ih) {
555 for (int iw = 0; iw < IN_DIM[1]; iw++) {
556 // L1 blocking
557 for (int icb = 0; icb < IC; icb+=4) {
558 // register blocking:
559 // K[0] x K[1] = 12, the corresponding 12 x 8 output elements will
560 // be kept in the twelve avx2 registers, which is:
561 // out[ih + 0..r][iw + 0..s][8]
562 for (int r = 0; r < K[0]; ++r) {
563 for (int s = 0; s < K[1]; ++s) {
564 for (int _icb = icb ; _icb < icb + 4; _icb ++) {
565 int oh = ih * conv_p.stride[0] + r;
566 int ow = iw * conv_p.stride[1] + s;
567
568 int a = input[((ih)*IN_DIM[1] + iw) * IC + icb];
569 int b = weight
570 [(((((oc / 8) * K[0] + r) * K[1] + s) * (IC / 4) + icb / 4) *
571 8 +
572 (oc % 8)) *
573 4 +
574 (icb % 4)];
575 out[((oh)*OUT_DIM[1] + ow) * OC + oc] += a * b;
576
577 // if we get rid of the brackets, and substitude corresponding
578 variables:
579 // out[ih * stride0 * o1Xoc + r * o1Xoc + iw * ldcReg + oc]
580 // input[ih * i1 * ic + iw * ic + icb]
581 }
582 }
583 }
584 }
585 } // for each ic
586 } // for each s
587 } // for each r
588
589 *
590 */
591
592/**
593 * Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel.
594 *
595 */
596template <>
597template <inst_set_t instSet>
598DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
599 jit_micro_kernel_fp_convT
600 DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
601 getOrCreateDirectConvTrans(
602 bool accum,
603 int32_t stride,
604 int32_t numColRegs) {
605 using VecRegT = typename simd_info<instSet>::vec_reg_t;
606 constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS;
607 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
608
609 std::tuple<bool, int, int, int> kernelSig;
610 // int ichSize = 32;
611 int mRowRegBlockSize = 2;
612 int mColRegBlockSize = numColRegs;
613 int mRegBlockSize = mRowRegBlockSize * mColRegBlockSize;
614 int nRegBlockSize = 8;
615 // int nRegBlockSizeMin;
616 int row_interleave = 4;
617
618 kernelSig = std::make_tuple(accum, stride, mRegBlockSize, nRegBlockSize);
619
620 return codeCacheT_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp_convT {
621 asmjit::CodeHolder code;
622 code.init(runtime().environment());
623 x86::Assembler assembler(&code);
624 x86::Emitter* a = assembler.as<x86::Emitter>();
625#if defined(FBGEMM_LOG_CODE)
626 // generated code logging
627 FILE* codeLogfile = fopen(
628 getCodeLoggingFile<instSet>(accum, stride, mRegBlockSize, nRegBlockSize)
629 .c_str(),
630 "w");
631 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
632 if (codeLogger) {
633 code.setLogger(codeLogger);
634 }
635#endif
636
637 const int maxMRegs = mRegBlockSize;
638 (void)maxMRegs; // Suppress unused variable warning
639 const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
640 assert(
641 maxMRegs * maxNRegs <= numRegs - 4 &&
642 "MRegs x NRegs is above available registers (MAX_REGS - 4)");
643
644 // arguments to the function created
645 x86::Gp buffer_A = a->zdi();
646 x86::Gp buffer_B = a->zsi();
647 x86::Gp CBase = a->zcx();
648 x86::Gp ic = a->gpz(8);
649 x86::Gp ldcReg = a->gpz(9);
650 x86::Gp o1Xoc = a->gpz(10);
651 x86::Gp i1 = a->gpz(11);
652
653 asmjit::FuncDetail func;
654 func.init(
655 asmjit::FuncSignatureT<
656 void,
657 uint8_t*,
658 int8_t*,
659 int32_t*,
660 int,
661 int,
662 int,
663 int>(asmjit::CallConvId::kHost),
664 a->environment());
665
666 asmjit::FuncFrame frame;
667 frame.init(func);
668
669 auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
670 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15);
671 if (numRegs >= 16) {
672 dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
673 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31);
674 }
675
676 frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs);
677 frame.setDirtyRegs(
678 asmjit::RegGroup::kGp,
679 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
680
681 asmjit::FuncArgsAssignment args(&func);
682 args.assignAll(buffer_A, buffer_B, CBase, ic, ldcReg, o1Xoc, i1);
683
684 args.updateFuncFrame(frame);
685 frame.finalize();
686
687 a->emitProlog(frame);
688 a->emitArgsAssignment(frame, args);
689
690 asmjit::Label LoopMBlocks = a->newLabel();
691
692 x86::Gp C_offset = a->gpz(12);
693 x86::Gp buffer_B_saved = a->gpz(13);
694 x86::Gp iIdx = a->gpz(14);
695 x86::Gp kIdx = a->gpz(15);
696
697 VecRegT oneReg(numRegs - 3);
698
699 gen16BitVectorOne<instSet, VecRegT>(a, oneReg);
700 a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
701
702 int colRegs = maxNRegs;
703
704 auto issueLoopOverK = [&](int rowRegs) {
705 asmjit::Label LoopKLabel = a->newLabel();
706
707 // Init C (result) vector registers
708 initCRegs(a, rowRegs, colRegs);
709
710 // Loops over K: input channel
711 a->xor_(kIdx.r32(), kIdx.r32());
712 a->bind(LoopKLabel);
713
714 // k is incremented by row_interleave
715 a->add(kIdx, 4);
716 genComputeBlockDirectConvTrans<instSet>(
717 a,
718 buffer_A,
719 buffer_B,
720 ic,
721 C_offset,
722 mRowRegBlockSize,
723 mColRegBlockSize);
724
725 // update buffer_A address for next k iteration
726 a->add(
727 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
728
729 // update buffer_B address for next k iteration
730 a->add(buffer_B, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
731
732 a->cmp(kIdx, ic);
733 a->jl(LoopKLabel);
734
735 // store C matrix
736 storeCRegsTrans<instSet>(
737 a,
738 mRowRegBlockSize,
739 mColRegBlockSize,
740 C_offset,
741 o1Xoc,
742 ldcReg,
743 accum);
744 };
745
746 {
747 // move 0 to iteration variables
748 a->xor_(iIdx.r32(), iIdx.r32());
749
750 a->bind(LoopMBlocks);
751 a->inc(iIdx);
752
753 // save B_buffer address
754 a->mov(buffer_B_saved, buffer_B);
755
756 issueLoopOverK(mRegBlockSize);
757
758 // B for next block
759 a->mov(buffer_B, buffer_B_saved);
760 // increment C for next B block
761 a->imul(
762 C_offset,
763 ldcReg,
764 static_cast<asmjit::Imm>(stride)); // ldcReg already multiplied by 4
765 a->add(CBase, C_offset);
766
767 a->cmp(iIdx, i1);
768 a->jl(LoopMBlocks);
769 }
770
771 a->emitEpilog(frame);
772
773 jit_micro_kernel_fp_convT fn;
774 asmjit::Error err;
775 {
776 std::unique_lock<std::mutex> lock(rtMutex_);
777 err = runtime().add(&fn, &code);
778 }
779 if (err) {
780 std::cout << "Error: in fn add" << std::endl;
781 return nullptr;
782 }
783
784#if defined(FBGEMM_LOG_CODE)
785 fclose(codeLogfile);
786 delete codeLogger;
787#endif
788
789 return fn;
790 });
791}
792
793/**
794 * Instantiate the inst_set_t::avx512 instructions for store kernel.
795 *
796 */
797template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
798 storeCRegs<inst_set_t::avx512>(
799 x86::Emitter* a,
800 int rowRegs,
801 int colRegs,
802 x86::Gp C_Offset,
803 x86::Gp ldcReg,
804 bool accum);
805
806/**
807 * Instantiate the inst_set_t::avx512_ymm instructions for store kernel.
808 *
809 */
810template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
811 storeCRegs<inst_set_t::avx512_ymm>(
812 x86::Emitter* a,
813 int rowRegs,
814 int colRegs,
815 x86::Gp C_Offset,
816 x86::Gp ldcReg,
817 bool accum);
818
819/**
820 * Instantiate the inst_set_t::avx2 instructions for store kernel.
821 *
822 */
823template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
824 storeCRegs<inst_set_t::avx2>(
825 x86::Emitter* a,
826 int rowRegs,
827 int colRegs,
828 x86::Gp C_Offset,
829 x86::Gp ldcReg,
830 bool accum);
831
832/**
833 * Instantiate the inst_set_t::avx2 instructions for store kernel.
834 *
835 */
836template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
837 storeCRegsTrans<inst_set_t::avx512>(
838 x86::Emitter* a,
839 int rowRegs,
840 int colRegs,
841 x86::Gp C_offset,
842 x86::Gp o1XocReg,
843 x86::Gp ldcReg,
844 bool accum);
845
846/**
847 * Instantiate the inst_set_t::avx2 instructions for store kernel.
848 *
849 */
850template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
851 storeCRegsTrans<inst_set_t::avx512_ymm>(
852 x86::Emitter* a,
853 int rowRegs,
854 int colRegs,
855 x86::Gp C_offset,
856 x86::Gp o1XocReg,
857 x86::Gp ldcReg,
858 bool accum);
859
860/**
861 * Instantiate the inst_set_t::avx2 instructions for store kernel.
862 *
863 */
864template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
865 storeCRegsTrans<inst_set_t::avx2>(
866 x86::Emitter* a,
867 int rowRegs,
868 int colRegs,
869 x86::Gp C_offset,
870 x86::Gp o1XocReg,
871 x86::Gp ldcReg,
872 bool accum);
873
874/**
875 * Instantiate the AVX2 instructions for 32-bit Accumulation macro-kernel.
876 *
877 */
878template DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
879 jit_micro_kernel_fp
880 DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
881 getOrCreateDirectConv<inst_set_t::avx2>(
882 bool accum,
883 int32_t O1,
884 int32_t i1Xich,
885 int32_t strideXich);
886
887/**
888 * Instantiate the AVX2 instructions for 32-bit Accumulation macro-kernel.
889 *
890 */
891template DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
892 jit_micro_kernel_fp_convT
893 DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
894 getOrCreateDirectConvTrans<inst_set_t::avx2>(
895 bool accum,
896 int32_t stride,
897 int32_t numColRegs);
898
899} // namespace fbgemm
900