1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmI64.h" |
9 | |
10 | #if defined(__x86_64__) || defined(__i386__) || \ |
11 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
12 | #include <immintrin.h> |
13 | #endif |
14 | #include <cmath> |
15 | #include <iostream> |
16 | #include <vector> |
17 | |
18 | #include "./GenerateKernel.h" |
19 | #include "./RefImplementations.h" |
20 | #include "fbgemm/PackingTraits-inl.h" |
21 | |
22 | using namespace std; |
23 | |
24 | namespace fbgemm { |
25 | |
26 | /** |
27 | * Generate AVX2 instructions for computing block in the rank-k update of 32-bit |
28 | * Accmulation kernel. |
29 | */ |
30 | template <> |
31 | template <inst_set_t instSet> |
32 | void CodeGenBase<int64_t, int64_t, int64_t, int64_t>::genComputeBlock( |
33 | x86::Emitter* a, |
34 | x86::Gp buffer_A, |
35 | x86::Gp buffer_B, |
36 | x86::Gp B_pf, |
37 | int rowRegs, |
38 | int colRegs, |
39 | int lda) { |
40 | using VecRegT = typename simd_info<instSet>::vec_reg_t; |
41 | constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64; |
42 | |
43 | // used for matrix B |
44 | VecRegT BReg(31); |
45 | |
46 | // temporary register |
47 | VecRegT res1(30); |
48 | |
49 | for (int j = 0; j < colRegs; ++j) { |
50 | // load B |
51 | a->vmovaps( |
52 | BReg, |
53 | x86::Mem( |
54 | buffer_B, |
55 | j * vectorLen * sizeof(int64_t), |
56 | simd_info<instSet>::WIDTH_BYTES)); |
57 | // load A, broadcast and fmas |
58 | for (int i = 0; i < rowRegs; ++i) { |
59 | a->vpmullq( |
60 | res1, |
61 | BReg, |
62 | x86::qword_ptr(buffer_A, (i * lda) * sizeof(int64_t))._1to8()); |
63 | a->vpaddq(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j)); |
64 | } |
65 | // TODO: need to tune |
66 | a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int64_t))); |
67 | } |
68 | } |
69 | |
70 | /** |
71 | * Generate AVX2 instructions for storing the C registers back to the memory in |
72 | * 32-bit Accumulation kernel. |
73 | */ |
74 | template <> |
75 | template <inst_set_t instSet> |
76 | void CodeGenBase<int64_t, int64_t, int64_t, int64_t>::storeCRegs( |
77 | x86::Emitter* a, |
78 | int rowRegs, |
79 | int colRegs, |
80 | x86::Gp C_Offset, |
81 | x86::Gp ldcReg, |
82 | bool accum) { |
83 | using VecT = typename simd_info<instSet>::vec_reg_t; |
84 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64; |
85 | |
86 | for (int i = 0; i < rowRegs; ++i) { |
87 | if (i != 0) { |
88 | a->add(C_Offset, ldcReg); |
89 | } else { |
90 | a->xor_(C_Offset.r32(), C_Offset.r32()); |
91 | } |
92 | for (int j = 0; j < colRegs; ++j) { |
93 | if (accum) { |
94 | a->vpaddq( |
95 | VecT(i * colRegs + j), |
96 | VecT(i * colRegs + j), |
97 | x86::dword_ptr( |
98 | a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int64_t))); |
99 | } |
100 | a->vmovups( |
101 | x86::dword_ptr( |
102 | a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int64_t)), |
103 | VecT(i * colRegs + j)); |
104 | } |
105 | } |
106 | } |
107 | |
108 | /** |
109 | * Get or Create the avx512 instructions for int64_t GEMM macro-kernel. |
110 | */ |
111 | template <> |
112 | template <inst_set_t instSet> |
113 | CodeGenBase<int64_t, int64_t, int64_t, int64_t>::jit_micro_kernel_fp |
114 | CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate( |
115 | bool accum, |
116 | int32_t mc, |
117 | int32_t nc, |
118 | int32_t /* unused */) { |
119 | static constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64; |
120 | |
121 | tuple<bool, int, int, int, int, int, int> kernelSig; |
122 | int kBlock; |
123 | int nBlock; |
124 | int mRegBlockSize; |
125 | int nRegBlockSize; |
126 | |
127 | if (blocking_params) { |
128 | kBlock = blocking_params->KCB; |
129 | nBlock = blocking_params->NCB; |
130 | mRegBlockSize = blocking_params->MR; |
131 | nRegBlockSize = blocking_params->NR; |
132 | } else { |
133 | kBlock = PackingTraits<int64_t, int64_t, instSet>::KCB; |
134 | nBlock = PackingTraits<int64_t, int64_t, instSet>::NCB; |
135 | mRegBlockSize = PackingTraits<int64_t, int64_t, instSet>::MR; |
136 | nRegBlockSize = PackingTraits<int64_t, int64_t, instSet>::NR; |
137 | } |
138 | |
139 | kernelSig = |
140 | make_tuple(accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); |
141 | |
142 | return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { |
143 | asmjit::CodeHolder code; |
144 | code.init(runtime().environment()); |
145 | x86::Assembler assembler(&code); |
146 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
147 | #ifdef FBGEMM_LOG_CODE |
148 | // generated code logging |
149 | FILE* codeLogfile = fopen( |
150 | getCodeLoggingFile<instSet>( |
151 | accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) |
152 | .c_str(), |
153 | "w" ); |
154 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); |
155 | if (codeLogger) { |
156 | code.setLogger(codeLogger); |
157 | } |
158 | #endif |
159 | |
160 | const int maxMRegs = mRegBlockSize; |
161 | (void)maxMRegs; // Suppress unused variable warning |
162 | const int maxNRegs = nRegBlockSize / vectorLen; |
163 | assert( |
164 | maxMRegs * maxNRegs <= 30 && |
165 | "MR*(NR*64/512) \ |
166 | must be <= 29 (available registers constraint)" ); |
167 | |
168 | const int mRegBlocks = mc / mRegBlockSize; |
169 | const int mRegBlocksRem = mc % mRegBlockSize; |
170 | |
171 | // arguments to the function created |
172 | x86::Gp buffer_A = a->zdi(); |
173 | x86::Gp buffer_B = a->zsi(); |
174 | x86::Gp B_pf = a->zdx(); |
175 | x86::Gp CBase = a->zcx(); |
176 | x86::Gp kSize = a->gpz(8); |
177 | x86::Gp ldcReg = a->gpz(9); |
178 | |
179 | asmjit::FuncDetail func; |
180 | func.init( |
181 | asmjit::FuncSignatureT< |
182 | void, |
183 | int64_t*, |
184 | int64_t*, |
185 | int64_t*, |
186 | int64_t*, |
187 | int, |
188 | int>(asmjit::CallConvId::kHost), |
189 | a->environment()); |
190 | |
191 | asmjit::FuncFrame frame; |
192 | frame.init(func); |
193 | |
194 | frame.setDirtyRegs( |
195 | asmjit::RegGroup::kVec, |
196 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
197 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | |
198 | asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
199 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); |
200 | frame.setDirtyRegs( |
201 | asmjit::RegGroup::kGp, |
202 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
203 | |
204 | asmjit::FuncArgsAssignment args(&func); |
205 | args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); |
206 | |
207 | args.updateFuncFrame(frame); |
208 | frame.finalize(); |
209 | |
210 | a->emitProlog(frame); |
211 | a->emitArgsAssignment(frame, args); |
212 | |
213 | asmjit::Label LoopMBlocks = a->newLabel(); |
214 | asmjit::Label LoopNBlocks = a->newLabel(); |
215 | asmjit::Label Loopk = a->newLabel(); |
216 | |
217 | x86::Gp buffer_B_saved = a->gpz(10); |
218 | x86::Gp C_Offset = a->gpz(11); |
219 | x86::Gp B_pf_saved = a->gpz(12); |
220 | x86::Gp iIdx = a->gpz(13); |
221 | x86::Gp jIdx = a->gpz(14); |
222 | x86::Gp kIdx = a->gpz(15); |
223 | |
224 | a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int64_t))); |
225 | a->imul(kSize, kSize, static_cast<asmjit::Imm>(sizeof(int64_t))); |
226 | |
227 | // save B_buffer address |
228 | a->mov(buffer_B_saved, buffer_B); |
229 | a->mov(B_pf_saved, B_pf); |
230 | |
231 | int currColRegs = nc / vectorLen; |
232 | int colRegs = std::min(currColRegs, maxNRegs); |
233 | if (mRegBlocks > 0) { |
234 | // move 0 to iteration variables |
235 | a->xor_(iIdx.r32(), iIdx.r32()); |
236 | |
237 | a->bind(LoopMBlocks); |
238 | a->inc(iIdx); |
239 | a->xor_(jIdx.r32(), jIdx.r32()); |
240 | |
241 | a->bind(LoopNBlocks); |
242 | a->inc(jIdx); |
243 | |
244 | int rowRegs = mRegBlockSize; |
245 | |
246 | // init C registers |
247 | initCRegs(a, rowRegs, colRegs); |
248 | |
249 | // init k loop index |
250 | a->xor_(kIdx.r32(), kIdx.r32()); |
251 | a->bind(Loopk); |
252 | |
253 | // k is incremented by 1 |
254 | a->add(kIdx, static_cast<asmjit::Imm>(sizeof(int64_t))); |
255 | |
256 | genComputeBlock<instSet>( |
257 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
258 | |
259 | // update buffer_A address for next k iteration |
260 | a->add(buffer_A, static_cast<asmjit::Imm>(sizeof(int64_t))); |
261 | |
262 | // update buffer_B address for next k iteration |
263 | a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t))); |
264 | a->add(B_pf, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t))); |
265 | |
266 | a->cmp(kIdx, kSize); |
267 | a->jl(Loopk); |
268 | |
269 | // store C matrix |
270 | storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
271 | |
272 | // reset A |
273 | a->sub(buffer_A, kSize); |
274 | |
275 | // B for next block |
276 | a->mov(buffer_B, buffer_B_saved); |
277 | // using C_Offset as temp reg |
278 | a->imul( |
279 | C_Offset, |
280 | jIdx, |
281 | static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t))); |
282 | a->add(buffer_B, C_Offset); |
283 | a->mov(B_pf, B_pf_saved); |
284 | a->add(B_pf, C_Offset); |
285 | |
286 | // increment C for next B block |
287 | a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t))); |
288 | |
289 | int jLoopTrips = currColRegs / maxNRegs; |
290 | // jLoopTrips should be at least 1 |
291 | jLoopTrips = jLoopTrips ? jLoopTrips : 1; |
292 | a->cmp(jIdx, jLoopTrips); |
293 | a->jl(LoopNBlocks); |
294 | |
295 | // increment A for next block |
296 | a->add( |
297 | buffer_A, |
298 | static_cast<asmjit::Imm>(rowRegs * kBlock * sizeof(int64_t))); |
299 | |
300 | // increment C for next A block |
301 | a->sub( |
302 | CBase, |
303 | static_cast<asmjit::Imm>( |
304 | jLoopTrips * nRegBlockSize * sizeof(int64_t))); |
305 | a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); |
306 | a->add(CBase, C_Offset); |
307 | |
308 | // reset B |
309 | a->mov(buffer_B, buffer_B_saved); |
310 | a->mov(B_pf, B_pf_saved); |
311 | a->cmp(iIdx, mRegBlocks); |
312 | a->jl(LoopMBlocks); |
313 | } |
314 | // generate code for remainder |
315 | if (mRegBlocksRem > 0) { |
316 | assert(false); |
317 | asmjit::Label LoopNRem = a->newLabel(); |
318 | asmjit::Label LoopkRem = a->newLabel(); |
319 | int rowRegs = mRegBlocksRem; |
320 | |
321 | a->xor_(jIdx.r32(), jIdx.r32()); |
322 | a->bind(LoopNRem); |
323 | a->inc(jIdx); |
324 | |
325 | // init C registers |
326 | initCRegs(a, rowRegs, colRegs); |
327 | |
328 | // init k loop index |
329 | a->xor_(kIdx.r32(), kIdx.r32()); |
330 | a->bind(LoopkRem); |
331 | |
332 | // k is incremented by 1 |
333 | a->add(kIdx, static_cast<asmjit::Imm>(sizeof(int64_t))); |
334 | |
335 | genComputeBlock<instSet>( |
336 | a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); |
337 | |
338 | // update buffer_A address for next k iteration |
339 | a->add(buffer_A, static_cast<asmjit::Imm>(sizeof(int64_t))); |
340 | |
341 | // update buffer_B address for next k iteration |
342 | a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t))); |
343 | a->add(B_pf, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t))); |
344 | |
345 | a->cmp(kIdx, kSize); |
346 | a->jl(LoopkRem); |
347 | |
348 | // reset A |
349 | a->sub(buffer_A, kSize); |
350 | |
351 | // B for next block |
352 | // using C_Offset as temp reg |
353 | a->imul( |
354 | C_Offset, |
355 | jIdx, |
356 | static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t))); |
357 | a->mov(buffer_B, buffer_B_saved); |
358 | a->add(buffer_B, C_Offset); |
359 | a->mov(B_pf, B_pf_saved); |
360 | a->add(B_pf, C_Offset); |
361 | |
362 | // store C matrix |
363 | storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); |
364 | |
365 | // increment C for next B block |
366 | a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t))); |
367 | |
368 | int jLoopTrips = currColRegs / maxNRegs; |
369 | // jLoopTrips should be at least 1 |
370 | jLoopTrips = jLoopTrips ? jLoopTrips : 1; |
371 | a->cmp(jIdx, jLoopTrips); |
372 | a->jl(LoopNRem); |
373 | } |
374 | |
375 | a->emitEpilog(frame); |
376 | |
377 | jit_micro_kernel_fp fn; |
378 | asmjit::Error err; |
379 | { |
380 | unique_lock<mutex> lock(rtMutex_); |
381 | err = runtime().add(&fn, &code); |
382 | } |
383 | if (err) { |
384 | cout << "Error: in fn add" << endl; |
385 | return nullptr; |
386 | } |
387 | |
388 | #ifdef FBGEMM_LOG_CODE |
389 | fclose(codeLogfile); |
390 | delete codeLogger; |
391 | #endif |
392 | |
393 | return fn; |
394 | }); |
395 | } |
396 | |
397 | /** |
398 | * Instatiate the AVX512 instructions for int64_t GEMM macro-kernel. |
399 | */ |
400 | template CodeGenBase<int64_t, int64_t, int64_t, int64_t>::jit_micro_kernel_fp |
401 | CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate< |
402 | inst_set_t::avx512>(bool accum, int32_t mc, int32_t nc, int32_t kc); |
403 | |
404 | // Expected to have overflows |
405 | NO_SANITIZE("undefined" ) |
406 | void cblas_gemm_i64_i64acc( |
407 | matrix_op_t transa, |
408 | matrix_op_t transb, |
409 | int M, |
410 | int N, |
411 | int K, |
412 | const int64_t* A, |
413 | int lda, |
414 | const int64_t* B, |
415 | int ldb, |
416 | bool accumulate, |
417 | int64_t* C, |
418 | int ldc) { |
419 | cpuinfo_initialize(); |
420 | if (!fbgemmHasAvx512Support()) { |
421 | cblas_gemm_i64_i64acc_ref( |
422 | transa, transb, M, N, K, A, lda, B, ldb, accumulate, C, ldc); |
423 | return; |
424 | } |
425 | constexpr int MCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MCB; |
426 | constexpr int NCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NCB; |
427 | constexpr int KCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::KCB; |
428 | constexpr int MR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MR; |
429 | constexpr int NR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NR; |
430 | static_assert(MCB % MR == 0, "MR must divide MCB" ); |
431 | static_assert(NCB % NR == 0, "NR must divide NCB" ); |
432 | constexpr int VLEN = |
433 | simd_info<inst_set_t::avx512>::WIDTH_BYTES / sizeof(int64_t); |
434 | static_assert(NR % VLEN == 0, "VLEN must divide NR" ); |
435 | |
436 | using CodeGenType = CodeGenBase<int64_t, int64_t, int64_t, int64_t>; |
437 | CodeGenType codeObj; |
438 | CodeGenType::jit_micro_kernel_fp fn = |
439 | codeObj.getOrCreate<inst_set_t::avx512>(true /* accum */, MCB, NCB, KCB); |
440 | CodeGenType::jit_micro_kernel_fp fn_noacc; |
441 | if (!accumulate) { |
442 | fn_noacc = codeObj.getOrCreate<inst_set_t::avx512>( |
443 | false /* accum */, MCB, NCB, KCB); |
444 | } |
445 | |
446 | vector<int64_t> At, Bt; |
447 | // TODO: handle transpose during packing |
448 | if (transa == matrix_op_t::Transpose) { |
449 | At.resize(M * K); |
450 | for (int i = 0; i < M; ++i) { |
451 | for (int k = 0; k < K; ++k) { |
452 | At.at(i * K + k) = A[i + k * lda]; |
453 | } |
454 | } |
455 | A = At.data(); |
456 | lda = K; |
457 | } |
458 | if (transb == matrix_op_t::Transpose) { |
459 | Bt.resize(K * N); |
460 | for (int k = 0; k < K; ++k) { |
461 | for (int j = 0; j < N; ++j) { |
462 | Bt.at(k * N + j) = B[k + j * ldb]; |
463 | } |
464 | } |
465 | B = Bt.data(); |
466 | ldb = N; |
467 | } |
468 | |
469 | alignas(64) array<int64_t, MCB * KCB> packA; |
470 | alignas(64) array<int64_t, KCB * NCB> packB; |
471 | alignas(64) array<int64_t, MCB * NCB> packC; |
472 | |
473 | for (int ic = 0; ic < M; ic += MCB) { |
474 | for (int kc = 0; kc < K; kc += KCB) { |
475 | // pack A |
476 | for (int i = 0; i < std::min(MCB, M - ic); ++i) { |
477 | memcpy( |
478 | &packA[i * KCB], |
479 | A + (ic + i) * lda + kc, |
480 | std::min(K - kc, KCB) * sizeof(int64_t)); |
481 | } |
482 | |
483 | for (int jc = 0; jc < N; jc += NCB) { |
484 | // pack B |
485 | for (int i = 0; i < std::min(KCB, K - kc); ++i) { |
486 | memcpy( |
487 | &packB[i * NCB], |
488 | B + (kc + i) * ldb + jc, |
489 | std::min(NCB, N - jc) * sizeof(int64_t)); |
490 | } |
491 | |
492 | if (M - ic >= MCB && N - jc >= NCB) { |
493 | if (kc == 0 && !accumulate) { |
494 | fn_noacc( |
495 | packA.data(), |
496 | packB.data(), |
497 | packB.data(), |
498 | C + ic * ldc + jc, |
499 | std::min(KCB, K - kc), |
500 | ldc); |
501 | } else { |
502 | fn(packA.data(), |
503 | packB.data(), |
504 | packB.data(), |
505 | C + ic * ldc + jc, |
506 | std::min(KCB, K - kc), |
507 | ldc); |
508 | } |
509 | } else { |
510 | // remainder |
511 | if (kc == 0 && !accumulate) { |
512 | fn_noacc( |
513 | packA.data(), |
514 | packB.data(), |
515 | packB.data(), |
516 | packC.data(), |
517 | std::min(KCB, K - kc), |
518 | NCB); |
519 | } else { |
520 | for (int i = 0; i < std::min(MCB, M - ic); ++i) { |
521 | memcpy( |
522 | &packC[i * NCB], |
523 | C + (ic + i) * ldc + jc, |
524 | std::min(NCB, N - jc) * sizeof(int64_t)); |
525 | } |
526 | fn(packA.data(), |
527 | packB.data(), |
528 | packB.data(), |
529 | packC.data(), |
530 | std::min(KCB, K - kc), |
531 | NCB); |
532 | } |
533 | for (int i = 0; i < std::min(MCB, M - ic); ++i) { |
534 | memcpy( |
535 | C + (ic + i) * ldc + jc, |
536 | &packC[i * NCB], |
537 | std::min(NCB, N - jc) * sizeof(int64_t)); |
538 | } |
539 | } |
540 | } // jc |
541 | } // kc |
542 | } // ic |
543 | } |
544 | |
545 | } // namespace fbgemm |
546 | |