1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#include <iostream>
8#include "./CodeGenHelpers.h"
9#include "./GenerateKernel.h"
10
11namespace fbgemm {
12
13namespace x86 = asmjit::x86;
14
15/**
16 * Generate AVX512 instructions for computing block in the rank-k update of
17 * 16-bit Accmulation kernel.
18 */
19template <>
20template <inst_set_t instSet>
21void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock(
22 x86::Emitter* a,
23 x86::Gp buffer_A,
24 x86::Gp buffer_B,
25 x86::Gp /* unused (reserved for prefetching)*/,
26 int rowRegs,
27 int colRegs,
28 int lda) {
29 using VecRegT = typename simd_info<instSet>::vec_reg_t;
30 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
31
32 // used for matrix A
33 VecRegT AReg(29);
34
35 VecRegT tmpReg(30);
36
37 // We start allocating BRegs from zmm28 and then allocate zmm27 and so on.
38 for (int j = 0; j < colRegs; ++j) {
39 a->vmovups(
40 VecRegT(28 - j),
41 x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t)));
42 }
43
44 for (int i = 0; i < rowRegs; ++i) {
45 // broadcast A
46 a->vpbroadcastw(
47 AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
48 for (int j = 0; j < colRegs; ++j) {
49 a->vpmaddubsw(tmpReg, AReg, VecRegT(28 - j));
50 a->vpaddsw(VecRegT(i * colRegs + j), tmpReg, VecRegT(i * colRegs + j));
51 // Prefetching is hurting performance in some cases
52 // because prefetch instructions itself consumes a slot
53 // in pipeline issue thus slowing down the kernel.
54 // if((i == rowRegs - 1) && j % 2 == 0){
55 // a->prefetcht0(x86::dword_ptr(B_pf, j*VecLen*sizeof(int8_t)));
56 //}
57 }
58 }
59}
60
61/**
62 * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel.
63 *
64 */
65template <>
66template <inst_set_t instSet>
67CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
68CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate(
69 bool accum,
70 int32_t mc,
71 int32_t nc,
72 int32_t kc) {
73 (void)kc; // Suppress unused variable warning
74 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
75
76 std::tuple<bool, int, int, int, int, int, int> kernelSig;
77 int kBlock;
78 int nBlock;
79 int mRegBlockSize;
80 int nRegBlockSize;
81 int nRegBlockSizeMin;
82 int row_interleave;
83
84 if (blocking_params) {
85 kBlock = blocking_params->KCB;
86 nBlock = blocking_params->NCB;
87 mRegBlockSize = blocking_params->MR;
88 nRegBlockSize = blocking_params->NR;
89 nRegBlockSizeMin = blocking_params->NR_MIN;
90 row_interleave = blocking_params->ROW_INTERLEAVE;
91 } else {
92 kBlock = PackingTraits<uint8_t, int16_t, instSet>::KCB;
93 nBlock = PackingTraits<uint8_t, int16_t, instSet>::NCB;
94 mRegBlockSize = PackingTraits<uint8_t, int16_t, instSet>::MR;
95 nRegBlockSize = PackingTraits<uint8_t, int16_t, instSet>::NR;
96 nRegBlockSizeMin = PackingTraits<uint8_t, int16_t, instSet>::NR_MIN;
97 row_interleave = PackingTraits<uint8_t, int16_t, instSet>::ROW_INTERLEAVE;
98 }
99 (void)nRegBlockSizeMin; // Suppress unused variable warning
100
101 kernelSig = std::make_tuple(
102 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize);
103
104 return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
105 asmjit::CodeHolder code;
106 code.init(runtime().environment());
107 x86::Assembler assembler(&code);
108 x86::Emitter* a = assembler.as<x86::Emitter>();
109
110#if defined(FBGEMM_LOG_CODE)
111 // generated code logging
112 FILE* codeLogfile = fopen(
113 getCodeLoggingFile<instSet>(
114 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize)
115 .c_str(),
116 "w");
117 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
118 if (codeLogger) {
119 code.setLogger(codeLogger);
120 }
121#endif
122
123 assert(
124 kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
125 assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
126 const int maxMRegs = mRegBlockSize;
127 (void)maxMRegs; // Suppress unused variable warning
128 const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
129 assert(
130 (maxMRegs + 1) * maxNRegs <= 29 &&
131 "number of zmm registers for C + one row for loading B: \
132 MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \
133 must be <= 29(available registers constraint)");
134 int mRegBlocks = mc / mRegBlockSize;
135 int mRegBlocksRem = mc % mRegBlockSize;
136
137 // arguments to the function created
138 x86::Gp buffer_A = a->zdi();
139 x86::Gp buffer_B = a->zsi();
140 x86::Gp B_pf = a->zdx();
141 x86::Gp CBase = a->zcx();
142 x86::Gp kSize = a->gpz(8);
143 x86::Gp ldcReg = a->gpz(9);
144
145 asmjit::FuncDetail func;
146 func.init(
147 asmjit::FuncSignatureT<
148 void,
149 uint8_t*,
150 int8_t*,
151 int8_t*,
152 int32_t*,
153 int,
154 int>(asmjit::CallConvId::kHost),
155 a->environment());
156
157 asmjit::FuncFrame frame;
158 frame.init(func);
159
160 frame.setDirtyRegs(
161 asmjit::RegGroup::kVec,
162 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
163 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
164 asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
165 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
166 frame.setDirtyRegs(
167 asmjit::RegGroup::kGp,
168 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
169
170 asmjit::FuncArgsAssignment args(&func);
171 args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
172
173 args.updateFuncFrame(frame);
174 frame.finalize();
175
176 a->emitProlog(frame);
177 a->emitArgsAssignment(frame, args);
178
179 asmjit::Label LoopMBlocks = a->newLabel();
180 asmjit::Label LoopNBlocks = a->newLabel();
181 asmjit::Label Loopk = a->newLabel();
182
183 x86::Gp buffer_B_saved = a->gpz(10);
184 x86::Gp C_Offset = a->gpz(11);
185 // x86::Gp B_pf_saved = a->gpz(12);
186 x86::Gp iIdx = a->gpz(13);
187 x86::Gp jIdx = a->gpz(14);
188 x86::Gp kIdx = a->gpz(15);
189
190 // save B_buffer address
191 a->mov(buffer_B_saved, buffer_B);
192 // a->mov(B_pf_saved, B_pf);
193
194 int currColRegs = nc * row_interleave / vectorLen;
195 int colRegs = std::min(currColRegs, maxNRegs);
196 if (mRegBlocks > 0) {
197 // move 0 to iteration variables
198 a->xor_(iIdx.r32(), iIdx.r32());
199
200 a->bind(LoopMBlocks);
201 a->inc(iIdx);
202 a->xor_(jIdx.r32(), jIdx.r32());
203
204 a->bind(LoopNBlocks);
205 a->inc(jIdx);
206
207 int rowRegs = mRegBlockSize;
208
209 // init C registers
210 initCRegs(a, rowRegs, colRegs);
211
212 // init k loop index
213 a->xor_(kIdx.r32(), kIdx.r32());
214 a->bind(Loopk);
215 // k is incremented by row_interleave
216 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
217
218 genComputeBlock<instSet>(
219 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
220
221 // update buffer_A address for next k iteration
222 a->add(
223 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
224
225 // update buffer_B address for next k iteration
226 a->add(
227 buffer_B,
228 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
229 // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
230 // sizeof(int8_t)));
231
232 a->cmp(kIdx, kSize);
233 a->jl(Loopk);
234
235 // store C matrix
236 storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
237
238 // reset A
239 a->sub(buffer_A, kSize);
240
241 // B for next block
242 a->mov(buffer_B, buffer_B_saved);
243 // using C_Offset as temp reg
244 a->imul(
245 C_Offset,
246 jIdx,
247 static_cast<asmjit::Imm>(
248 nRegBlockSize * row_interleave * sizeof(int8_t)));
249 a->add(buffer_B, C_Offset);
250
251 // increment C for next block
252 a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
253
254 int jLoopTrips = currColRegs / maxNRegs;
255 // jLoopTrips should be at least 1
256 jLoopTrips = jLoopTrips ? jLoopTrips : 1;
257 a->cmp(jIdx, jLoopTrips);
258 a->jl(LoopNBlocks);
259
260 // increment A for next block
261 a->add(
262 buffer_A,
263 static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
264
265 // increment C for next A block
266 a->sub(
267 CBase,
268 static_cast<asmjit::Imm>(
269 jLoopTrips * nRegBlockSize * sizeof(int32_t)));
270 a->imul(
271 C_Offset,
272 ldcReg,
273 static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
274 a->add(CBase, C_Offset);
275
276 // reset B
277 a->mov(buffer_B, buffer_B_saved);
278 // a->mov(B_pf, B_pf_saved);
279
280 a->cmp(iIdx, mRegBlocks);
281 a->jl(LoopMBlocks);
282 }
283 // generate code for remainder
284 if (mRegBlocksRem > 0) {
285 asmjit::Label LoopNRem = a->newLabel();
286 asmjit::Label LoopkRem = a->newLabel();
287 int rowRegs = mRegBlocksRem;
288
289 a->xor_(jIdx.r32(), jIdx.r32());
290 a->bind(LoopNRem);
291 a->inc(jIdx);
292
293 // init C registers
294 initCRegs(a, rowRegs, colRegs);
295
296 // init k loop index
297 a->xor_(kIdx.r32(), kIdx.r32());
298 a->bind(LoopkRem);
299
300 // k is incremented by row_interleave
301 a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
302
303 genComputeBlock<instSet>(
304 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
305
306 // update buffer_A address for next k iteration
307 a->add(
308 buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
309
310 // update buffer_B address for next k iteration
311 a->add(
312 buffer_B,
313 static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
314 // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
315 // sizeof(int8_t)));
316
317 a->cmp(kIdx, kSize);
318 a->jl(LoopkRem);
319
320 // reset A
321 a->sub(buffer_A, kSize);
322
323 // B for next block
324 a->mov(buffer_B, buffer_B_saved);
325 // using C_Offset as temp reg
326 a->imul(
327 C_Offset,
328 jIdx,
329 static_cast<asmjit::Imm>(
330 nRegBlockSize * row_interleave * sizeof(int8_t)));
331 a->add(buffer_B, C_Offset);
332
333 // store C matrix
334 storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
335
336 // increment C for next block
337 a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
338
339 int jLoopTrips = currColRegs / maxNRegs;
340 // jLoopTrips should be at least 1
341 jLoopTrips = jLoopTrips ? jLoopTrips : 1;
342 a->cmp(jIdx, jLoopTrips);
343 a->jl(LoopNRem);
344 }
345
346 a->emitEpilog(frame);
347
348 jit_micro_kernel_fp fn;
349 asmjit::Error err;
350 {
351 std::unique_lock<std::mutex> lock(rtMutex_);
352 err = runtime().add(&fn, &code);
353 }
354 if (err) {
355 std::cout << "Error: in fn add" << std::endl;
356 return nullptr;
357 }
358
359#if defined(FBGEMM_LOG_CODE)
360 fclose(codeLogfile);
361 delete codeLogger;
362#endif
363
364 return fn;
365 });
366}
367
368/**
369 * Instatiate the AVX512 instructions for 16-bit Accumulation macro-kernel.
370 *
371 */
372template CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
373CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
374 bool accum,
375 int32_t mc,
376 int32_t nc,
377 int32_t kc);
378
379/**
380 * Instatiate the AVX512_256 instructions for 16-bit Accumulation macro-kernel.
381 *
382 */
383template CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
384CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<
385 inst_set_t::avx512_ymm>(bool accum, int32_t mc, int32_t nc, int32_t kc);
386
387} // namespace fbgemm
388