1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #include <iostream> |
8 | #include "./GenerateKernel.h" |
9 | |
10 | namespace fbgemm { |
11 | |
12 | namespace x86 = asmjit::x86; |
13 | |
14 | /** |
15 | * Generate AVX512 instructions for computing block in the rank-k update of |
16 | * 32-bit Accmulation kernel. |
17 | */ |
18 | template <> |
19 | template <inst_set_t instSet> |
20 | void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock( |
21 | x86::Emitter* a, |
22 | x86::Gp buffer_A, |
23 | x86::Gp buffer_B, |
24 | x86::Gp /*B_pf*/, |
25 | int rowRegs, |
26 | int colRegs, |
27 | int lda) { |
28 | assert(colRegs * (rowRegs + 1) <= 31); |
29 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
30 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
31 | |
32 | // used for matrix A |
33 | VecRegT AReg(31); |
34 | for (int j = 0; j < colRegs; ++j) { |
35 | a->vmovdqa32( |
36 | VecRegT(30 - j), |
37 | x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); |
38 | } |
39 | |
40 | for (int i = 0; i < rowRegs; i++) { |
41 | a->vpbroadcastd( |
42 | AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); |
43 | for (int j = 0; j < colRegs; ++j) { |
44 | a->vpdpbusd(VecRegT(i * colRegs + j), AReg, VecRegT(30 - j)); |
45 | } |
46 | } |
47 | } |
48 | |
49 | /** |
50 | * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. |
51 | * |
52 | */ |
53 | template <> |
54 | template <inst_set_t instSet> |
55 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp |
56 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate( |
57 | bool accum, |
58 | int32_t mc, |
59 | int32_t nc, |
60 | int32_t kc) { |
61 | (void)kc; // Suppress unused variable warning |
62 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
63 | static constexpr inst_set_t storeInstType = |
64 | simd_info<instSet>::WIDTH_BITS == 512 ? inst_set_t::avx512 |
65 | : inst_set_t::avx512_ymm; |
66 | |
67 | std::tuple<bool, int, int, int, int, int, int> kernelSig; |
68 | int kBlock; |
69 | int nBlock; |
70 | int mRegBlockSize; |
71 | int nRegBlockSize; |
72 | int nRegBlockSizeMin; |
73 | int row_interleave; |
74 | |
75 | if (blocking_params) { |
76 | kBlock = blocking_params->KCB; |
77 | nBlock = blocking_params->NCB; |
78 | mRegBlockSize = blocking_params->MR; |
79 | nRegBlockSize = blocking_params->NR; |
80 | nRegBlockSizeMin = blocking_params->NR_MIN; |
81 | row_interleave = blocking_params->ROW_INTERLEAVE; |
82 | } else { |
83 | kBlock = PackingTraits<uint8_t, int32_t, instSet>::KCB; |
84 | nBlock = PackingTraits<uint8_t, int32_t, instSet>::NCB; |
85 | mRegBlockSize = PackingTraits<uint8_t, int32_t, instSet>::MR; |
86 | nRegBlockSize = PackingTraits<uint8_t, int32_t, instSet>::NR; |
87 | nRegBlockSizeMin = PackingTraits<uint8_t, int32_t, instSet>::NR_MIN; |
88 | row_interleave = PackingTraits<uint8_t, int32_t, instSet>::ROW_INTERLEAVE; |
89 | } |
90 | (void)nRegBlockSizeMin; // Suppress unused variable warning |
91 | |
92 | kernelSig = std::make_tuple( |
93 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); |
94 | |
95 | return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { |
96 | asmjit::CodeHolder code; |
97 | code.init(runtime().environment()); |
98 | x86::Assembler assembler(&code); |
99 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
100 | |
101 | #if defined(FBGEMM_LOG_CODE) |
102 | // generated code logging |
103 | FILE* codeLogfile = fopen( |
104 | getCodeLoggingFile<instSet>( |
105 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) |
106 | .c_str(), |
107 | "w" ); |
108 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); |
109 | if (codeLogger) { |
110 | code.setLogger(codeLogger); |
111 | } |
112 | #endif |
113 | |
114 | assert( |
115 | kc % row_interleave == 0 && "kc must be a multiple of row_interleave" ); |
116 | assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN" ); |
117 | const int maxMRegs = mRegBlockSize; |
118 | const int maxNRegs = nRegBlockSize * row_interleave / vectorLen; |
119 | (void)maxMRegs; // Suppress unused variable warning |
120 | assert( |
121 | maxMRegs * maxNRegs <= 30 && |
122 | "MR*(NR*ROW_INTERLEAVE*8/512) \ |
123 | must be <= 30(available registers constraint)" ); |
124 | |
125 | int mRegBlocks = mc / mRegBlockSize; |
126 | int mRegBlocksRem = mc % mRegBlockSize; |
127 | |
128 | // arguments to the function created |
129 | x86::Gp buffer_A = a->zdi(); |
130 | x86::Gp buffer_B = a->zsi(); |
131 | x86::Gp B_pf = a->zdx(); |
132 | x86::Gp CBase = a->zcx(); |
133 | x86::Gp kSize = a->gpz(8); |
134 | x86::Gp ldcReg = a->gpz(9); |
135 | |
136 | asmjit::FuncDetail func; |
137 | func.init( |
138 | asmjit::FuncSignatureT< |
139 | void, |
140 | uint8_t*, |
141 | int8_t*, |
142 | int8_t*, |
143 | int32_t*, |
144 | int, |
145 | int>(asmjit::CallConvId::kHost), |
146 | a->environment()); |
147 | |
148 | asmjit::FuncFrame frame; |
149 | frame.init(func); |
150 | |
151 | frame.setDirtyRegs( |
152 | asmjit::RegGroup::kVec, |
153 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
154 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | |
155 | asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
156 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); |
157 | frame.setDirtyRegs( |
158 | asmjit::RegGroup::kGp, |
159 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
160 | |
161 | asmjit::FuncArgsAssignment args(&func); |
162 | args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); |
163 | |
164 | args.updateFuncFrame(frame); |
165 | frame.finalize(); |
166 | |
167 | a->emitProlog(frame); |
168 | a->emitArgsAssignment(frame, args); |
169 | |
170 | asmjit::Label LoopMBlocks = a->newLabel(); |
171 | asmjit::Label LoopNBlocks = a->newLabel(); |
172 | asmjit::Label Loopk = a->newLabel(); |
173 | |
174 | x86::Gp buffer_B_saved = a->gpz(10); |
175 | x86::Gp C_Offset = a->gpz(11); |
176 | x86::Gp B_pf_saved = a->gpz(12); |
177 | x86::Gp iIdx = a->gpz(13); |
178 | x86::Gp jIdx = a->gpz(14); |
179 | x86::Gp kIdx = a->gpz(15); |
180 | // x86::Gp B_pf = a->gpz(8); |
181 | |
182 | x86::Zmm oneReg = x86::zmm29; |
183 | // create 16-bit 1s |
184 | // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 |
185 | // and so on |
186 | // a->vpcmpeqw(oneReg, oneReg, oneReg); |
187 | a->vpternlogd(oneReg, oneReg, oneReg, 0xff); |
188 | a->vpsrlw(oneReg, oneReg, 15); |
189 | a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); |
190 | |
191 | // save B_buffer address |
192 | a->mov(buffer_B_saved, buffer_B); |
193 | a->mov(B_pf_saved, B_pf); |
194 | |
195 | int currColRegs = nc * row_interleave / vectorLen; |
196 | int colRegs = std::min(currColRegs, maxNRegs); |
197 | if (mRegBlocks > 0) { |
198 | // move 0 to iteration variables |
199 | a->xor_(iIdx.r32(), iIdx.r32()); |
200 | |
201 | a->bind(LoopMBlocks); |
202 | a->inc(iIdx); |
203 | a->xor_(jIdx.r32(), jIdx.r32()); |
204 | |
205 | a->bind(LoopNBlocks); |
206 | a->inc(jIdx); |
207 | |
208 | int rowRegs = mRegBlockSize; |
209 | |
210 | // init C registers |
211 | initCRegs(a, rowRegs, colRegs); |
212 | |
213 | // init k loop index |
214 | a->xor_(kIdx.r32(), kIdx.r32()); |
215 | a->bind(Loopk); |
216 | |
217 | // k is incremented by row_interleave |
218 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
219 | |
220 | genComputeBlock<instSet>( |
221 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
222 | |
223 | // update buffer_A address for next k iteration |
224 | a->add( |
225 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
226 | |
227 | // update buffer_B address for next k iteration |
228 | a->add( |
229 | buffer_B, |
230 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
231 | a->add( |
232 | B_pf, |
233 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
234 | |
235 | // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); |
236 | |
237 | a->cmp(kIdx, kSize); |
238 | a->jl(Loopk); |
239 | |
240 | // store C matrix |
241 | storeCRegs<storeInstType>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
242 | |
243 | // reset A |
244 | a->sub(buffer_A, kSize); |
245 | |
246 | // B for next block |
247 | a->mov(buffer_B, buffer_B_saved); |
248 | // using C_Offset as temp reg |
249 | a->imul( |
250 | C_Offset, |
251 | jIdx, |
252 | static_cast<asmjit::Imm>( |
253 | nRegBlockSize * row_interleave * sizeof(int8_t))); |
254 | a->add(buffer_B, C_Offset); |
255 | a->mov(B_pf, B_pf_saved); |
256 | a->add(B_pf, C_Offset); |
257 | |
258 | // increment C for next B block |
259 | a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); |
260 | |
261 | int jLoopTrips = currColRegs / maxNRegs; |
262 | // jLoopTrips should be at least 1 |
263 | jLoopTrips = jLoopTrips ? jLoopTrips : 1; |
264 | a->cmp(jIdx, jLoopTrips); |
265 | a->jl(LoopNBlocks); |
266 | |
267 | // increment A for next block |
268 | a->add( |
269 | buffer_A, |
270 | static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); |
271 | |
272 | // increment C for next A block |
273 | a->sub( |
274 | CBase, |
275 | static_cast<asmjit::Imm>( |
276 | jLoopTrips * nRegBlockSize * sizeof(int32_t))); |
277 | a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); |
278 | a->add(CBase, C_Offset); |
279 | |
280 | // reset B |
281 | a->mov(buffer_B, buffer_B_saved); |
282 | a->mov(B_pf, B_pf_saved); |
283 | a->cmp(iIdx, mRegBlocks); |
284 | a->jl(LoopMBlocks); |
285 | } |
286 | // generate code for remainder |
287 | if (mRegBlocksRem > 0) { |
288 | asmjit::Label LoopNRem = a->newLabel(); |
289 | asmjit::Label LoopkRem = a->newLabel(); |
290 | int rowRegs = mRegBlocksRem; |
291 | |
292 | a->xor_(jIdx.r32(), jIdx.r32()); |
293 | a->bind(LoopNRem); |
294 | a->inc(jIdx); |
295 | |
296 | // init C registers |
297 | initCRegs(a, rowRegs, colRegs); |
298 | |
299 | // init k loop index |
300 | a->xor_(kIdx.r32(), kIdx.r32()); |
301 | a->bind(LoopkRem); |
302 | |
303 | // k is incremented by row_interleave |
304 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
305 | |
306 | genComputeBlock<instSet>( |
307 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
308 | |
309 | // update buffer_A address for next k iteration |
310 | a->add( |
311 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
312 | |
313 | // update buffer_B address for next k iteration |
314 | a->add( |
315 | buffer_B, |
316 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
317 | a->add( |
318 | B_pf, |
319 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
320 | |
321 | a->cmp(kIdx, kSize); |
322 | a->jl(LoopkRem); |
323 | |
324 | // reset A |
325 | a->sub(buffer_A, kSize); |
326 | // B for next block |
327 | // using C_Offset as temp reg |
328 | a->imul( |
329 | C_Offset, |
330 | jIdx, |
331 | static_cast<asmjit::Imm>( |
332 | nRegBlockSize * row_interleave * sizeof(int8_t))); |
333 | a->mov(buffer_B, buffer_B_saved); |
334 | a->add(buffer_B, C_Offset); |
335 | a->mov(B_pf, B_pf_saved); |
336 | a->add(B_pf, C_Offset); |
337 | |
338 | // store C matrix |
339 | storeCRegs<storeInstType>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
340 | |
341 | // increment C for next B block |
342 | a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); |
343 | |
344 | int jLoopTrips = currColRegs / maxNRegs; |
345 | // jLoopTrips should be at least 1 |
346 | jLoopTrips = jLoopTrips ? jLoopTrips : 1; |
347 | a->cmp(jIdx, jLoopTrips); |
348 | a->jl(LoopNRem); |
349 | } |
350 | |
351 | a->emitEpilog(frame); |
352 | |
353 | jit_micro_kernel_fp fn; |
354 | asmjit::Error err; |
355 | { |
356 | std::unique_lock<std::mutex> lock(rtMutex_); |
357 | err = runtime().add(&fn, &code); |
358 | } |
359 | if (err) { |
360 | std::cout << "Error: in fn add" << std::endl; |
361 | return nullptr; |
362 | } |
363 | |
364 | #if defined(FBGEMM_LOG_CODE) |
365 | fclose(codeLogfile); |
366 | delete codeLogger; |
367 | #endif |
368 | |
369 | return fn; |
370 | }); |
371 | } |
372 | |
373 | /** |
374 | * Instatiate the AVX512_VNNI instructions for 32-bit Accumulation macro-kernel. |
375 | * |
376 | */ |
377 | template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp |
378 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate< |
379 | inst_set_t::avx512_vnni>(bool accum, int32_t mc, int32_t nc, int32_t kc); |
380 | |
381 | /** |
382 | * Instatiate the AVX512_VNNI_256 instructions for 32-bit Accumulation |
383 | * macro-kernel. |
384 | * |
385 | */ |
386 | template CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp |
387 | CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate< |
388 | inst_set_t::avx512_vnni_ymm>( |
389 | bool accum, |
390 | int32_t mc, |
391 | int32_t nc, |
392 | int32_t kc); |
393 | |
394 | } // namespace fbgemm |
395 | |