1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#include <iostream>
8#include "./CodeGenHelpers.h"
9#include "./GenerateKernel.h"
10
11namespace fbgemm {
12
13namespace x86 = asmjit::x86;
14
15/**
16 * Generate AVX512 instructions for computing block in the rank-k update of
17 * 32-bit Accumulation kernel.
18 */
19template <>
20template <inst_set_t instSet>
21void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock(
22 x86::Emitter* a,
23 x86::Gp buffer_A,
24 x86::Gp buffer_B,
25 x86::Gp B_pf,
26 int rowRegs,
27 int colRegs,
28 int lda) {
29 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
30 using VecRegT = typename simd_info<instSet>::vec_reg_t;
31 constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS;
32
33 // used for matrix A
34 VecRegT AReg(numRegs - 1);
35
36 // used for matrix B
37 VecRegT BReg(numRegs - 2);
38
39 // Contains 16-bit 1s
40 VecRegT oneReg(numRegs - 3);
41
42 // temporary register
43 VecRegT res1(numRegs - 4);
44
45 for (int j = 0; j < colRegs; ++j) {
46 // load B
47 emitLoadDWord<instSet, VecRegT>(
48 a, BReg, x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t)));
49 // load A, broadcast and fmas
50 for (int i = 0; i < rowRegs; ++i) {
51 a->vpbroadcastd(
52 AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
53 a->vpmaddubsw(res1, AReg, BReg);
54 a->vpmaddwd(res1, oneReg, res1);
55 a->vpaddd(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j));
56 }
57 a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int8_t)));
58 }
59}
60
61/**
62 * Generate AVX512 instructions for storing the C registers back to the memory
63 * in 32-bit Accumulation kernel.
64 */
65template <>
66template <inst_set_t instSet>
67void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs(
68 x86::Emitter* a,
69 int rowRegs,
70 int colRegs,
71 x86::Gp C_Offset,
72 x86::Gp ldcReg,
73 bool accum) {
74 using VecT = typename simd_info<instSet>::vec_reg_t;
75 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
76
77 for (int i = 0; i < rowRegs; ++i) {
78 if (i != 0) {
79 a->add(C_Offset, ldcReg);
80 } else {
81 a->xor_(C_Offset.r32(), C_Offset.r32());
82 }
83 for (int j = 0; j < colRegs; ++j) {
84 if (accum) {
85 a->vpaddd(
86 VecT(i * colRegs + j),
87 VecT(i * colRegs + j),
88 x86::dword_ptr(
89 a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)));
90 }
91 a->vmovups(
92 x86::dword_ptr(a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)),
93 VecT(i * colRegs + j));
94 }
95 }
96}
97
98/**
99 * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
100 *
101 */
102template <>
103template <inst_set_t instSet>
104CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
105CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate(
106 bool accum,
107 int32_t mc,
108 int32_t nc,
109 int32_t kc) {
110 (void)kc; // Suppress unused variable warning
111 using VecRegT = typename simd_info<instSet>::vec_reg_t;
112 constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS;
113 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
114
115 std::tuple<bool, int, int, int, int, int, int> kernelSig;
116 int kBlock;
117 int nBlock;
118 int mRegBlockSize;
119 int nRegBlockSize;
120 int nRegBlockSizeMin;
121 int row_interleave;
122
123 if (blocking_params) {
124 kBlock = blocking_params->KCB;
125 nBlock = blocking_params->NCB;
126 mRegBlockSize = blocking_params->MR;
127 nRegBlockSize = blocking_params->NR;
128 nRegBlockSizeMin = blocking_params->NR_MIN;
129 row_interleave = blocking_params->ROW_INTERLEAVE;
130 } else {
131 kBlock = PackingTraits<uint8_t, int32_t, instSet>::KCB;
132 nBlock = PackingTraits<uint8_t, int32_t, instSet>::NCB;
133 mRegBlockSize = PackingTraits<uint8_t, int32_t, instSet>::MR;
134 nRegBlockSize = PackingTraits<uint8_t, int32_t, instSet>::NR;
135 nRegBlockSizeMin = PackingTraits<uint8_t, int32_t, instSet>::NR_MIN;
136 row_interleave = PackingTraits<uint8_t, int32_t, instSet>::ROW_INTERLEAVE;
137 }
138 (void)nRegBlockSizeMin; // Suppress unused variable warning
139
140 kernelSig = std::make_tuple(
141 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize);
142
143 return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
144 asmjit::CodeHolder code;
145 code.init(runtime().environment());
146 x86::Assembler assembler(&code);
147 x86::Emitter* a = assembler.as<x86::Emitter>();
148#if defined(FBGEMM_LOG_CODE)
149 // generated code logging
150 FILE* codeLogfile = fopen(
151 getCodeLoggingFile<instSet>(
152 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize)
153 .c_str(),
154 "w");
155 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
156 if (codeLogger) {
157 code.setLogger(codeLogger);
158 }
159#endif
160
161 assert(
162 kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
163 assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
164 const int maxMRegs = mRegBlockSize;
165 (void)maxMRegs; // Suppress unused variable warning
166 const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
167 assert(
168 maxMRegs * maxNRegs <= numRegs - 4 &&
169 "MRegs x NRegs is above available registers (MAX_REGS - 4)");
170
171 int mRegBlocks = mc / mRegBlockSize;
172 int mRegBlocksRem = mc % mRegBlockSize;
173
174 // arguments to the function created
175 x86::Gp buffer_A = a->zdi();
176 x86::Gp buffer_B = a->zsi();
177 x86::Gp B_pf = a->zdx();
178 x86::Gp CBase = a->zcx();
179 x86::Gp kSize = a->gpz(8);
180 x86::Gp ldcReg = a->gpz(9);
181
182 asmjit::FuncDetail func;
183 func.init(
184 asmjit::FuncSignatureT<
185 void,
186 uint8_t*,
187 int8_t*,
188 int8_t*,
189 int32_t*,
190 int,
191 int>(asmjit::CallConvId::kHost),
192 a->environment());
193
194 asmjit::FuncFrame frame;
195 frame.init(func);
196
197 auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
198 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15);
199 if (numRegs >= 16) {
200 dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
201 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31);
202 }
203
204 frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs);
205 frame.setDirtyRegs(
206 asmjit::RegGroup::kGp,
207 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
208
209 asmjit::FuncArgsAssignment args(&func);
210 args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
211
212 args.updateFuncFrame(frame);
213 frame.finalize();
214
215 a->emitProlog(frame);
216 a->emitArgsAssignment(frame, args);
217
218 asmjit::Label LoopMBlocks = a->newLabel();
219 asmjit::Label LoopNBlocks = a->newLabel();
220
221 x86::Gp buffer_B_saved = a->gpz(10);
222 x86::Gp C_Offset = a->gpz(11);
223 x86::Gp B_pf_saved = a->gpz(12);
224 x86::Gp iIdx = a->gpz(13);
225 x86::Gp jIdx = a->gpz(14);
226 x86::Gp kIdx = a->gpz(15);
227 // x86::Gp B_pf = a->gpz(8);
228
229 VecRegT oneReg(numRegs - 3);
230
231 gen16BitVectorOne<instSet, VecRegT>(a, oneReg);
232 a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
233 // a->xor_(C_Offset.r32(), C_Offset.r32());
234
235 // save B_buffer address
236 a->mov(buffer_B_saved, buffer_B);
237 a->mov(B_pf_saved, B_pf);
238
239 int currColRegs = nc * row_interleave / vectorLen;
240 int colRegs = std::min(currColRegs, maxNRegs);
241
242 auto issueLoopOverK = [&](int rowRegs) {
243 asmjit::Label LoopKLabel = a->newLabel();
244
245 // Init C (result) vector registers
246 initCRegs(a, rowRegs, colRegs);
247
248 // Loops over K
249 a->xor_(kIdx.r32(), kIdx.r32());
250 a->bind(LoopKLabel);
251
252 // k is incremented by row_interleave
253 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
254
255 genComputeBlock<instSet>(
256 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
257
258 // update buffer_A address for next k iteration
259 a->add(
260 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
261
262 // update buffer_B address for next k iteration
263 a->add(
264 buffer_B,
265 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
266 a->add(
267 B_pf,
268 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
269
270 a->cmp(kIdx, kSize);
271 a->jl(LoopKLabel);
272
273 // store C matrix
274 storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
275 };
276
277 if (mRegBlocks > 0) {
278 // move 0 to iteration variables
279 a->xor_(iIdx.r32(), iIdx.r32());
280
281 a->bind(LoopMBlocks);
282 a->inc(iIdx);
283
284 a->xor_(jIdx.r32(), jIdx.r32());
285
286 a->bind(LoopNBlocks);
287 a->inc(jIdx);
288
289 issueLoopOverK(mRegBlockSize);
290
291 int rowRegs = mRegBlockSize;
292
293 // reset A
294 a->sub(buffer_A, kSize);
295
296 // B for next block
297 a->mov(buffer_B, buffer_B_saved);
298 // using C_Offset as temp reg
299 a->imul(
300 C_Offset,
301 jIdx,
302 static_cast<asmjit::Imm>(
303 nRegBlockSize * row_interleave * sizeof(int8_t)));
304 a->add(buffer_B, C_Offset);
305 a->mov(B_pf, B_pf_saved);
306 a->add(B_pf, C_Offset);
307
308 // increment C for next B block
309 a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
310
311 int jLoopTrips = currColRegs / maxNRegs;
312 // jLoopTrips should be at least 1
313 jLoopTrips = jLoopTrips ? jLoopTrips : 1;
314 a->cmp(jIdx, jLoopTrips);
315 a->jl(LoopNBlocks);
316
317 // increment A for next block
318 a->add(
319 buffer_A,
320 static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
321
322 // increment C for next A block
323 a->sub(
324 CBase,
325 static_cast<asmjit::Imm>(
326 jLoopTrips * nRegBlockSize * sizeof(int32_t)));
327 a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
328 a->add(CBase, C_Offset);
329
330 // reset B
331 a->mov(buffer_B, buffer_B_saved);
332 a->mov(B_pf, B_pf_saved);
333 a->cmp(iIdx, mRegBlocks);
334 a->jl(LoopMBlocks);
335 }
336 // generate code for remainder
337 if (mRegBlocksRem > 0) {
338 asmjit::Label LoopNRem = a->newLabel();
339
340 a->xor_(jIdx.r32(), jIdx.r32());
341 a->bind(LoopNRem);
342 a->inc(jIdx);
343
344 issueLoopOverK(mRegBlocksRem);
345
346 // increment C for next B block
347 a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
348
349 int jLoopTrips = currColRegs / maxNRegs;
350 // jLoopTrips should be at least 1
351 jLoopTrips = jLoopTrips ? jLoopTrips : 1;
352 a->cmp(jIdx, jLoopTrips);
353 a->jl(LoopNRem);
354 }
355
356 a->emitEpilog(frame);
357
358 jit_micro_kernel_fp fn;
359 asmjit::Error err;
360 {
361 std::unique_lock<std::mutex> lock(rtMutex_);
362 err = runtime().add(&fn, &code);
363 }
364 if (err) {
365 std::cout << "Error: in fn add" << std::endl;
366 return nullptr;
367 }
368
369#if defined(FBGEMM_LOG_CODE)
370 fclose(codeLogfile);
371 delete codeLogger;
372#endif
373
374 return fn;
375 });
376}
377
378/**
379 * Instantiate the AVX512 instructions for 32-bit Accumulation macro-kernel.
380 *
381 */
382template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
383CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
384 bool accum,
385 int32_t mc,
386 int32_t nc,
387 int32_t kc);
388
389/**
390 * Instatiate the AVX512_256 instructions for 32-bit Accumulation macro-kernel.
391 *
392 */
393template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
394CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
395 inst_set_t::avx512_ymm>(bool accum, int32_t mc, int32_t nc, int32_t kc);
396
397/**
398 * Instantiate the AVX2 instructions for 32-bit Accumulation macro-kernel.
399 *
400 */
401template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
402CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
403 bool accum,
404 int32_t mc,
405 int32_t nc,
406 int32_t kc);
407
408/**
409 * Instantiate the inst_set_t::avx512 instructions for store kernel.
410 *
411 */
412template void
413CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<inst_set_t::avx512>(
414 x86::Emitter* a,
415 int rowRegs,
416 int colRegs,
417 x86::Gp C_Offset,
418 x86::Gp ldcReg,
419 bool accum);
420
421/**
422 * Instantiate the inst_set_t::avx512_ymm instructions for store kernel.
423 *
424 */
425template void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
426 inst_set_t::avx512_ymm>(
427 x86::Emitter* a,
428 int rowRegs,
429 int colRegs,
430 x86::Gp C_Offset,
431 x86::Gp ldcReg,
432 bool accum);
433
434/**
435 * Instantiate the inst_set_t::avx2 instructions for store kernel.
436 *
437 */
438template void
439CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<inst_set_t::avx2>(
440 x86::Emitter* a,
441 int rowRegs,
442 int colRegs,
443 x86::Gp C_Offset,
444 x86::Gp ldcReg,
445 bool accum);
446
447} // namespace fbgemm
448