1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #include <iostream> |
8 | #include "./CodeGenHelpers.h" |
9 | #include "./GenerateKernel.h" |
10 | |
11 | namespace fbgemm { |
12 | |
13 | namespace x86 = asmjit::x86; |
14 | |
15 | /** |
16 | * Generate AVX2 instructions for computing block in the rank-k update of 16-bit |
17 | * Accmulation kernel. |
18 | */ |
19 | template <> |
20 | template <> |
21 | void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< |
22 | inst_set_t::avx2>( |
23 | x86::Emitter* a, |
24 | x86::Gp buffer_A, |
25 | x86::Gp buffer_B, |
26 | x86::Gp /* unused (reserved for prefetching)*/, |
27 | int rowRegs, |
28 | int colRegs, |
29 | int lda) { |
30 | using CRegs = x86::Ymm; |
31 | static constexpr int vectorLen = simd_info<inst_set_t::avx2>::WIDTH_BYTES; |
32 | |
33 | // used for matrix A |
34 | x86::Ymm AReg = x86::ymm13; |
35 | x86::Ymm tmpReg = x86::ymm14; |
36 | for (int i = 0; i < rowRegs; ++i) { |
37 | // broadcast A |
38 | a->vpbroadcastw( |
39 | AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); |
40 | for (int j = 0; j < colRegs; ++j) { |
41 | a->vpmaddubsw( |
42 | tmpReg, |
43 | AReg, |
44 | x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); |
45 | a->vpaddsw(CRegs(i * colRegs + j), tmpReg, CRegs(i * colRegs + j)); |
46 | // Prefetching is hurting performance in some cases |
47 | // because prefetch instructions itself consumes a slot |
48 | // in pipeline issue thus slowing down the kernel. |
49 | // if((i == rowRegs - 1) && j % 2 == 0){ |
50 | // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t))); |
51 | //} |
52 | } |
53 | } |
54 | } |
55 | |
56 | /** |
57 | * Generate instructions for storing the C registers back to the memory |
58 | * in 16-bit Accumulation kernel. |
59 | */ |
60 | template <> |
61 | template <inst_set_t instSet> |
62 | void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs( |
63 | x86::Emitter* a, |
64 | int rowRegs, |
65 | int colRegs, |
66 | x86::Gp C_Offset, |
67 | x86::Gp ldcReg, |
68 | bool accum) { |
69 | using VecT = typename simd_info<instSet>::vec_reg_t; |
70 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
71 | |
72 | VecT (simd_info<instSet>::NUM_VEC_REGS - 1); |
73 | auto = extractDestFull.half(); |
74 | |
75 | for (int i = 0; i < rowRegs; ++i) { |
76 | a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); |
77 | for (int j = 0; j < colRegs; ++j) { |
78 | for (int idx = 0; idx < 2; ++idx) { |
79 | emitExtractHalfVector<instSet, VecT>( |
80 | a, extractDestHalf, VecT(i * colRegs + j), idx); |
81 | a->vpmovsxwd(extractDestFull, extractDestHalf); |
82 | x86::Mem destAddr = |
83 | x86::dword_ptr(a->zcx(), C_Offset, 0, (j * 2 + idx) * vectorLen); |
84 | if (accum) { |
85 | a->vpaddd(extractDestFull, extractDestFull, destAddr); |
86 | } |
87 | a->vmovups(destAddr, extractDestFull); |
88 | } |
89 | } |
90 | } |
91 | } |
92 | |
93 | /** |
94 | * Get or Create the AVX2 instructions for 16-bit Accumulation macro-kernel. |
95 | * |
96 | */ |
97 | template <> |
98 | template <> |
99 | CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp |
100 | CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( |
101 | bool accum, |
102 | int32_t mc, |
103 | int32_t nc, |
104 | int32_t kc) { |
105 | (void)kc; // Suppress unused variable warning |
106 | constexpr int vectorLen = simd_info<inst_set_t::avx2>::WIDTH_BYTES; |
107 | |
108 | std::tuple<bool, int, int, int, int, int, int> kernelSig; |
109 | int kBlock; |
110 | int nBlock; |
111 | int mRegBlockSize; |
112 | int nRegBlockSize; |
113 | int nRegBlockSizeMin; |
114 | int row_interleave; |
115 | |
116 | if (blocking_params) { |
117 | kBlock = blocking_params->KCB; |
118 | nBlock = blocking_params->NCB; |
119 | mRegBlockSize = blocking_params->MR; |
120 | nRegBlockSize = blocking_params->NR; |
121 | nRegBlockSizeMin = blocking_params->NR_MIN; |
122 | row_interleave = blocking_params->ROW_INTERLEAVE; |
123 | } else { |
124 | kBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::KCB; |
125 | nBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NCB; |
126 | mRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::MR; |
127 | nRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NR; |
128 | nRegBlockSizeMin = |
129 | PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NR_MIN; |
130 | row_interleave = |
131 | PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE; |
132 | } |
133 | (void)nRegBlockSizeMin; // Suppress unused variable warning |
134 | |
135 | kernelSig = std::make_tuple( |
136 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); |
137 | |
138 | return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { |
139 | asmjit::CodeHolder code; |
140 | code.init(runtime().environment()); |
141 | x86::Assembler assembler(&code); |
142 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
143 | |
144 | #if defined(FBGEMM_LOG_CODE) |
145 | // generated code logging |
146 | FILE* codeLogfile = fopen( |
147 | getCodeLoggingFile<inst_set_t::avx2>( |
148 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) |
149 | .c_str(), |
150 | "w" ); |
151 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); |
152 | if (codeLogger) { |
153 | code.setLogger(codeLogger); |
154 | } |
155 | #endif |
156 | |
157 | assert( |
158 | kc % row_interleave == 0 && "kc must be a multiple of row_interleave" ); |
159 | assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN" ); |
160 | const int maxMRegs = mRegBlockSize; |
161 | const int maxNRegs = nRegBlockSize * row_interleave / vectorLen; |
162 | (void)maxMRegs; // Suppress unused variable warning |
163 | (void)maxNRegs; // Suppress unused variable warning |
164 | assert( |
165 | maxMRegs * maxNRegs <= 13 && |
166 | "MR*(NR*ROW_INTERLEAVE*8/256" |
167 | "must be <= 13(available registers constraint)" ); |
168 | |
169 | int mRegBlocks = mc / mRegBlockSize; |
170 | int mRegBlocksRem = mc % mRegBlockSize; |
171 | |
172 | // assert((nc == nRegBlockSize) && |
173 | //"nc must be equal to the number of register blocks"); |
174 | |
175 | // arguments to the function created |
176 | x86::Gp buffer_A = a->zdi(); |
177 | x86::Gp buffer_B = a->zsi(); |
178 | x86::Gp B_pf = a->zdx(); |
179 | x86::Gp CBase = a->zcx(); |
180 | x86::Gp kSize = a->gpz(8); |
181 | x86::Gp ldcReg = a->gpz(9); |
182 | |
183 | asmjit::FuncDetail func; |
184 | func.init( |
185 | asmjit::FuncSignatureT< |
186 | void, |
187 | uint8_t*, |
188 | int8_t*, |
189 | int8_t*, |
190 | int32_t*, |
191 | int, |
192 | int>(asmjit::CallConvId::kHost), |
193 | a->environment()); |
194 | |
195 | asmjit::FuncFrame frame; |
196 | frame.init(func); |
197 | frame.setDirtyRegs( |
198 | asmjit::RegGroup::kVec, |
199 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
200 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
201 | frame.setDirtyRegs( |
202 | asmjit::RegGroup::kGp, |
203 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); |
204 | |
205 | asmjit::FuncArgsAssignment args(&func); |
206 | args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); |
207 | |
208 | args.updateFuncFrame(frame); |
209 | frame.finalize(); |
210 | |
211 | a->emitProlog(frame); |
212 | a->emitArgsAssignment(frame, args); |
213 | |
214 | asmjit::Label Loopk = a->newLabel(); |
215 | asmjit::Label LoopMBlocks = a->newLabel(); |
216 | |
217 | x86::Gp buffer_B_saved = a->gpz(10); |
218 | x86::Gp C_Offset = a->gpz(11); |
219 | // x86::Gp B_pf_saved = a->gpz(12); |
220 | x86::Gp iIdx = a->gpz(13); |
221 | x86::Gp kIdx = a->gpz(14); |
222 | |
223 | int colRegs = nc * row_interleave / vectorLen; |
224 | if (mRegBlocks > 0) { |
225 | // move 0 to iteration variables |
226 | a->xor_(iIdx.r32(), iIdx.r32()); |
227 | |
228 | // save B_buffer address |
229 | a->mov(buffer_B_saved, buffer_B); |
230 | // a->mov(B_pf_saved, B_pf); |
231 | |
232 | a->bind(LoopMBlocks); |
233 | a->inc(iIdx); |
234 | |
235 | int rowRegs = mRegBlockSize; |
236 | |
237 | // init C registers |
238 | initCRegs(a, rowRegs, colRegs); |
239 | |
240 | // init k loop index |
241 | a->xor_(kIdx.r32(), kIdx.r32()); |
242 | a->bind(Loopk); |
243 | // k is incremented by row_interleave |
244 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
245 | |
246 | genComputeBlock<inst_set_t::avx2>( |
247 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
248 | |
249 | // update buffer_A address for next k iteration |
250 | a->add( |
251 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
252 | |
253 | // update buffer_B address for next k iteration |
254 | a->add( |
255 | buffer_B, |
256 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
257 | // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * |
258 | // sizeof(int8_t))); |
259 | |
260 | a->cmp(kIdx, kSize); |
261 | a->jl(Loopk); |
262 | |
263 | // store C matrix |
264 | storeCRegs<inst_set_t::avx2>( |
265 | a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
266 | |
267 | // increment A for next block |
268 | a->sub(buffer_A, kSize); |
269 | a->add( |
270 | buffer_A, |
271 | static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); |
272 | // increment C for next block |
273 | a->imul( |
274 | C_Offset, |
275 | ldcReg, |
276 | static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); |
277 | a->add(CBase, C_Offset); |
278 | // reset B |
279 | a->mov(buffer_B, buffer_B_saved); |
280 | // a->mov(B_pf, B_pf_saved); |
281 | |
282 | a->cmp(iIdx, mRegBlocks); |
283 | a->jl(LoopMBlocks); |
284 | } |
285 | // generate code for remainder |
286 | if (mRegBlocksRem > 0) { |
287 | asmjit::Label LoopkRem = a->newLabel(); |
288 | int rowRegs = mRegBlocksRem; |
289 | |
290 | // init C registers |
291 | initCRegs(a, rowRegs, colRegs); |
292 | |
293 | // init k loop index |
294 | a->xor_(kIdx.r32(), kIdx.r32()); |
295 | a->bind(LoopkRem); |
296 | |
297 | // k is incremented by row_interleave |
298 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
299 | |
300 | genComputeBlock<inst_set_t::avx2>( |
301 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
302 | |
303 | // update buffer_A address for next k iteration |
304 | a->add( |
305 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
306 | |
307 | // update buffer_B address for next k iteration |
308 | a->add( |
309 | buffer_B, |
310 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
311 | // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * |
312 | // sizeof(int8_t))); |
313 | |
314 | a->cmp(kIdx, kSize); |
315 | a->jl(LoopkRem); |
316 | |
317 | // store C matrix |
318 | storeCRegs<inst_set_t::avx2>( |
319 | a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
320 | } |
321 | |
322 | a->emitEpilog(frame); |
323 | |
324 | jit_micro_kernel_fp fn; |
325 | asmjit::Error err; |
326 | { |
327 | std::unique_lock<std::mutex> lock(rtMutex_); |
328 | err = runtime().add(&fn, &code); |
329 | } |
330 | if (err) { |
331 | std::cout << "Error: in fn add" << std::endl; |
332 | return nullptr; |
333 | } |
334 | |
335 | #if defined(FBGEMM_LOG_CODE) |
336 | fclose(codeLogfile); |
337 | delete codeLogger; |
338 | #endif |
339 | |
340 | return fn; |
341 | }); |
342 | } |
343 | |
344 | /** |
345 | * Instantiate the inst_set_t::avx2 instructions for store kernel. |
346 | * |
347 | */ |
348 | template void |
349 | CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<inst_set_t::avx2>( |
350 | x86::Emitter* a, |
351 | int rowRegs, |
352 | int colRegs, |
353 | x86::Gp C_Offset, |
354 | x86::Gp ldcReg, |
355 | bool accum); |
356 | |
357 | /** |
358 | * Instantiate the inst_set_t::avx512 instructions for store kernel. |
359 | * |
360 | */ |
361 | template void |
362 | CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<inst_set_t::avx512>( |
363 | x86::Emitter* a, |
364 | int rowRegs, |
365 | int colRegs, |
366 | x86::Gp C_Offset, |
367 | x86::Gp ldcReg, |
368 | bool accum); |
369 | |
370 | /** |
371 | * Instantiate the inst_set_t::avx512_ymm instructions for store kernel. |
372 | * |
373 | */ |
374 | template void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< |
375 | inst_set_t::avx512_ymm>( |
376 | x86::Emitter* a, |
377 | int rowRegs, |
378 | int colRegs, |
379 | x86::Gp C_Offset, |
380 | x86::Gp ldcReg, |
381 | bool accum); |
382 | |
383 | } // namespace fbgemm |
384 | |