1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #include <iostream> |
8 | #include "./CodeGenHelpers.h" |
9 | #include "./GenerateKernel.h" |
10 | |
11 | namespace fbgemm { |
12 | |
13 | namespace x86 = asmjit::x86; |
14 | |
15 | /** |
16 | * Generate AVX512 instructions for computing block in the rank-k update of |
17 | * 16-bit Accmulation kernel. |
18 | */ |
19 | template <> |
20 | template <inst_set_t instSet> |
21 | void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock( |
22 | x86::Emitter* a, |
23 | x86::Gp buffer_A, |
24 | x86::Gp buffer_B, |
25 | x86::Gp /* unused (reserved for prefetching)*/, |
26 | int rowRegs, |
27 | int colRegs, |
28 | int lda) { |
29 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
30 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
31 | |
32 | // used for matrix A |
33 | VecRegT AReg(29); |
34 | |
35 | VecRegT tmpReg(30); |
36 | |
37 | // We start allocating BRegs from zmm28 and then allocate zmm27 and so on. |
38 | for (int j = 0; j < colRegs; ++j) { |
39 | a->vmovups( |
40 | VecRegT(28 - j), |
41 | x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); |
42 | } |
43 | |
44 | for (int i = 0; i < rowRegs; ++i) { |
45 | // broadcast A |
46 | a->vpbroadcastw( |
47 | AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); |
48 | for (int j = 0; j < colRegs; ++j) { |
49 | a->vpmaddubsw(tmpReg, AReg, VecRegT(28 - j)); |
50 | a->vpaddsw(VecRegT(i * colRegs + j), tmpReg, VecRegT(i * colRegs + j)); |
51 | // Prefetching is hurting performance in some cases |
52 | // because prefetch instructions itself consumes a slot |
53 | // in pipeline issue thus slowing down the kernel. |
54 | // if((i == rowRegs - 1) && j % 2 == 0){ |
55 | // a->prefetcht0(x86::dword_ptr(B_pf, j*VecLen*sizeof(int8_t))); |
56 | //} |
57 | } |
58 | } |
59 | } |
60 | |
61 | /** |
62 | * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. |
63 | * |
64 | */ |
65 | template <> |
66 | template <inst_set_t instSet> |
67 | CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp |
68 | CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate( |
69 | bool accum, |
70 | int32_t mc, |
71 | int32_t nc, |
72 | int32_t kc) { |
73 | (void)kc; // Suppress unused variable warning |
74 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES; |
75 | |
76 | std::tuple<bool, int, int, int, int, int, int> kernelSig; |
77 | int kBlock; |
78 | int nBlock; |
79 | int mRegBlockSize; |
80 | int nRegBlockSize; |
81 | int nRegBlockSizeMin; |
82 | int row_interleave; |
83 | |
84 | if (blocking_params) { |
85 | kBlock = blocking_params->KCB; |
86 | nBlock = blocking_params->NCB; |
87 | mRegBlockSize = blocking_params->MR; |
88 | nRegBlockSize = blocking_params->NR; |
89 | nRegBlockSizeMin = blocking_params->NR_MIN; |
90 | row_interleave = blocking_params->ROW_INTERLEAVE; |
91 | } else { |
92 | kBlock = PackingTraits<uint8_t, int16_t, instSet>::KCB; |
93 | nBlock = PackingTraits<uint8_t, int16_t, instSet>::NCB; |
94 | mRegBlockSize = PackingTraits<uint8_t, int16_t, instSet>::MR; |
95 | nRegBlockSize = PackingTraits<uint8_t, int16_t, instSet>::NR; |
96 | nRegBlockSizeMin = PackingTraits<uint8_t, int16_t, instSet>::NR_MIN; |
97 | row_interleave = PackingTraits<uint8_t, int16_t, instSet>::ROW_INTERLEAVE; |
98 | } |
99 | (void)nRegBlockSizeMin; // Suppress unused variable warning |
100 | |
101 | kernelSig = std::make_tuple( |
102 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); |
103 | |
104 | return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { |
105 | asmjit::CodeHolder code; |
106 | code.init(runtime().environment()); |
107 | x86::Assembler assembler(&code); |
108 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
109 | |
110 | #if defined(FBGEMM_LOG_CODE) |
111 | // generated code logging |
112 | FILE* codeLogfile = fopen( |
113 | getCodeLoggingFile<instSet>( |
114 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) |
115 | .c_str(), |
116 | "w" ); |
117 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); |
118 | if (codeLogger) { |
119 | code.setLogger(codeLogger); |
120 | } |
121 | #endif |
122 | |
123 | assert( |
124 | kc % row_interleave == 0 && "kc must be a multiple of row_interleave" ); |
125 | assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN" ); |
126 | const int maxMRegs = mRegBlockSize; |
127 | (void)maxMRegs; // Suppress unused variable warning |
128 | const int maxNRegs = nRegBlockSize * row_interleave / vectorLen; |
129 | assert( |
130 | (maxMRegs + 1) * maxNRegs <= 29 && |
131 | "number of zmm registers for C + one row for loading B: \ |
132 | MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \ |
133 | must be <= 29(available registers constraint)" ); |
134 | int mRegBlocks = mc / mRegBlockSize; |
135 | int mRegBlocksRem = mc % mRegBlockSize; |
136 | |
137 | // arguments to the function created |
138 | x86::Gp buffer_A = a->zdi(); |
139 | x86::Gp buffer_B = a->zsi(); |
140 | x86::Gp B_pf = a->zdx(); |
141 | x86::Gp CBase = a->zcx(); |
142 | x86::Gp kSize = a->gpz(8); |
143 | x86::Gp ldcReg = a->gpz(9); |
144 | |
145 | asmjit::FuncDetail func; |
146 | func.init( |
147 | asmjit::FuncSignatureT< |
148 | void, |
149 | uint8_t*, |
150 | int8_t*, |
151 | int8_t*, |
152 | int32_t*, |
153 | int, |
154 | int>(asmjit::CallConvId::kHost), |
155 | a->environment()); |
156 | |
157 | asmjit::FuncFrame frame; |
158 | frame.init(func); |
159 | |
160 | frame.setDirtyRegs( |
161 | asmjit::RegGroup::kVec, |
162 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
163 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | |
164 | asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
165 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); |
166 | frame.setDirtyRegs( |
167 | asmjit::RegGroup::kGp, |
168 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
169 | |
170 | asmjit::FuncArgsAssignment args(&func); |
171 | args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); |
172 | |
173 | args.updateFuncFrame(frame); |
174 | frame.finalize(); |
175 | |
176 | a->emitProlog(frame); |
177 | a->emitArgsAssignment(frame, args); |
178 | |
179 | asmjit::Label LoopMBlocks = a->newLabel(); |
180 | asmjit::Label LoopNBlocks = a->newLabel(); |
181 | asmjit::Label Loopk = a->newLabel(); |
182 | |
183 | x86::Gp buffer_B_saved = a->gpz(10); |
184 | x86::Gp C_Offset = a->gpz(11); |
185 | // x86::Gp B_pf_saved = a->gpz(12); |
186 | x86::Gp iIdx = a->gpz(13); |
187 | x86::Gp jIdx = a->gpz(14); |
188 | x86::Gp kIdx = a->gpz(15); |
189 | |
190 | // save B_buffer address |
191 | a->mov(buffer_B_saved, buffer_B); |
192 | // a->mov(B_pf_saved, B_pf); |
193 | |
194 | int currColRegs = nc * row_interleave / vectorLen; |
195 | int colRegs = std::min(currColRegs, maxNRegs); |
196 | if (mRegBlocks > 0) { |
197 | // move 0 to iteration variables |
198 | a->xor_(iIdx.r32(), iIdx.r32()); |
199 | |
200 | a->bind(LoopMBlocks); |
201 | a->inc(iIdx); |
202 | a->xor_(jIdx.r32(), jIdx.r32()); |
203 | |
204 | a->bind(LoopNBlocks); |
205 | a->inc(jIdx); |
206 | |
207 | int rowRegs = mRegBlockSize; |
208 | |
209 | // init C registers |
210 | initCRegs(a, rowRegs, colRegs); |
211 | |
212 | // init k loop index |
213 | a->xor_(kIdx.r32(), kIdx.r32()); |
214 | a->bind(Loopk); |
215 | // k is incremented by row_interleave |
216 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
217 | |
218 | genComputeBlock<instSet>( |
219 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
220 | |
221 | // update buffer_A address for next k iteration |
222 | a->add( |
223 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
224 | |
225 | // update buffer_B address for next k iteration |
226 | a->add( |
227 | buffer_B, |
228 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
229 | // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * |
230 | // sizeof(int8_t))); |
231 | |
232 | a->cmp(kIdx, kSize); |
233 | a->jl(Loopk); |
234 | |
235 | // store C matrix |
236 | storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
237 | |
238 | // reset A |
239 | a->sub(buffer_A, kSize); |
240 | |
241 | // B for next block |
242 | a->mov(buffer_B, buffer_B_saved); |
243 | // using C_Offset as temp reg |
244 | a->imul( |
245 | C_Offset, |
246 | jIdx, |
247 | static_cast<asmjit::Imm>( |
248 | nRegBlockSize * row_interleave * sizeof(int8_t))); |
249 | a->add(buffer_B, C_Offset); |
250 | |
251 | // increment C for next block |
252 | a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); |
253 | |
254 | int jLoopTrips = currColRegs / maxNRegs; |
255 | // jLoopTrips should be at least 1 |
256 | jLoopTrips = jLoopTrips ? jLoopTrips : 1; |
257 | a->cmp(jIdx, jLoopTrips); |
258 | a->jl(LoopNBlocks); |
259 | |
260 | // increment A for next block |
261 | a->add( |
262 | buffer_A, |
263 | static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); |
264 | |
265 | // increment C for next A block |
266 | a->sub( |
267 | CBase, |
268 | static_cast<asmjit::Imm>( |
269 | jLoopTrips * nRegBlockSize * sizeof(int32_t))); |
270 | a->imul( |
271 | C_Offset, |
272 | ldcReg, |
273 | static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); |
274 | a->add(CBase, C_Offset); |
275 | |
276 | // reset B |
277 | a->mov(buffer_B, buffer_B_saved); |
278 | // a->mov(B_pf, B_pf_saved); |
279 | |
280 | a->cmp(iIdx, mRegBlocks); |
281 | a->jl(LoopMBlocks); |
282 | } |
283 | // generate code for remainder |
284 | if (mRegBlocksRem > 0) { |
285 | asmjit::Label LoopNRem = a->newLabel(); |
286 | asmjit::Label LoopkRem = a->newLabel(); |
287 | int rowRegs = mRegBlocksRem; |
288 | |
289 | a->xor_(jIdx.r32(), jIdx.r32()); |
290 | a->bind(LoopNRem); |
291 | a->inc(jIdx); |
292 | |
293 | // init C registers |
294 | initCRegs(a, rowRegs, colRegs); |
295 | |
296 | // init k loop index |
297 | a->xor_(kIdx.r32(), kIdx.r32()); |
298 | a->bind(LoopkRem); |
299 | |
300 | // k is incremented by row_interleave |
301 | a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); |
302 | |
303 | genComputeBlock<instSet>( |
304 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
305 | |
306 | // update buffer_A address for next k iteration |
307 | a->add( |
308 | buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); |
309 | |
310 | // update buffer_B address for next k iteration |
311 | a->add( |
312 | buffer_B, |
313 | static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); |
314 | // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * |
315 | // sizeof(int8_t))); |
316 | |
317 | a->cmp(kIdx, kSize); |
318 | a->jl(LoopkRem); |
319 | |
320 | // reset A |
321 | a->sub(buffer_A, kSize); |
322 | |
323 | // B for next block |
324 | a->mov(buffer_B, buffer_B_saved); |
325 | // using C_Offset as temp reg |
326 | a->imul( |
327 | C_Offset, |
328 | jIdx, |
329 | static_cast<asmjit::Imm>( |
330 | nRegBlockSize * row_interleave * sizeof(int8_t))); |
331 | a->add(buffer_B, C_Offset); |
332 | |
333 | // store C matrix |
334 | storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
335 | |
336 | // increment C for next block |
337 | a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); |
338 | |
339 | int jLoopTrips = currColRegs / maxNRegs; |
340 | // jLoopTrips should be at least 1 |
341 | jLoopTrips = jLoopTrips ? jLoopTrips : 1; |
342 | a->cmp(jIdx, jLoopTrips); |
343 | a->jl(LoopNRem); |
344 | } |
345 | |
346 | a->emitEpilog(frame); |
347 | |
348 | jit_micro_kernel_fp fn; |
349 | asmjit::Error err; |
350 | { |
351 | std::unique_lock<std::mutex> lock(rtMutex_); |
352 | err = runtime().add(&fn, &code); |
353 | } |
354 | if (err) { |
355 | std::cout << "Error: in fn add" << std::endl; |
356 | return nullptr; |
357 | } |
358 | |
359 | #if defined(FBGEMM_LOG_CODE) |
360 | fclose(codeLogfile); |
361 | delete codeLogger; |
362 | #endif |
363 | |
364 | return fn; |
365 | }); |
366 | } |
367 | |
368 | /** |
369 | * Instatiate the AVX512 instructions for 16-bit Accumulation macro-kernel. |
370 | * |
371 | */ |
372 | template CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp |
373 | CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( |
374 | bool accum, |
375 | int32_t mc, |
376 | int32_t nc, |
377 | int32_t kc); |
378 | |
379 | /** |
380 | * Instatiate the AVX512_256 instructions for 16-bit Accumulation macro-kernel. |
381 | * |
382 | */ |
383 | template CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp |
384 | CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate< |
385 | inst_set_t::avx512_ymm>(bool accum, int32_t mc, int32_t nc, int32_t kc); |
386 | |
387 | } // namespace fbgemm |
388 | |