1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#include <iostream>
8#include "./CodeGenHelpers.h"
9#include "./GenerateKernel.h"
10
11namespace fbgemm {
12
13namespace x86 = asmjit::x86;
14
15/**
16 * Generate AVX2 instructions for computing block in the rank-k update of 16-bit
17 * Accmulation kernel.
18 */
19template <>
20template <>
21void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
22 inst_set_t::avx2>(
23 x86::Emitter* a,
24 x86::Gp buffer_A,
25 x86::Gp buffer_B,
26 x86::Gp /* unused (reserved for prefetching)*/,
27 int rowRegs,
28 int colRegs,
29 int lda) {
30 using CRegs = x86::Ymm;
31 static constexpr int vectorLen = simd_info<inst_set_t::avx2>::WIDTH_BYTES;
32
33 // used for matrix A
34 x86::Ymm AReg = x86::ymm13;
35 x86::Ymm tmpReg = x86::ymm14;
36 for (int i = 0; i < rowRegs; ++i) {
37 // broadcast A
38 a->vpbroadcastw(
39 AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
40 for (int j = 0; j < colRegs; ++j) {
41 a->vpmaddubsw(
42 tmpReg,
43 AReg,
44 x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t)));
45 a->vpaddsw(CRegs(i * colRegs + j), tmpReg, CRegs(i * colRegs + j));
46 // Prefetching is hurting performance in some cases
47 // because prefetch instructions itself consumes a slot
48 // in pipeline issue thus slowing down the kernel.
49 // if((i == rowRegs - 1) && j % 2 == 0){
50 // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t)));
51 //}
52 }
53 }
54}
55
56/**
57 * Generate instructions for storing the C registers back to the memory
58 * in 16-bit Accumulation kernel.
59 */
60template <>
61template <inst_set_t instSet>
62void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs(
63 x86::Emitter* a,
64 int rowRegs,
65 int colRegs,
66 x86::Gp C_Offset,
67 x86::Gp ldcReg,
68 bool accum) {
69 using VecT = typename simd_info<instSet>::vec_reg_t;
70 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
71
72 VecT extractDestFull(simd_info<instSet>::NUM_VEC_REGS - 1);
73 auto extractDestHalf = extractDestFull.half();
74
75 for (int i = 0; i < rowRegs; ++i) {
76 a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
77 for (int j = 0; j < colRegs; ++j) {
78 for (int idx = 0; idx < 2; ++idx) {
79 emitExtractHalfVector<instSet, VecT>(
80 a, extractDestHalf, VecT(i * colRegs + j), idx);
81 a->vpmovsxwd(extractDestFull, extractDestHalf);
82 x86::Mem destAddr =
83 x86::dword_ptr(a->zcx(), C_Offset, 0, (j * 2 + idx) * vectorLen);
84 if (accum) {
85 a->vpaddd(extractDestFull, extractDestFull, destAddr);
86 }
87 a->vmovups(destAddr, extractDestFull);
88 }
89 }
90 }
91}
92
93/**
94 * Get or Create the AVX2 instructions for 16-bit Accumulation macro-kernel.
95 *
96 */
97template <>
98template <>
99CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
100CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
101 bool accum,
102 int32_t mc,
103 int32_t nc,
104 int32_t kc) {
105 (void)kc; // Suppress unused variable warning
106 constexpr int vectorLen = simd_info<inst_set_t::avx2>::WIDTH_BYTES;
107
108 std::tuple<bool, int, int, int, int, int, int> kernelSig;
109 int kBlock;
110 int nBlock;
111 int mRegBlockSize;
112 int nRegBlockSize;
113 int nRegBlockSizeMin;
114 int row_interleave;
115
116 if (blocking_params) {
117 kBlock = blocking_params->KCB;
118 nBlock = blocking_params->NCB;
119 mRegBlockSize = blocking_params->MR;
120 nRegBlockSize = blocking_params->NR;
121 nRegBlockSizeMin = blocking_params->NR_MIN;
122 row_interleave = blocking_params->ROW_INTERLEAVE;
123 } else {
124 kBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::KCB;
125 nBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NCB;
126 mRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::MR;
127 nRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NR;
128 nRegBlockSizeMin =
129 PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NR_MIN;
130 row_interleave =
131 PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE;
132 }
133 (void)nRegBlockSizeMin; // Suppress unused variable warning
134
135 kernelSig = std::make_tuple(
136 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize);
137
138 return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
139 asmjit::CodeHolder code;
140 code.init(runtime().environment());
141 x86::Assembler assembler(&code);
142 x86::Emitter* a = assembler.as<x86::Emitter>();
143
144#if defined(FBGEMM_LOG_CODE)
145 // generated code logging
146 FILE* codeLogfile = fopen(
147 getCodeLoggingFile<inst_set_t::avx2>(
148 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize)
149 .c_str(),
150 "w");
151 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
152 if (codeLogger) {
153 code.setLogger(codeLogger);
154 }
155#endif
156
157 assert(
158 kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
159 assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
160 const int maxMRegs = mRegBlockSize;
161 const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
162 (void)maxMRegs; // Suppress unused variable warning
163 (void)maxNRegs; // Suppress unused variable warning
164 assert(
165 maxMRegs * maxNRegs <= 13 &&
166 "MR*(NR*ROW_INTERLEAVE*8/256"
167 "must be <= 13(available registers constraint)");
168
169 int mRegBlocks = mc / mRegBlockSize;
170 int mRegBlocksRem = mc % mRegBlockSize;
171
172 // assert((nc == nRegBlockSize) &&
173 //"nc must be equal to the number of register blocks");
174
175 // arguments to the function created
176 x86::Gp buffer_A = a->zdi();
177 x86::Gp buffer_B = a->zsi();
178 x86::Gp B_pf = a->zdx();
179 x86::Gp CBase = a->zcx();
180 x86::Gp kSize = a->gpz(8);
181 x86::Gp ldcReg = a->gpz(9);
182
183 asmjit::FuncDetail func;
184 func.init(
185 asmjit::FuncSignatureT<
186 void,
187 uint8_t*,
188 int8_t*,
189 int8_t*,
190 int32_t*,
191 int,
192 int>(asmjit::CallConvId::kHost),
193 a->environment());
194
195 asmjit::FuncFrame frame;
196 frame.init(func);
197 frame.setDirtyRegs(
198 asmjit::RegGroup::kVec,
199 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
200 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
201 frame.setDirtyRegs(
202 asmjit::RegGroup::kGp,
203 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
204
205 asmjit::FuncArgsAssignment args(&func);
206 args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
207
208 args.updateFuncFrame(frame);
209 frame.finalize();
210
211 a->emitProlog(frame);
212 a->emitArgsAssignment(frame, args);
213
214 asmjit::Label Loopk = a->newLabel();
215 asmjit::Label LoopMBlocks = a->newLabel();
216
217 x86::Gp buffer_B_saved = a->gpz(10);
218 x86::Gp C_Offset = a->gpz(11);
219 // x86::Gp B_pf_saved = a->gpz(12);
220 x86::Gp iIdx = a->gpz(13);
221 x86::Gp kIdx = a->gpz(14);
222
223 int colRegs = nc * row_interleave / vectorLen;
224 if (mRegBlocks > 0) {
225 // move 0 to iteration variables
226 a->xor_(iIdx.r32(), iIdx.r32());
227
228 // save B_buffer address
229 a->mov(buffer_B_saved, buffer_B);
230 // a->mov(B_pf_saved, B_pf);
231
232 a->bind(LoopMBlocks);
233 a->inc(iIdx);
234
235 int rowRegs = mRegBlockSize;
236
237 // init C registers
238 initCRegs(a, rowRegs, colRegs);
239
240 // init k loop index
241 a->xor_(kIdx.r32(), kIdx.r32());
242 a->bind(Loopk);
243 // k is incremented by row_interleave
244 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
245
246 genComputeBlock<inst_set_t::avx2>(
247 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
248
249 // update buffer_A address for next k iteration
250 a->add(
251 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
252
253 // update buffer_B address for next k iteration
254 a->add(
255 buffer_B,
256 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
257 // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
258 // sizeof(int8_t)));
259
260 a->cmp(kIdx, kSize);
261 a->jl(Loopk);
262
263 // store C matrix
264 storeCRegs<inst_set_t::avx2>(
265 a, rowRegs, colRegs, C_Offset, ldcReg, accum);
266
267 // increment A for next block
268 a->sub(buffer_A, kSize);
269 a->add(
270 buffer_A,
271 static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
272 // increment C for next block
273 a->imul(
274 C_Offset,
275 ldcReg,
276 static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
277 a->add(CBase, C_Offset);
278 // reset B
279 a->mov(buffer_B, buffer_B_saved);
280 // a->mov(B_pf, B_pf_saved);
281
282 a->cmp(iIdx, mRegBlocks);
283 a->jl(LoopMBlocks);
284 }
285 // generate code for remainder
286 if (mRegBlocksRem > 0) {
287 asmjit::Label LoopkRem = a->newLabel();
288 int rowRegs = mRegBlocksRem;
289
290 // init C registers
291 initCRegs(a, rowRegs, colRegs);
292
293 // init k loop index
294 a->xor_(kIdx.r32(), kIdx.r32());
295 a->bind(LoopkRem);
296
297 // k is incremented by row_interleave
298 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
299
300 genComputeBlock<inst_set_t::avx2>(
301 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
302
303 // update buffer_A address for next k iteration
304 a->add(
305 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
306
307 // update buffer_B address for next k iteration
308 a->add(
309 buffer_B,
310 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
311 // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
312 // sizeof(int8_t)));
313
314 a->cmp(kIdx, kSize);
315 a->jl(LoopkRem);
316
317 // store C matrix
318 storeCRegs<inst_set_t::avx2>(
319 a, rowRegs, colRegs, C_Offset, ldcReg, accum);
320 }
321
322 a->emitEpilog(frame);
323
324 jit_micro_kernel_fp fn;
325 asmjit::Error err;
326 {
327 std::unique_lock<std::mutex> lock(rtMutex_);
328 err = runtime().add(&fn, &code);
329 }
330 if (err) {
331 std::cout << "Error: in fn add" << std::endl;
332 return nullptr;
333 }
334
335#if defined(FBGEMM_LOG_CODE)
336 fclose(codeLogfile);
337 delete codeLogger;
338#endif
339
340 return fn;
341 });
342}
343
344/**
345 * Instantiate the inst_set_t::avx2 instructions for store kernel.
346 *
347 */
348template void
349CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<inst_set_t::avx2>(
350 x86::Emitter* a,
351 int rowRegs,
352 int colRegs,
353 x86::Gp C_Offset,
354 x86::Gp ldcReg,
355 bool accum);
356
357/**
358 * Instantiate the inst_set_t::avx512 instructions for store kernel.
359 *
360 */
361template void
362CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<inst_set_t::avx512>(
363 x86::Emitter* a,
364 int rowRegs,
365 int colRegs,
366 x86::Gp C_Offset,
367 x86::Gp ldcReg,
368 bool accum);
369
370/**
371 * Instantiate the inst_set_t::avx512_ymm instructions for store kernel.
372 *
373 */
374template void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
375 inst_set_t::avx512_ymm>(
376 x86::Emitter* a,
377 int rowRegs,
378 int colRegs,
379 x86::Gp C_Offset,
380 x86::Gp ldcReg,
381 bool accum);
382
383} // namespace fbgemm
384