1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#include <iostream>
8#include "./GenerateKernel.h"
9
10namespace fbgemm {
11
12namespace x86 = asmjit::x86;
13
14/**
15 * Generate AVX512 instructions for computing block in the rank-k update of
16 * 32-bit Accmulation kernel.
17 */
18template <>
19template <inst_set_t instSet>
20void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock(
21 x86::Emitter* a,
22 x86::Gp buffer_A,
23 x86::Gp buffer_B,
24 x86::Gp /*B_pf*/,
25 int rowRegs,
26 int colRegs,
27 int lda) {
28 assert(colRegs * (rowRegs + 1) <= 31);
29 using VecRegT = typename simd_info<instSet>::vec_reg_t;
30 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
31
32 // used for matrix A
33 VecRegT AReg(31);
34 for (int j = 0; j < colRegs; ++j) {
35 a->vmovdqa32(
36 VecRegT(30 - j),
37 x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t)));
38 }
39
40 for (int i = 0; i < rowRegs; i++) {
41 a->vpbroadcastd(
42 AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
43 for (int j = 0; j < colRegs; ++j) {
44 a->vpdpbusd(VecRegT(i * colRegs + j), AReg, VecRegT(30 - j));
45 }
46 }
47}
48
49/**
50 * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
51 *
52 */
53template <>
54template <inst_set_t instSet>
55CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
56CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate(
57 bool accum,
58 int32_t mc,
59 int32_t nc,
60 int32_t kc) {
61 (void)kc; // Suppress unused variable warning
62 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
63 static constexpr inst_set_t storeInstType =
64 simd_info<instSet>::WIDTH_BITS == 512 ? inst_set_t::avx512
65 : inst_set_t::avx512_ymm;
66
67 std::tuple<bool, int, int, int, int, int, int> kernelSig;
68 int kBlock;
69 int nBlock;
70 int mRegBlockSize;
71 int nRegBlockSize;
72 int nRegBlockSizeMin;
73 int row_interleave;
74
75 if (blocking_params) {
76 kBlock = blocking_params->KCB;
77 nBlock = blocking_params->NCB;
78 mRegBlockSize = blocking_params->MR;
79 nRegBlockSize = blocking_params->NR;
80 nRegBlockSizeMin = blocking_params->NR_MIN;
81 row_interleave = blocking_params->ROW_INTERLEAVE;
82 } else {
83 kBlock = PackingTraits<uint8_t, int32_t, instSet>::KCB;
84 nBlock = PackingTraits<uint8_t, int32_t, instSet>::NCB;
85 mRegBlockSize = PackingTraits<uint8_t, int32_t, instSet>::MR;
86 nRegBlockSize = PackingTraits<uint8_t, int32_t, instSet>::NR;
87 nRegBlockSizeMin = PackingTraits<uint8_t, int32_t, instSet>::NR_MIN;
88 row_interleave = PackingTraits<uint8_t, int32_t, instSet>::ROW_INTERLEAVE;
89 }
90 (void)nRegBlockSizeMin; // Suppress unused variable warning
91
92 kernelSig = std::make_tuple(
93 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize);
94
95 return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
96 asmjit::CodeHolder code;
97 code.init(runtime().environment());
98 x86::Assembler assembler(&code);
99 x86::Emitter* a = assembler.as<x86::Emitter>();
100
101#if defined(FBGEMM_LOG_CODE)
102 // generated code logging
103 FILE* codeLogfile = fopen(
104 getCodeLoggingFile<instSet>(
105 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize)
106 .c_str(),
107 "w");
108 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
109 if (codeLogger) {
110 code.setLogger(codeLogger);
111 }
112#endif
113
114 assert(
115 kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
116 assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
117 const int maxMRegs = mRegBlockSize;
118 const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
119 (void)maxMRegs; // Suppress unused variable warning
120 assert(
121 maxMRegs * maxNRegs <= 30 &&
122 "MR*(NR*ROW_INTERLEAVE*8/512) \
123 must be <= 30(available registers constraint)");
124
125 int mRegBlocks = mc / mRegBlockSize;
126 int mRegBlocksRem = mc % mRegBlockSize;
127
128 // arguments to the function created
129 x86::Gp buffer_A = a->zdi();
130 x86::Gp buffer_B = a->zsi();
131 x86::Gp B_pf = a->zdx();
132 x86::Gp CBase = a->zcx();
133 x86::Gp kSize = a->gpz(8);
134 x86::Gp ldcReg = a->gpz(9);
135
136 asmjit::FuncDetail func;
137 func.init(
138 asmjit::FuncSignatureT<
139 void,
140 uint8_t*,
141 int8_t*,
142 int8_t*,
143 int32_t*,
144 int,
145 int>(asmjit::CallConvId::kHost),
146 a->environment());
147
148 asmjit::FuncFrame frame;
149 frame.init(func);
150
151 frame.setDirtyRegs(
152 asmjit::RegGroup::kVec,
153 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
154 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
155 asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
156 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
157 frame.setDirtyRegs(
158 asmjit::RegGroup::kGp,
159 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
160
161 asmjit::FuncArgsAssignment args(&func);
162 args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
163
164 args.updateFuncFrame(frame);
165 frame.finalize();
166
167 a->emitProlog(frame);
168 a->emitArgsAssignment(frame, args);
169
170 asmjit::Label LoopMBlocks = a->newLabel();
171 asmjit::Label LoopNBlocks = a->newLabel();
172 asmjit::Label Loopk = a->newLabel();
173
174 x86::Gp buffer_B_saved = a->gpz(10);
175 x86::Gp C_Offset = a->gpz(11);
176 x86::Gp B_pf_saved = a->gpz(12);
177 x86::Gp iIdx = a->gpz(13);
178 x86::Gp jIdx = a->gpz(14);
179 x86::Gp kIdx = a->gpz(15);
180 // x86::Gp B_pf = a->gpz(8);
181
182 x86::Zmm oneReg = x86::zmm29;
183 // create 16-bit 1s
184 // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
185 // and so on
186 // a->vpcmpeqw(oneReg, oneReg, oneReg);
187 a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
188 a->vpsrlw(oneReg, oneReg, 15);
189 a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
190
191 // save B_buffer address
192 a->mov(buffer_B_saved, buffer_B);
193 a->mov(B_pf_saved, B_pf);
194
195 int currColRegs = nc * row_interleave / vectorLen;
196 int colRegs = std::min(currColRegs, maxNRegs);
197 if (mRegBlocks > 0) {
198 // move 0 to iteration variables
199 a->xor_(iIdx.r32(), iIdx.r32());
200
201 a->bind(LoopMBlocks);
202 a->inc(iIdx);
203 a->xor_(jIdx.r32(), jIdx.r32());
204
205 a->bind(LoopNBlocks);
206 a->inc(jIdx);
207
208 int rowRegs = mRegBlockSize;
209
210 // init C registers
211 initCRegs(a, rowRegs, colRegs);
212
213 // init k loop index
214 a->xor_(kIdx.r32(), kIdx.r32());
215 a->bind(Loopk);
216
217 // k is incremented by row_interleave
218 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
219
220 genComputeBlock<instSet>(
221 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
222
223 // update buffer_A address for next k iteration
224 a->add(
225 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
226
227 // update buffer_B address for next k iteration
228 a->add(
229 buffer_B,
230 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
231 a->add(
232 B_pf,
233 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
234
235 // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
236
237 a->cmp(kIdx, kSize);
238 a->jl(Loopk);
239
240 // store C matrix
241 storeCRegs<storeInstType>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
242
243 // reset A
244 a->sub(buffer_A, kSize);
245
246 // B for next block
247 a->mov(buffer_B, buffer_B_saved);
248 // using C_Offset as temp reg
249 a->imul(
250 C_Offset,
251 jIdx,
252 static_cast<asmjit::Imm>(
253 nRegBlockSize * row_interleave * sizeof(int8_t)));
254 a->add(buffer_B, C_Offset);
255 a->mov(B_pf, B_pf_saved);
256 a->add(B_pf, C_Offset);
257
258 // increment C for next B block
259 a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
260
261 int jLoopTrips = currColRegs / maxNRegs;
262 // jLoopTrips should be at least 1
263 jLoopTrips = jLoopTrips ? jLoopTrips : 1;
264 a->cmp(jIdx, jLoopTrips);
265 a->jl(LoopNBlocks);
266
267 // increment A for next block
268 a->add(
269 buffer_A,
270 static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
271
272 // increment C for next A block
273 a->sub(
274 CBase,
275 static_cast<asmjit::Imm>(
276 jLoopTrips * nRegBlockSize * sizeof(int32_t)));
277 a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
278 a->add(CBase, C_Offset);
279
280 // reset B
281 a->mov(buffer_B, buffer_B_saved);
282 a->mov(B_pf, B_pf_saved);
283 a->cmp(iIdx, mRegBlocks);
284 a->jl(LoopMBlocks);
285 }
286 // generate code for remainder
287 if (mRegBlocksRem > 0) {
288 asmjit::Label LoopNRem = a->newLabel();
289 asmjit::Label LoopkRem = a->newLabel();
290 int rowRegs = mRegBlocksRem;
291
292 a->xor_(jIdx.r32(), jIdx.r32());
293 a->bind(LoopNRem);
294 a->inc(jIdx);
295
296 // init C registers
297 initCRegs(a, rowRegs, colRegs);
298
299 // init k loop index
300 a->xor_(kIdx.r32(), kIdx.r32());
301 a->bind(LoopkRem);
302
303 // k is incremented by row_interleave
304 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
305
306 genComputeBlock<instSet>(
307 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
308
309 // update buffer_A address for next k iteration
310 a->add(
311 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
312
313 // update buffer_B address for next k iteration
314 a->add(
315 buffer_B,
316 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
317 a->add(
318 B_pf,
319 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
320
321 a->cmp(kIdx, kSize);
322 a->jl(LoopkRem);
323
324 // reset A
325 a->sub(buffer_A, kSize);
326 // B for next block
327 // using C_Offset as temp reg
328 a->imul(
329 C_Offset,
330 jIdx,
331 static_cast<asmjit::Imm>(
332 nRegBlockSize * row_interleave * sizeof(int8_t)));
333 a->mov(buffer_B, buffer_B_saved);
334 a->add(buffer_B, C_Offset);
335 a->mov(B_pf, B_pf_saved);
336 a->add(B_pf, C_Offset);
337
338 // store C matrix
339 storeCRegs<storeInstType>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
340
341 // increment C for next B block
342 a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
343
344 int jLoopTrips = currColRegs / maxNRegs;
345 // jLoopTrips should be at least 1
346 jLoopTrips = jLoopTrips ? jLoopTrips : 1;
347 a->cmp(jIdx, jLoopTrips);
348 a->jl(LoopNRem);
349 }
350
351 a->emitEpilog(frame);
352
353 jit_micro_kernel_fp fn;
354 asmjit::Error err;
355 {
356 std::unique_lock<std::mutex> lock(rtMutex_);
357 err = runtime().add(&fn, &code);
358 }
359 if (err) {
360 std::cout << "Error: in fn add" << std::endl;
361 return nullptr;
362 }
363
364#if defined(FBGEMM_LOG_CODE)
365 fclose(codeLogfile);
366 delete codeLogger;
367#endif
368
369 return fn;
370 });
371}
372
373/**
374 * Instatiate the AVX512_VNNI instructions for 32-bit Accumulation macro-kernel.
375 *
376 */
377template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
378CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
379 inst_set_t::avx512_vnni>(bool accum, int32_t mc, int32_t nc, int32_t kc);
380
381/**
382 * Instatiate the AVX512_VNNI_256 instructions for 32-bit Accumulation
383 * macro-kernel.
384 *
385 */
386template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
387CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
388 inst_set_t::avx512_vnni_ymm>(
389 bool accum,
390 int32_t mc,
391 int32_t nc,
392 int32_t kc);
393
394} // namespace fbgemm
395