1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #include <iostream> |
8 | #include "./CodeGenHelpers.h" |
9 | #include "./GenerateKernel.h" |
10 | |
11 | namespace fbgemm { |
12 | |
13 | namespace x86 = asmjit::x86; |
14 | |
15 | /** |
16 | * Generate AVX512 instructions for computing block in the rank-k update of |
17 | * 32-bit Accumulation kernel. |
18 | */ |
19 | template <> |
20 | template <inst_set_t instSet> |
21 | void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock( |
22 | x86::Emitter* a, |
23 | x86::Gp buffer_A, |
24 | x86::Gp buffer_B, |
25 | x86::Gp B_pf, |
26 | int rowRegs, |
27 | int colRegs, |
28 | int lda) { |
29 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
30 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
31 | constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS; |
32 | |
33 | // used for matrix A |
34 | VecRegT AReg(numRegs - 1); |
35 | |
36 | // used for matrix B |
37 | VecRegT BReg(numRegs - 2); |
38 | |
39 | // Contains 16-bit 1s |
40 | VecRegT oneReg(numRegs - 3); |
41 | |
42 | // temporary register |
43 | VecRegT res1(numRegs - 4); |
44 | |
45 | for (int j = 0; j < colRegs; ++j) { |
46 | // load B |
47 | emitLoadDWord<instSet, VecRegT>( |
48 | a, BReg, x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); |
49 | // load A, broadcast and fmas |
50 | for (int i = 0; i < rowRegs; ++i) { |
51 | a->vpbroadcastd( |
52 | AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); |
53 | a->vpmaddubsw(res1, AReg, BReg); |
54 | a->vpmaddwd(res1, oneReg, res1); |
55 | a->vpaddd(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j)); |
56 | } |
57 | a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int8_t))); |
58 | } |
59 | } |
60 | |
61 | /** |
62 | * Generate AVX512 instructions for storing the C registers back to the memory |
63 | * in 32-bit Accumulation kernel. |
64 | */ |
65 | template <> |
66 | template <inst_set_t instSet> |
67 | void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs( |
68 | x86::Emitter* a, |
69 | int rowRegs, |
70 | int colRegs, |
71 | x86::Gp C_Offset, |
72 | x86::Gp ldcReg, |
73 | bool accum) { |
74 | using VecT = typename simd_info<instSet>::vec_reg_t; |
75 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
76 | |
77 | for (int i = 0; i < rowRegs; ++i) { |
78 | if (i != 0) { |
79 | a->add(C_Offset, ldcReg); |
80 | } else { |
81 | a->xor_(C_Offset.r32(), C_Offset.r32()); |
82 | } |
83 | for (int j = 0; j < colRegs; ++j) { |
84 | if (accum) { |
85 | a->vpaddd( |
86 | VecT(i * colRegs + j), |
87 | VecT(i * colRegs + j), |
88 | x86::dword_ptr( |
89 | a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t))); |
90 | } |
91 | a->vmovups( |
92 | x86::dword_ptr(a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)), |
93 | VecT(i * colRegs + j)); |
94 | } |
95 | } |
96 | } |
97 | |
98 | /** |
99 | * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. |
100 | * |
101 | */ |
102 | template <> |
103 | template <inst_set_t instSet> |
104 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp |
105 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate( |
106 | bool accum, |
107 | int32_t mc, |
108 | int32_t nc, |
109 | int32_t kc) { |
110 | (void)kc; // Suppress unused variable warning |
111 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
112 | constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS; |
113 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
114 | |
115 | std::tuple<bool, int, int, int, int, int, int> kernelSig; |
116 | int kBlock; |
117 | int nBlock; |
118 | int mRegBlockSize; |
119 | int nRegBlockSize; |
120 | int nRegBlockSizeMin; |
121 | int row_interleave; |
122 | |
123 | if (blocking_params) { |
124 | kBlock = blocking_params->KCB; |
125 | nBlock = blocking_params->NCB; |
126 | mRegBlockSize = blocking_params->MR; |
127 | nRegBlockSize = blocking_params->NR; |
128 | nRegBlockSizeMin = blocking_params->NR_MIN; |
129 | row_interleave = blocking_params->ROW_INTERLEAVE; |
130 | } else { |
131 | kBlock = PackingTraits<uint8_t, int32_t, instSet>::KCB; |
132 | nBlock = PackingTraits<uint8_t, int32_t, instSet>::NCB; |
133 | mRegBlockSize = PackingTraits<uint8_t, int32_t, instSet>::MR; |
134 | nRegBlockSize = PackingTraits<uint8_t, int32_t, instSet>::NR; |
135 | nRegBlockSizeMin = PackingTraits<uint8_t, int32_t, instSet>::NR_MIN; |
136 | row_interleave = PackingTraits<uint8_t, int32_t, instSet>::ROW_INTERLEAVE; |
137 | } |
138 | (void)nRegBlockSizeMin; // Suppress unused variable warning |
139 | |
140 | kernelSig = std::make_tuple( |
141 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); |
142 | |
143 | return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { |
144 | asmjit::CodeHolder code; |
145 | code.init(runtime().environment()); |
146 | x86::Assembler assembler(&code); |
147 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
148 | #if defined(FBGEMM_LOG_CODE) |
149 | // generated code logging |
150 | FILE* codeLogfile = fopen( |
151 | getCodeLoggingFile<instSet>( |
152 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) |
153 | .c_str(), |
154 | "w" ); |
155 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); |
156 | if (codeLogger) { |
157 | code.setLogger(codeLogger); |
158 | } |
159 | #endif |
160 | |
161 | assert( |
162 | kc % row_interleave == 0 && "kc must be a multiple of row_interleave" ); |
163 | assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN" ); |
164 | const int maxMRegs = mRegBlockSize; |
165 | (void)maxMRegs; // Suppress unused variable warning |
166 | const int maxNRegs = nRegBlockSize * row_interleave / vectorLen; |
167 | assert( |
168 | maxMRegs * maxNRegs <= numRegs - 4 && |
169 | "MRegs x NRegs is above available registers (MAX_REGS - 4)" ); |
170 | |
171 | int mRegBlocks = mc / mRegBlockSize; |
172 | int mRegBlocksRem = mc % mRegBlockSize; |
173 | |
174 | // arguments to the function created |
175 | x86::Gp buffer_A = a->zdi(); |
176 | x86::Gp buffer_B = a->zsi(); |
177 | x86::Gp B_pf = a->zdx(); |
178 | x86::Gp CBase = a->zcx(); |
179 | x86::Gp kSize = a->gpz(8); |
180 | x86::Gp ldcReg = a->gpz(9); |
181 | |
182 | asmjit::FuncDetail func; |
183 | func.init( |
184 | asmjit::FuncSignatureT< |
185 | void, |
186 | uint8_t*, |
187 | int8_t*, |
188 | int8_t*, |
189 | int32_t*, |
190 | int, |
191 | int>(asmjit::CallConvId::kHost), |
192 | a->environment()); |
193 | |
194 | asmjit::FuncFrame frame; |
195 | frame.init(func); |
196 | |
197 | auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
198 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15); |
199 | if (numRegs >= 16) { |
200 | dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
201 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31); |
202 | } |
203 | |
204 | frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs); |
205 | frame.setDirtyRegs( |
206 | asmjit::RegGroup::kGp, |
207 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
208 | |
209 | asmjit::FuncArgsAssignment args(&func); |
210 | args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); |
211 | |
212 | args.updateFuncFrame(frame); |
213 | frame.finalize(); |
214 | |
215 | a->emitProlog(frame); |
216 | a->emitArgsAssignment(frame, args); |
217 | |
218 | asmjit::Label LoopMBlocks = a->newLabel(); |
219 | asmjit::Label LoopNBlocks = a->newLabel(); |
220 | |
221 | x86::Gp buffer_B_saved = a->gpz(10); |
222 | x86::Gp C_Offset = a->gpz(11); |
223 | x86::Gp B_pf_saved = a->gpz(12); |
224 | x86::Gp iIdx = a->gpz(13); |
225 | x86::Gp jIdx = a->gpz(14); |
226 | x86::Gp kIdx = a->gpz(15); |
227 | // x86::Gp B_pf = a->gpz(8); |
228 | |
229 | VecRegT oneReg(numRegs - 3); |
230 | |
231 | gen16BitVectorOne<instSet, VecRegT>(a, oneReg); |
232 | a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); |
233 | // a->xor_(C_Offset.r32(), C_Offset.r32()); |
234 | |
235 | // save B_buffer address |
236 | a->mov(buffer_B_saved, buffer_B); |
237 | a->mov(B_pf_saved, B_pf); |
238 | |
239 | int currColRegs = nc * row_interleave / vectorLen; |
240 | int colRegs = std::min(currColRegs, maxNRegs); |
241 | |
242 | auto issueLoopOverK = [&](int rowRegs) { |
243 | asmjit::Label LoopKLabel = a->newLabel(); |
244 | |
245 | // Init C (result) vector registers |
246 | initCRegs(a, rowRegs, colRegs); |
247 | |
248 | // Loops over K |
249 | a->xor_(kIdx.r32(), kIdx.r32()); |
250 | a->bind(LoopKLabel); |
251 | |
252 | // k is incremented by row_interleave |
253 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
254 | |
255 | genComputeBlock<instSet>( |
256 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
257 | |
258 | // update buffer_A address for next k iteration |
259 | a->add( |
260 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
261 | |
262 | // update buffer_B address for next k iteration |
263 | a->add( |
264 | buffer_B, |
265 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
266 | a->add( |
267 | B_pf, |
268 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
269 | |
270 | a->cmp(kIdx, kSize); |
271 | a->jl(LoopKLabel); |
272 | |
273 | // store C matrix |
274 | storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
275 | }; |
276 | |
277 | if (mRegBlocks > 0) { |
278 | // move 0 to iteration variables |
279 | a->xor_(iIdx.r32(), iIdx.r32()); |
280 | |
281 | a->bind(LoopMBlocks); |
282 | a->inc(iIdx); |
283 | |
284 | a->xor_(jIdx.r32(), jIdx.r32()); |
285 | |
286 | a->bind(LoopNBlocks); |
287 | a->inc(jIdx); |
288 | |
289 | issueLoopOverK(mRegBlockSize); |
290 | |
291 | int rowRegs = mRegBlockSize; |
292 | |
293 | // reset A |
294 | a->sub(buffer_A, kSize); |
295 | |
296 | // B for next block |
297 | a->mov(buffer_B, buffer_B_saved); |
298 | // using C_Offset as temp reg |
299 | a->imul( |
300 | C_Offset, |
301 | jIdx, |
302 | static_cast<asmjit::Imm>( |
303 | nRegBlockSize * row_interleave * sizeof(int8_t))); |
304 | a->add(buffer_B, C_Offset); |
305 | a->mov(B_pf, B_pf_saved); |
306 | a->add(B_pf, C_Offset); |
307 | |
308 | // increment C for next B block |
309 | a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); |
310 | |
311 | int jLoopTrips = currColRegs / maxNRegs; |
312 | // jLoopTrips should be at least 1 |
313 | jLoopTrips = jLoopTrips ? jLoopTrips : 1; |
314 | a->cmp(jIdx, jLoopTrips); |
315 | a->jl(LoopNBlocks); |
316 | |
317 | // increment A for next block |
318 | a->add( |
319 | buffer_A, |
320 | static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); |
321 | |
322 | // increment C for next A block |
323 | a->sub( |
324 | CBase, |
325 | static_cast<asmjit::Imm>( |
326 | jLoopTrips * nRegBlockSize * sizeof(int32_t))); |
327 | a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); |
328 | a->add(CBase, C_Offset); |
329 | |
330 | // reset B |
331 | a->mov(buffer_B, buffer_B_saved); |
332 | a->mov(B_pf, B_pf_saved); |
333 | a->cmp(iIdx, mRegBlocks); |
334 | a->jl(LoopMBlocks); |
335 | } |
336 | // generate code for remainder |
337 | if (mRegBlocksRem > 0) { |
338 | asmjit::Label LoopNRem = a->newLabel(); |
339 | |
340 | a->xor_(jIdx.r32(), jIdx.r32()); |
341 | a->bind(LoopNRem); |
342 | a->inc(jIdx); |
343 | |
344 | issueLoopOverK(mRegBlocksRem); |
345 | |
346 | // increment C for next B block |
347 | a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); |
348 | |
349 | int jLoopTrips = currColRegs / maxNRegs; |
350 | // jLoopTrips should be at least 1 |
351 | jLoopTrips = jLoopTrips ? jLoopTrips : 1; |
352 | a->cmp(jIdx, jLoopTrips); |
353 | a->jl(LoopNRem); |
354 | } |
355 | |
356 | a->emitEpilog(frame); |
357 | |
358 | jit_micro_kernel_fp fn; |
359 | asmjit::Error err; |
360 | { |
361 | std::unique_lock<std::mutex> lock(rtMutex_); |
362 | err = runtime().add(&fn, &code); |
363 | } |
364 | if (err) { |
365 | std::cout << "Error: in fn add" << std::endl; |
366 | return nullptr; |
367 | } |
368 | |
369 | #if defined(FBGEMM_LOG_CODE) |
370 | fclose(codeLogfile); |
371 | delete codeLogger; |
372 | #endif |
373 | |
374 | return fn; |
375 | }); |
376 | } |
377 | |
378 | /** |
379 | * Instantiate the AVX512 instructions for 32-bit Accumulation macro-kernel. |
380 | * |
381 | */ |
382 | template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp |
383 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( |
384 | bool accum, |
385 | int32_t mc, |
386 | int32_t nc, |
387 | int32_t kc); |
388 | |
389 | /** |
390 | * Instatiate the AVX512_256 instructions for 32-bit Accumulation macro-kernel. |
391 | * |
392 | */ |
393 | template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp |
394 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate< |
395 | inst_set_t::avx512_ymm>(bool accum, int32_t mc, int32_t nc, int32_t kc); |
396 | |
397 | /** |
398 | * Instantiate the AVX2 instructions for 32-bit Accumulation macro-kernel. |
399 | * |
400 | */ |
401 | template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp |
402 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( |
403 | bool accum, |
404 | int32_t mc, |
405 | int32_t nc, |
406 | int32_t kc); |
407 | |
408 | /** |
409 | * Instantiate the inst_set_t::avx512 instructions for store kernel. |
410 | * |
411 | */ |
412 | template void |
413 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<inst_set_t::avx512>( |
414 | x86::Emitter* a, |
415 | int rowRegs, |
416 | int colRegs, |
417 | x86::Gp C_Offset, |
418 | x86::Gp ldcReg, |
419 | bool accum); |
420 | |
421 | /** |
422 | * Instantiate the inst_set_t::avx512_ymm instructions for store kernel. |
423 | * |
424 | */ |
425 | template void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< |
426 | inst_set_t::avx512_ymm>( |
427 | x86::Emitter* a, |
428 | int rowRegs, |
429 | int colRegs, |
430 | x86::Gp C_Offset, |
431 | x86::Gp ldcReg, |
432 | bool accum); |
433 | |
434 | /** |
435 | * Instantiate the inst_set_t::avx2 instructions for store kernel. |
436 | * |
437 | */ |
438 | template void |
439 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<inst_set_t::avx2>( |
440 | x86::Emitter* a, |
441 | int rowRegs, |
442 | int colRegs, |
443 | x86::Gp C_Offset, |
444 | x86::Gp ldcReg, |
445 | bool accum); |
446 | |
447 | } // namespace fbgemm |
448 | |