1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #include <iostream> |
8 | #include "./CodeGenHelpers.h" |
9 | #include "./DirectConv.h" |
10 | |
11 | namespace fbgemm { |
12 | |
13 | namespace x86 = asmjit::x86; |
14 | /** |
15 | * Generate AVX256 instructions for computing block in the rank-k update of |
16 | * 32-bit Accumulation kernel. |
17 | * |
18 | * this compute block implements the following register blocking |
19 | // register blocking: |
20 | // leverage vpmaddubsw instructions |
21 | for (int _icb = icb; _icb < icb + row_interleave; _icb ++ ) { |
22 | for (int _oc = oc; _oc < oc + mRegBLockSize; _oc ++) { |
23 | for (int _ow = ow; _ow < std::min(ow + 12, OUT_DIM[1]); _ow ++) { |
24 | out[_oc + _ow * OC] += |
25 | input[_ich + (_ow + s * stride[1]) * IC + r * IC * IN_DIM[1]] |
26 | * |
27 | weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s) |
28 | *8 + (_oc % 8)) * 4 + (_icb % 4)]; |
29 | } |
30 | } |
31 | } |
32 | * |
33 | */ |
34 | |
35 | /** |
36 | * Generate AVX256 instructions for storing the C registers back to the memory |
37 | * in 32-bit Accumulation kernel. |
38 | */ |
39 | template <> |
40 | template <inst_set_t instSet> |
41 | void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs( |
42 | x86::Emitter* a, |
43 | int rowRegs, |
44 | int colRegs, |
45 | x86::Gp C_Offset, |
46 | x86::Gp ldcReg, |
47 | bool accum) { |
48 | using VecT = typename simd_info<instSet>::vec_reg_t; |
49 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
50 | |
51 | for (int i = 0; i < rowRegs; ++i) { |
52 | if (i != 0) { |
53 | a->add(C_Offset, ldcReg); |
54 | } else { |
55 | a->xor_(C_Offset.r32(), C_Offset.r32()); |
56 | } |
57 | for (int j = 0; j < colRegs; ++j) { |
58 | if (accum) { |
59 | a->vpaddd( |
60 | VecT(i * colRegs + j), |
61 | VecT(i * colRegs + j), |
62 | x86::dword_ptr( |
63 | a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t))); |
64 | } |
65 | a->vmovups( |
66 | x86::dword_ptr(a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)), |
67 | VecT(i * colRegs + j)); |
68 | } |
69 | } |
70 | } |
71 | |
72 | template <> |
73 | template <inst_set_t instSet> |
74 | void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
75 | genComputeBlockDirectConv( |
76 | x86::Emitter* a, |
77 | x86::Gp buffer_A, |
78 | x86::Gp buffer_B, |
79 | x86::Gp /*B_pf*/, |
80 | int rowRegs, |
81 | int colRegs, |
82 | int strideXich) { |
83 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
84 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
85 | constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS; |
86 | |
87 | // used for matrix A |
88 | VecRegT AReg(numRegs - 1); |
89 | |
90 | // used for matrix B |
91 | VecRegT BReg(numRegs - 2); |
92 | |
93 | // Contains 16-bit 1s |
94 | VecRegT oneReg(numRegs - 3); |
95 | |
96 | // temporary register |
97 | VecRegT res1(numRegs - 4); |
98 | |
99 | for (int j = 0; j < colRegs; ++j) { |
100 | // load B |
101 | emitLoadDWord<instSet, VecRegT>( |
102 | a, BReg, x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); |
103 | // load A, broadcast and fmas |
104 | for (int i = 0; i < rowRegs; ++i) { |
105 | a->vpbroadcastd( |
106 | AReg, x86::dword_ptr(buffer_A, (i * strideXich) * sizeof(uint8_t))); |
107 | a->vpmaddubsw(res1, AReg, BReg); |
108 | a->vpmaddwd(res1, oneReg, res1); |
109 | a->vpaddd(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j)); |
110 | } |
111 | // a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int8_t))); |
112 | } |
113 | } |
114 | |
115 | /** |
116 | * Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel. |
117 | * |
118 | * This function implements a direct convolution kernel that is specialized |
119 | * for kernel size (2, 6) and input_height (IN_DIM[0]) = 2. |
120 | * |
121 | * More specifically the implementation has the following requirements: |
122 | * * Weights has layout {OC/8, KH, KW, IC/4, 8, 4} |
123 | * * kernel size (2, 6), IN_DIM[0] = 2, therefore: OUT_DIM[0] = 1 |
124 | * * Features are in channel last format |
125 | * |
126 | * mRegBlockSize = 12: the number of avx2 registers for output |
127 | * nRegBlockSize = 8: the # of output elements in one avx2 register |
128 | * row_interleave = 4: the horizontal reduction size for vpmaddubsw instruction |
129 | * O1: output_width: OUT_DIM[1] |
130 | * i1Xich: input_width multiply input_channel: IN_DIM[1] x IC |
131 | * strideXich: stride multiply input_channel: stride[1] x input_channel |
132 | * |
133 | * |
134 | * The kernel implements the following algorithm: |
135 | |
136 | for (int ow = 0; ow < OUT_DIM[1]; ow+=12) { |
137 | L1 blocking: following weights are in L1 cache |
138 | for (int s = 0; s < K[1]; ++s) { |
139 | for (int r = 0; r < K[0]; ++r) { |
140 | for (int icb = 0; icb < IC; icb+=row_interleave) { |
141 | |
142 | // register blocking: |
143 | // leverage vpmaddubsw instructions |
144 | for (int _icb = icb; _icb < icb + row_interleave; _icb ++ ) { |
145 | for (int _oc = oc; _oc < oc + mRegBLockSize; _oc ++) { |
146 | for (int _ow = ow; _ow < std::min(ow + 12, OUT_DIM[1]); _ow ++) { |
147 | out[_oc + _ow * OC] += |
148 | input[_ich + (_ow + s * stride[1]) * IC + r * IC * IN_DIM[1]] |
149 | * |
150 | weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s) |
151 | *8 + (_oc % 8)) * 4 + (_icb % 4)]; |
152 | |
153 | // If we get rid of the brackets, and substitute corrresponding |
154 | variables |
155 | // |
156 | // input[_ich + _ow * IC + s * strideXich + r * i1Xich] |
157 | // * |
158 | // weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s) |
159 | // *8 + (_oc % 8)) * 4 + (_icb % 4)]; |
160 | } |
161 | } |
162 | } |
163 | |
164 | } |
165 | } |
166 | } |
167 | * |
168 | */ |
169 | template <> |
170 | template <inst_set_t instSet> |
171 | DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp |
172 | DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreateDirectConv( |
173 | bool accum, |
174 | int32_t O1, |
175 | int32_t i1Xich, |
176 | int32_t strideXich) { |
177 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
178 | constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS; |
179 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
180 | |
181 | std::tuple<bool, int, int, int, int, int, int> kernelSig; |
182 | // int ichSize = 32; |
183 | int mRegBlockSize = 12; |
184 | int nRegBlockSize = 8; |
185 | // int nRegBlockSizeMin; |
186 | int row_interleave = 4; |
187 | |
188 | kernelSig = std::make_tuple( |
189 | accum, O1, i1Xich, strideXich, i1Xich, mRegBlockSize, nRegBlockSize); |
190 | |
191 | return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { |
192 | asmjit::CodeHolder code; |
193 | code.init(runtime().environment()); |
194 | x86::Assembler assembler(&code); |
195 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
196 | #if defined(FBGEMM_LOG_CODE) |
197 | // generated code logging |
198 | FILE* codeLogfile = fopen( |
199 | getCodeLoggingFile<instSet>( |
200 | accum, O1, i1Xich, strideXich, i1Xich, mRegBlockSize, nRegBlockSize) |
201 | .c_str(), |
202 | "w" ); |
203 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); |
204 | if (codeLogger) { |
205 | code.setLogger(codeLogger); |
206 | } |
207 | #endif |
208 | |
209 | const int maxMRegs = mRegBlockSize; |
210 | (void)maxMRegs; // Suppress unused variable warning |
211 | const int maxNRegs = nRegBlockSize * row_interleave / vectorLen; |
212 | assert( |
213 | maxMRegs * maxNRegs <= numRegs - 4 && |
214 | "MRegs x NRegs is above available registers (MAX_REGS - 4)" ); |
215 | |
216 | int O1RegBlocks = O1 / mRegBlockSize; |
217 | int O1RegBlocksRem = O1 % mRegBlockSize; |
218 | |
219 | // arguments to the function created |
220 | x86::Gp buffer_A = a->zdi(); |
221 | x86::Gp buffer_B = a->zsi(); |
222 | x86::Gp B_pf = a->zdx(); |
223 | x86::Gp CBase = a->zcx(); |
224 | x86::Gp ichXk1 = a->gpz(8); |
225 | x86::Gp ldcReg = a->gpz(9); |
226 | |
227 | asmjit::FuncDetail func; |
228 | func.init( |
229 | asmjit::FuncSignatureT< |
230 | void, |
231 | uint8_t*, |
232 | int8_t*, |
233 | int8_t*, |
234 | int32_t*, |
235 | int, |
236 | int>(asmjit::CallConvId::kHost), |
237 | a->environment()); |
238 | |
239 | asmjit::FuncFrame frame; |
240 | frame.init(func); |
241 | |
242 | auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
243 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15); |
244 | if (numRegs >= 16) { |
245 | dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
246 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31); |
247 | } |
248 | |
249 | frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs); |
250 | frame.setDirtyRegs( |
251 | asmjit::RegGroup::kGp, |
252 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
253 | |
254 | asmjit::FuncArgsAssignment args(&func); |
255 | args.assignAll(buffer_A, buffer_B, B_pf, CBase, ichXk1, ldcReg); |
256 | |
257 | args.updateFuncFrame(frame); |
258 | frame.finalize(); |
259 | |
260 | a->emitProlog(frame); |
261 | a->emitArgsAssignment(frame, args); |
262 | |
263 | asmjit::Label LoopMBlocks = a->newLabel(); |
264 | // asmjit::Label LoopOBlocks = a->newLabel(); |
265 | // asmjit::Label LoopNBlocks = a->newLabel(); |
266 | |
267 | x86::Gp buffer_B_saved = a->gpz(10); |
268 | x86::Gp C_Offset = a->gpz(11); |
269 | // x86::Gp B_pf_saved = a->gpz(12); |
270 | x86::Gp iIdx = a->gpz(13); |
271 | // x86::Gp jIdx = a->gpz(14); |
272 | x86::Gp kIdx = a->gpz(15); |
273 | // x86::Gp B_pf = a->gpz(8); |
274 | |
275 | VecRegT oneReg(numRegs - 3); |
276 | |
277 | gen16BitVectorOne<instSet, VecRegT>(a, oneReg); |
278 | a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); |
279 | // a->xor_(C_Offset.r32(), C_Offset.r32()); |
280 | |
281 | // a->mov(B_pf_saved, B_pf); |
282 | |
283 | int colRegs = maxNRegs; |
284 | |
285 | auto issueLoopOverK = [&](int rowRegs) { |
286 | // loopKLabel: corresponds to loop "r" where r = 0 |
287 | // loopK0Label: corresponds to loop "r" where r = 1 |
288 | asmjit::Label LoopKLabel = a->newLabel(); |
289 | asmjit::Label LoopK0Label = a->newLabel(); |
290 | |
291 | // Init C (result) vector registers |
292 | initCRegs(a, rowRegs, colRegs); |
293 | |
294 | // Loops over K: input channel |
295 | // a.k.a this issueLoopOverK code block generates code |
296 | // corresponding to the "ich" loop of the psedo-code |
297 | a->xor_(kIdx.r32(), kIdx.r32()); |
298 | a->bind(LoopKLabel); |
299 | |
300 | // k is incremented by row_interleave |
301 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
302 | |
303 | // this ComputeBlock generates code correspondent to |
304 | // the above psedu-code since the kernel_height loop (loop "r"). |
305 | // And because K[0] == 2 and IN_DIM[2] (requirement #2), |
306 | // we can unroll loop "r" here. Thus this following |
307 | // genComputeBlockDirectConv generates code for loop "r" = 0 |
308 | genComputeBlockDirectConv<instSet>( |
309 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, strideXich); |
310 | |
311 | // update buffer_A address for next k iteration |
312 | a->add( |
313 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
314 | |
315 | // update buffer_B address for next k iteration |
316 | a->add(buffer_B, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); |
317 | a->add(B_pf, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); |
318 | |
319 | a->cmp(kIdx, ichXk1); |
320 | a->jl(LoopKLabel); |
321 | |
322 | a->sub(buffer_A, ichXk1); |
323 | |
324 | a->add(buffer_A, static_cast<asmjit::Imm>(i1Xich)); |
325 | |
326 | a->xor_(kIdx.r32(), kIdx.r32()); |
327 | a->bind(LoopK0Label); |
328 | |
329 | // k is incremented by row_interleave |
330 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
331 | |
332 | // this ComputeBlock generates code that corresponds |
333 | // to the kernel_height loop (loop "r") in the psedu-code above. |
334 | // And the following genComputeBlockDirectConv |
335 | // generates code for loop "r" where "r" = 1 |
336 | genComputeBlockDirectConv<instSet>( |
337 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, strideXich); |
338 | |
339 | // update buffer_A address for next k iteration |
340 | a->add( |
341 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
342 | |
343 | // update buffer_B address for next k iteration |
344 | a->add(buffer_B, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); |
345 | a->add(B_pf, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); |
346 | |
347 | a->cmp(kIdx, ichXk1); |
348 | a->jl(LoopK0Label); |
349 | |
350 | a->sub(buffer_A, ichXk1); |
351 | |
352 | // store C matrix |
353 | storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
354 | }; |
355 | |
356 | if (O1RegBlocks > 0) { |
357 | // move 0 to iteration variables |
358 | a->xor_(iIdx.r32(), iIdx.r32()); |
359 | |
360 | // iIdex loop corresponds to kernel_width loop (loop "s") |
361 | // in the direct conv loops |
362 | a->bind(LoopMBlocks); |
363 | a->inc(iIdx); |
364 | |
365 | // save B_buffer address |
366 | a->mov(buffer_B_saved, buffer_B); |
367 | |
368 | issueLoopOverK(mRegBlockSize); |
369 | |
370 | int rowRegs = mRegBlockSize; |
371 | |
372 | // reset A |
373 | a->sub(buffer_A, static_cast<asmjit::Imm>(i1Xich)); |
374 | |
375 | // increment A for next block |
376 | a->add( |
377 | buffer_A, |
378 | static_cast<asmjit::Imm>(rowRegs * strideXich * sizeof(uint8_t))); |
379 | |
380 | // B for next block |
381 | a->mov(buffer_B, buffer_B_saved); |
382 | |
383 | // increment C for next B block |
384 | // ldcReg already multiplied with 4 (sizeof(int32_t)) |
385 | a->imul( |
386 | C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int8_t))); |
387 | a->add(CBase, C_Offset); |
388 | |
389 | // a->add(CBase, static_cast<asmjit::Imm>(12*16*4)); |
390 | // storeCRegs<instSet>(a, 12, 1, C_Offset, ldcReg, accum); |
391 | |
392 | a->cmp(iIdx, O1RegBlocks); |
393 | a->jl(LoopMBlocks); |
394 | } |
395 | |
396 | // generate code for remainder |
397 | if (O1RegBlocksRem > 0) { |
398 | issueLoopOverK(O1RegBlocksRem); |
399 | } |
400 | |
401 | a->emitEpilog(frame); |
402 | |
403 | jit_micro_kernel_fp fn; |
404 | asmjit::Error err; |
405 | { |
406 | std::unique_lock<std::mutex> lock(rtMutex_); |
407 | err = runtime().add(&fn, &code); |
408 | } |
409 | if (err) { |
410 | std::cout << "Error: in fn add" << std::endl; |
411 | return nullptr; |
412 | } |
413 | |
414 | #if defined(FBGEMM_LOG_CODE) |
415 | fclose(codeLogfile); |
416 | delete codeLogger; |
417 | #endif |
418 | |
419 | return fn; |
420 | }); |
421 | } |
422 | |
423 | /** |
424 | * Generate AVX256 instructions for storing the C registers back to the memory |
425 | * in 32-bit Accumulation kernel. |
426 | */ |
427 | template <> |
428 | template <inst_set_t instSet> |
429 | void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegsTrans( |
430 | x86::Emitter* a, |
431 | int rowRegs, |
432 | int colRegs, |
433 | x86::Gp C_offset, |
434 | x86::Gp o1XocReg, |
435 | x86::Gp ldcReg, |
436 | bool accum) { |
437 | using VecT = typename simd_info<instSet>::vec_reg_t; |
438 | // static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
439 | |
440 | a->xor_(C_offset.r32(), C_offset.r32()); |
441 | for (int i = 0; i < rowRegs; ++i) { |
442 | for (int j = 0; j < colRegs; ++j) { |
443 | if (accum) { |
444 | a->vpaddd( |
445 | VecT(i * colRegs + j), |
446 | VecT(i * colRegs + j), |
447 | x86::dword_ptr(a->zcx(), C_offset)); |
448 | } |
449 | a->vmovups(x86::dword_ptr(a->zcx(), C_offset), VecT(i * colRegs + j)); |
450 | a->add(C_offset, ldcReg); |
451 | } |
452 | a->add(C_offset, o1XocReg); |
453 | } |
454 | } |
455 | |
456 | /** |
457 | * Generate AVX256 instructions for computing block in the rank-k update of |
458 | * 32-bit Accumulation kernel. |
459 | |
460 | The function generates the register blocking code for transposed |
461 | direct convolution |
462 | // register blocking for transposed direct convolution: |
463 | // K[0] x K[1] = 12, the corresponding 12 x 8 output elements will |
464 | // be kept in the twelve avx2 registers, which is: |
465 | // out[ih + 0..r][iw + 0..s][8] |
466 | for (int r = 0; r < K[0]; ++r) { |
467 | for (int s = 0; s < K[1]; ++s) { |
468 | int oh = ih * conv_p.stride[0] + r; |
469 | int ow = iw * conv_p.stride[1] + s; |
470 | |
471 | int a = input[((ih)*IN_DIM[1] + iw) * IC + icb]; |
472 | int b = weight |
473 | [(((((oc / 8) * K[0] + r) * K[1] + s) * (IC / 4) + icb / 4) * |
474 | 8 + |
475 | (oc % 8)) * |
476 | 4 + |
477 | (icb % 4)]; |
478 | out[((oh)*OUT_DIM[1] + ow) * OC + oc] += a * b; |
479 | } |
480 | } |
481 | } |
482 | * |
483 | */ |
484 | template <> |
485 | template <inst_set_t instSet> |
486 | void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
487 | genComputeBlockDirectConvTrans( |
488 | x86::Emitter* a, |
489 | x86::Gp buffer_A, |
490 | x86::Gp buffer_B, |
491 | x86::Gp icReg, |
492 | x86::Gp C_offset, |
493 | int rowRegs, |
494 | int colRegs) { |
495 | // static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
496 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
497 | constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS; |
498 | |
499 | // used for matrix A |
500 | VecRegT AReg(numRegs - 1); |
501 | |
502 | // used for matrix B |
503 | VecRegT BReg(numRegs - 2); |
504 | |
505 | // Contains 16-bit 1s |
506 | VecRegT oneReg(numRegs - 3); |
507 | |
508 | // temporary register |
509 | VecRegT res1(numRegs - 4); |
510 | |
511 | // load A |
512 | a->vpbroadcastd(AReg, x86::dword_ptr(buffer_A)); |
513 | |
514 | a->xor_(C_offset.r32(), C_offset.r32()); |
515 | for (int i = 0; i < rowRegs; ++i) { |
516 | for (int j = 0; j < colRegs; ++j) { |
517 | // load B, broadcast and fmas |
518 | emitLoadDWord<instSet, VecRegT>( |
519 | a, BReg, x86::dword_ptr(buffer_B, C_offset, 3, 0)); |
520 | a->vpmaddubsw(res1, AReg, BReg); |
521 | a->vpmaddwd(res1, oneReg, res1); |
522 | a->vpaddd(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j)); |
523 | a->add(C_offset, icReg); |
524 | } |
525 | // a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int8_t))); |
526 | } |
527 | } |
528 | |
529 | /** |
530 | * Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel. |
531 | * |
532 | * This function implements a direct convolution kernel that is specialized |
533 | * for kernel size (2, 6) |
534 | * |
535 | * More specifically the implementation has the following prerequisites: |
536 | * * Weights has layout {OC/8, KH, KW, IC/4, 8, 4} |
537 | * * kernel size (2, 6) |
538 | * * Features are in channel last format |
539 | * * Stride[0] = 1, Stride[1] = 1 or 2 |
540 | * * Padding = 0 |
541 | * |
542 | * mRegBlockSize = 12: the number of avx2 registers for output |
543 | * nRegBlockSize = 8: the # of output elements in one avx2 register |
544 | * row_interleave = 4: the horizontal reduction size for vpmaddubsw instruction |
545 | * stride: stride[1], 1 or 2. we stride[0] = 1 |
546 | * ic: input channel |
547 | * i1: input_width: IN_DIM[1] |
548 | * ldcReg: leading dimension of output, a.k.a OC |
549 | * o1Xoc: output width multiply output channel: OUT_DIM[1] x OC |
550 | * |
551 | * The kernel implements the following algorithm: |
552 | |
553 | for (int oc = 0; oc < OC; oc++) { |
554 | for (int ih = 0; ih < IN_DIM[0]; ++ih) { |
555 | for (int iw = 0; iw < IN_DIM[1]; iw++) { |
556 | // L1 blocking |
557 | for (int icb = 0; icb < IC; icb+=4) { |
558 | // register blocking: |
559 | // K[0] x K[1] = 12, the corresponding 12 x 8 output elements will |
560 | // be kept in the twelve avx2 registers, which is: |
561 | // out[ih + 0..r][iw + 0..s][8] |
562 | for (int r = 0; r < K[0]; ++r) { |
563 | for (int s = 0; s < K[1]; ++s) { |
564 | for (int _icb = icb ; _icb < icb + 4; _icb ++) { |
565 | int oh = ih * conv_p.stride[0] + r; |
566 | int ow = iw * conv_p.stride[1] + s; |
567 | |
568 | int a = input[((ih)*IN_DIM[1] + iw) * IC + icb]; |
569 | int b = weight |
570 | [(((((oc / 8) * K[0] + r) * K[1] + s) * (IC / 4) + icb / 4) * |
571 | 8 + |
572 | (oc % 8)) * |
573 | 4 + |
574 | (icb % 4)]; |
575 | out[((oh)*OUT_DIM[1] + ow) * OC + oc] += a * b; |
576 | |
577 | // if we get rid of the brackets, and substitude corresponding |
578 | variables: |
579 | // out[ih * stride0 * o1Xoc + r * o1Xoc + iw * ldcReg + oc] |
580 | // input[ih * i1 * ic + iw * ic + icb] |
581 | } |
582 | } |
583 | } |
584 | } |
585 | } // for each ic |
586 | } // for each s |
587 | } // for each r |
588 | |
589 | * |
590 | */ |
591 | |
592 | /** |
593 | * Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel. |
594 | * |
595 | */ |
596 | template <> |
597 | template <inst_set_t instSet> |
598 | DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
599 | jit_micro_kernel_fp_convT |
600 | DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
601 | getOrCreateDirectConvTrans( |
602 | bool accum, |
603 | int32_t stride, |
604 | int32_t numColRegs) { |
605 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
606 | constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS; |
607 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
608 | |
609 | std::tuple<bool, int, int, int> kernelSig; |
610 | // int ichSize = 32; |
611 | int mRowRegBlockSize = 2; |
612 | int mColRegBlockSize = numColRegs; |
613 | int mRegBlockSize = mRowRegBlockSize * mColRegBlockSize; |
614 | int nRegBlockSize = 8; |
615 | // int nRegBlockSizeMin; |
616 | int row_interleave = 4; |
617 | |
618 | kernelSig = std::make_tuple(accum, stride, mRegBlockSize, nRegBlockSize); |
619 | |
620 | return codeCacheT_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp_convT { |
621 | asmjit::CodeHolder code; |
622 | code.init(runtime().environment()); |
623 | x86::Assembler assembler(&code); |
624 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
625 | #if defined(FBGEMM_LOG_CODE) |
626 | // generated code logging |
627 | FILE* codeLogfile = fopen( |
628 | getCodeLoggingFile<instSet>(accum, stride, mRegBlockSize, nRegBlockSize) |
629 | .c_str(), |
630 | "w" ); |
631 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); |
632 | if (codeLogger) { |
633 | code.setLogger(codeLogger); |
634 | } |
635 | #endif |
636 | |
637 | const int maxMRegs = mRegBlockSize; |
638 | (void)maxMRegs; // Suppress unused variable warning |
639 | const int maxNRegs = nRegBlockSize * row_interleave / vectorLen; |
640 | assert( |
641 | maxMRegs * maxNRegs <= numRegs - 4 && |
642 | "MRegs x NRegs is above available registers (MAX_REGS - 4)" ); |
643 | |
644 | // arguments to the function created |
645 | x86::Gp buffer_A = a->zdi(); |
646 | x86::Gp buffer_B = a->zsi(); |
647 | x86::Gp CBase = a->zcx(); |
648 | x86::Gp ic = a->gpz(8); |
649 | x86::Gp ldcReg = a->gpz(9); |
650 | x86::Gp o1Xoc = a->gpz(10); |
651 | x86::Gp i1 = a->gpz(11); |
652 | |
653 | asmjit::FuncDetail func; |
654 | func.init( |
655 | asmjit::FuncSignatureT< |
656 | void, |
657 | uint8_t*, |
658 | int8_t*, |
659 | int32_t*, |
660 | int, |
661 | int, |
662 | int, |
663 | int>(asmjit::CallConvId::kHost), |
664 | a->environment()); |
665 | |
666 | asmjit::FuncFrame frame; |
667 | frame.init(func); |
668 | |
669 | auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
670 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15); |
671 | if (numRegs >= 16) { |
672 | dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
673 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31); |
674 | } |
675 | |
676 | frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs); |
677 | frame.setDirtyRegs( |
678 | asmjit::RegGroup::kGp, |
679 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
680 | |
681 | asmjit::FuncArgsAssignment args(&func); |
682 | args.assignAll(buffer_A, buffer_B, CBase, ic, ldcReg, o1Xoc, i1); |
683 | |
684 | args.updateFuncFrame(frame); |
685 | frame.finalize(); |
686 | |
687 | a->emitProlog(frame); |
688 | a->emitArgsAssignment(frame, args); |
689 | |
690 | asmjit::Label LoopMBlocks = a->newLabel(); |
691 | |
692 | x86::Gp C_offset = a->gpz(12); |
693 | x86::Gp buffer_B_saved = a->gpz(13); |
694 | x86::Gp iIdx = a->gpz(14); |
695 | x86::Gp kIdx = a->gpz(15); |
696 | |
697 | VecRegT oneReg(numRegs - 3); |
698 | |
699 | gen16BitVectorOne<instSet, VecRegT>(a, oneReg); |
700 | a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); |
701 | |
702 | int colRegs = maxNRegs; |
703 | |
704 | auto issueLoopOverK = [&](int rowRegs) { |
705 | asmjit::Label LoopKLabel = a->newLabel(); |
706 | |
707 | // Init C (result) vector registers |
708 | initCRegs(a, rowRegs, colRegs); |
709 | |
710 | // Loops over K: input channel |
711 | a->xor_(kIdx.r32(), kIdx.r32()); |
712 | a->bind(LoopKLabel); |
713 | |
714 | // k is incremented by row_interleave |
715 | a->add(kIdx, 4); |
716 | genComputeBlockDirectConvTrans<instSet>( |
717 | a, |
718 | buffer_A, |
719 | buffer_B, |
720 | ic, |
721 | C_offset, |
722 | mRowRegBlockSize, |
723 | mColRegBlockSize); |
724 | |
725 | // update buffer_A address for next k iteration |
726 | a->add( |
727 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
728 | |
729 | // update buffer_B address for next k iteration |
730 | a->add(buffer_B, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); |
731 | |
732 | a->cmp(kIdx, ic); |
733 | a->jl(LoopKLabel); |
734 | |
735 | // store C matrix |
736 | storeCRegsTrans<instSet>( |
737 | a, |
738 | mRowRegBlockSize, |
739 | mColRegBlockSize, |
740 | C_offset, |
741 | o1Xoc, |
742 | ldcReg, |
743 | accum); |
744 | }; |
745 | |
746 | { |
747 | // move 0 to iteration variables |
748 | a->xor_(iIdx.r32(), iIdx.r32()); |
749 | |
750 | a->bind(LoopMBlocks); |
751 | a->inc(iIdx); |
752 | |
753 | // save B_buffer address |
754 | a->mov(buffer_B_saved, buffer_B); |
755 | |
756 | issueLoopOverK(mRegBlockSize); |
757 | |
758 | // B for next block |
759 | a->mov(buffer_B, buffer_B_saved); |
760 | // increment C for next B block |
761 | a->imul( |
762 | C_offset, |
763 | ldcReg, |
764 | static_cast<asmjit::Imm>(stride)); // ldcReg already multiplied by 4 |
765 | a->add(CBase, C_offset); |
766 | |
767 | a->cmp(iIdx, i1); |
768 | a->jl(LoopMBlocks); |
769 | } |
770 | |
771 | a->emitEpilog(frame); |
772 | |
773 | jit_micro_kernel_fp_convT fn; |
774 | asmjit::Error err; |
775 | { |
776 | std::unique_lock<std::mutex> lock(rtMutex_); |
777 | err = runtime().add(&fn, &code); |
778 | } |
779 | if (err) { |
780 | std::cout << "Error: in fn add" << std::endl; |
781 | return nullptr; |
782 | } |
783 | |
784 | #if defined(FBGEMM_LOG_CODE) |
785 | fclose(codeLogfile); |
786 | delete codeLogger; |
787 | #endif |
788 | |
789 | return fn; |
790 | }); |
791 | } |
792 | |
793 | /** |
794 | * Instantiate the inst_set_t::avx512 instructions for store kernel. |
795 | * |
796 | */ |
797 | template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
798 | storeCRegs<inst_set_t::avx512>( |
799 | x86::Emitter* a, |
800 | int rowRegs, |
801 | int colRegs, |
802 | x86::Gp C_Offset, |
803 | x86::Gp ldcReg, |
804 | bool accum); |
805 | |
806 | /** |
807 | * Instantiate the inst_set_t::avx512_ymm instructions for store kernel. |
808 | * |
809 | */ |
810 | template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
811 | storeCRegs<inst_set_t::avx512_ymm>( |
812 | x86::Emitter* a, |
813 | int rowRegs, |
814 | int colRegs, |
815 | x86::Gp C_Offset, |
816 | x86::Gp ldcReg, |
817 | bool accum); |
818 | |
819 | /** |
820 | * Instantiate the inst_set_t::avx2 instructions for store kernel. |
821 | * |
822 | */ |
823 | template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
824 | storeCRegs<inst_set_t::avx2>( |
825 | x86::Emitter* a, |
826 | int rowRegs, |
827 | int colRegs, |
828 | x86::Gp C_Offset, |
829 | x86::Gp ldcReg, |
830 | bool accum); |
831 | |
832 | /** |
833 | * Instantiate the inst_set_t::avx2 instructions for store kernel. |
834 | * |
835 | */ |
836 | template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
837 | storeCRegsTrans<inst_set_t::avx512>( |
838 | x86::Emitter* a, |
839 | int rowRegs, |
840 | int colRegs, |
841 | x86::Gp C_offset, |
842 | x86::Gp o1XocReg, |
843 | x86::Gp ldcReg, |
844 | bool accum); |
845 | |
846 | /** |
847 | * Instantiate the inst_set_t::avx2 instructions for store kernel. |
848 | * |
849 | */ |
850 | template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
851 | storeCRegsTrans<inst_set_t::avx512_ymm>( |
852 | x86::Emitter* a, |
853 | int rowRegs, |
854 | int colRegs, |
855 | x86::Gp C_offset, |
856 | x86::Gp o1XocReg, |
857 | x86::Gp ldcReg, |
858 | bool accum); |
859 | |
860 | /** |
861 | * Instantiate the inst_set_t::avx2 instructions for store kernel. |
862 | * |
863 | */ |
864 | template void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
865 | storeCRegsTrans<inst_set_t::avx2>( |
866 | x86::Emitter* a, |
867 | int rowRegs, |
868 | int colRegs, |
869 | x86::Gp C_offset, |
870 | x86::Gp o1XocReg, |
871 | x86::Gp ldcReg, |
872 | bool accum); |
873 | |
874 | /** |
875 | * Instantiate the AVX2 instructions for 32-bit Accumulation macro-kernel. |
876 | * |
877 | */ |
878 | template DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
879 | jit_micro_kernel_fp |
880 | DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
881 | getOrCreateDirectConv<inst_set_t::avx2>( |
882 | bool accum, |
883 | int32_t O1, |
884 | int32_t i1Xich, |
885 | int32_t strideXich); |
886 | |
887 | /** |
888 | * Instantiate the AVX2 instructions for 32-bit Accumulation macro-kernel. |
889 | * |
890 | */ |
891 | template DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
892 | jit_micro_kernel_fp_convT |
893 | DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>:: |
894 | getOrCreateDirectConvTrans<inst_set_t::avx2>( |
895 | bool accum, |
896 | int32_t stride, |
897 | int32_t numColRegs); |
898 | |
899 | } // namespace fbgemm |
900 | |