1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmI64.h"
9
10#if defined(__x86_64__) || defined(__i386__) || \
11 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
12#include <immintrin.h>
13#endif
14#include <cmath>
15#include <iostream>
16#include <vector>
17
18#include "./GenerateKernel.h"
19#include "./RefImplementations.h"
20#include "fbgemm/PackingTraits-inl.h"
21
22using namespace std;
23
24namespace fbgemm {
25
26/**
27 * Generate AVX2 instructions for computing block in the rank-k update of 32-bit
28 * Accmulation kernel.
29 */
30template <>
31template <inst_set_t instSet>
32void CodeGenBase<int64_t, int64_t, int64_t, int64_t>::genComputeBlock(
33 x86::Emitter* a,
34 x86::Gp buffer_A,
35 x86::Gp buffer_B,
36 x86::Gp B_pf,
37 int rowRegs,
38 int colRegs,
39 int lda) {
40 using VecRegT = typename simd_info<instSet>::vec_reg_t;
41 constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64;
42
43 // used for matrix B
44 VecRegT BReg(31);
45
46 // temporary register
47 VecRegT res1(30);
48
49 for (int j = 0; j < colRegs; ++j) {
50 // load B
51 a->vmovaps(
52 BReg,
53 x86::Mem(
54 buffer_B,
55 j * vectorLen * sizeof(int64_t),
56 simd_info<instSet>::WIDTH_BYTES));
57 // load A, broadcast and fmas
58 for (int i = 0; i < rowRegs; ++i) {
59 a->vpmullq(
60 res1,
61 BReg,
62 x86::qword_ptr(buffer_A, (i * lda) * sizeof(int64_t))._1to8());
63 a->vpaddq(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j));
64 }
65 // TODO: need to tune
66 a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int64_t)));
67 }
68}
69
70/**
71 * Generate AVX2 instructions for storing the C registers back to the memory in
72 * 32-bit Accumulation kernel.
73 */
74template <>
75template <inst_set_t instSet>
76void CodeGenBase<int64_t, int64_t, int64_t, int64_t>::storeCRegs(
77 x86::Emitter* a,
78 int rowRegs,
79 int colRegs,
80 x86::Gp C_Offset,
81 x86::Gp ldcReg,
82 bool accum) {
83 using VecT = typename simd_info<instSet>::vec_reg_t;
84 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64;
85
86 for (int i = 0; i < rowRegs; ++i) {
87 if (i != 0) {
88 a->add(C_Offset, ldcReg);
89 } else {
90 a->xor_(C_Offset.r32(), C_Offset.r32());
91 }
92 for (int j = 0; j < colRegs; ++j) {
93 if (accum) {
94 a->vpaddq(
95 VecT(i * colRegs + j),
96 VecT(i * colRegs + j),
97 x86::dword_ptr(
98 a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int64_t)));
99 }
100 a->vmovups(
101 x86::dword_ptr(
102 a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int64_t)),
103 VecT(i * colRegs + j));
104 }
105 }
106}
107
108/**
109 * Get or Create the avx512 instructions for int64_t GEMM macro-kernel.
110 */
111template <>
112template <inst_set_t instSet>
113CodeGenBase<int64_t, int64_t, int64_t, int64_t>::jit_micro_kernel_fp
114CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate(
115 bool accum,
116 int32_t mc,
117 int32_t nc,
118 int32_t /* unused */) {
119 static constexpr int vectorLen = simd_info<instSet>::WIDTH_BITS / 64;
120
121 tuple<bool, int, int, int, int, int, int> kernelSig;
122 int kBlock;
123 int nBlock;
124 int mRegBlockSize;
125 int nRegBlockSize;
126
127 if (blocking_params) {
128 kBlock = blocking_params->KCB;
129 nBlock = blocking_params->NCB;
130 mRegBlockSize = blocking_params->MR;
131 nRegBlockSize = blocking_params->NR;
132 } else {
133 kBlock = PackingTraits<int64_t, int64_t, instSet>::KCB;
134 nBlock = PackingTraits<int64_t, int64_t, instSet>::NCB;
135 mRegBlockSize = PackingTraits<int64_t, int64_t, instSet>::MR;
136 nRegBlockSize = PackingTraits<int64_t, int64_t, instSet>::NR;
137 }
138
139 kernelSig =
140 make_tuple(accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize);
141
142 return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
143 asmjit::CodeHolder code;
144 code.init(runtime().environment());
145 x86::Assembler assembler(&code);
146 x86::Emitter* a = assembler.as<x86::Emitter>();
147#ifdef FBGEMM_LOG_CODE
148 // generated code logging
149 FILE* codeLogfile = fopen(
150 getCodeLoggingFile<instSet>(
151 accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize)
152 .c_str(),
153 "w");
154 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
155 if (codeLogger) {
156 code.setLogger(codeLogger);
157 }
158#endif
159
160 const int maxMRegs = mRegBlockSize;
161 (void)maxMRegs; // Suppress unused variable warning
162 const int maxNRegs = nRegBlockSize / vectorLen;
163 assert(
164 maxMRegs * maxNRegs <= 30 &&
165 "MR*(NR*64/512) \
166 must be <= 29 (available registers constraint)");
167
168 const int mRegBlocks = mc / mRegBlockSize;
169 const int mRegBlocksRem = mc % mRegBlockSize;
170
171 // arguments to the function created
172 x86::Gp buffer_A = a->zdi();
173 x86::Gp buffer_B = a->zsi();
174 x86::Gp B_pf = a->zdx();
175 x86::Gp CBase = a->zcx();
176 x86::Gp kSize = a->gpz(8);
177 x86::Gp ldcReg = a->gpz(9);
178
179 asmjit::FuncDetail func;
180 func.init(
181 asmjit::FuncSignatureT<
182 void,
183 int64_t*,
184 int64_t*,
185 int64_t*,
186 int64_t*,
187 int,
188 int>(asmjit::CallConvId::kHost),
189 a->environment());
190
191 asmjit::FuncFrame frame;
192 frame.init(func);
193
194 frame.setDirtyRegs(
195 asmjit::RegGroup::kVec,
196 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
197 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
198 asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
199 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
200 frame.setDirtyRegs(
201 asmjit::RegGroup::kGp,
202 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
203
204 asmjit::FuncArgsAssignment args(&func);
205 args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
206
207 args.updateFuncFrame(frame);
208 frame.finalize();
209
210 a->emitProlog(frame);
211 a->emitArgsAssignment(frame, args);
212
213 asmjit::Label LoopMBlocks = a->newLabel();
214 asmjit::Label LoopNBlocks = a->newLabel();
215 asmjit::Label Loopk = a->newLabel();
216
217 x86::Gp buffer_B_saved = a->gpz(10);
218 x86::Gp C_Offset = a->gpz(11);
219 x86::Gp B_pf_saved = a->gpz(12);
220 x86::Gp iIdx = a->gpz(13);
221 x86::Gp jIdx = a->gpz(14);
222 x86::Gp kIdx = a->gpz(15);
223
224 a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int64_t)));
225 a->imul(kSize, kSize, static_cast<asmjit::Imm>(sizeof(int64_t)));
226
227 // save B_buffer address
228 a->mov(buffer_B_saved, buffer_B);
229 a->mov(B_pf_saved, B_pf);
230
231 int currColRegs = nc / vectorLen;
232 int colRegs = std::min(currColRegs, maxNRegs);
233 if (mRegBlocks > 0) {
234 // move 0 to iteration variables
235 a->xor_(iIdx.r32(), iIdx.r32());
236
237 a->bind(LoopMBlocks);
238 a->inc(iIdx);
239 a->xor_(jIdx.r32(), jIdx.r32());
240
241 a->bind(LoopNBlocks);
242 a->inc(jIdx);
243
244 int rowRegs = mRegBlockSize;
245
246 // init C registers
247 initCRegs(a, rowRegs, colRegs);
248
249 // init k loop index
250 a->xor_(kIdx.r32(), kIdx.r32());
251 a->bind(Loopk);
252
253 // k is incremented by 1
254 a->add(kIdx, static_cast<asmjit::Imm>(sizeof(int64_t)));
255
256 genComputeBlock<instSet>(
257 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
258
259 // update buffer_A address for next k iteration
260 a->add(buffer_A, static_cast<asmjit::Imm>(sizeof(int64_t)));
261
262 // update buffer_B address for next k iteration
263 a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t)));
264 a->add(B_pf, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t)));
265
266 a->cmp(kIdx, kSize);
267 a->jl(Loopk);
268
269 // store C matrix
270 storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
271
272 // reset A
273 a->sub(buffer_A, kSize);
274
275 // B for next block
276 a->mov(buffer_B, buffer_B_saved);
277 // using C_Offset as temp reg
278 a->imul(
279 C_Offset,
280 jIdx,
281 static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t)));
282 a->add(buffer_B, C_Offset);
283 a->mov(B_pf, B_pf_saved);
284 a->add(B_pf, C_Offset);
285
286 // increment C for next B block
287 a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t)));
288
289 int jLoopTrips = currColRegs / maxNRegs;
290 // jLoopTrips should be at least 1
291 jLoopTrips = jLoopTrips ? jLoopTrips : 1;
292 a->cmp(jIdx, jLoopTrips);
293 a->jl(LoopNBlocks);
294
295 // increment A for next block
296 a->add(
297 buffer_A,
298 static_cast<asmjit::Imm>(rowRegs * kBlock * sizeof(int64_t)));
299
300 // increment C for next A block
301 a->sub(
302 CBase,
303 static_cast<asmjit::Imm>(
304 jLoopTrips * nRegBlockSize * sizeof(int64_t)));
305 a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
306 a->add(CBase, C_Offset);
307
308 // reset B
309 a->mov(buffer_B, buffer_B_saved);
310 a->mov(B_pf, B_pf_saved);
311 a->cmp(iIdx, mRegBlocks);
312 a->jl(LoopMBlocks);
313 }
314 // generate code for remainder
315 if (mRegBlocksRem > 0) {
316 assert(false);
317 asmjit::Label LoopNRem = a->newLabel();
318 asmjit::Label LoopkRem = a->newLabel();
319 int rowRegs = mRegBlocksRem;
320
321 a->xor_(jIdx.r32(), jIdx.r32());
322 a->bind(LoopNRem);
323 a->inc(jIdx);
324
325 // init C registers
326 initCRegs(a, rowRegs, colRegs);
327
328 // init k loop index
329 a->xor_(kIdx.r32(), kIdx.r32());
330 a->bind(LoopkRem);
331
332 // k is incremented by 1
333 a->add(kIdx, static_cast<asmjit::Imm>(sizeof(int64_t)));
334
335 genComputeBlock<instSet>(
336 a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
337
338 // update buffer_A address for next k iteration
339 a->add(buffer_A, static_cast<asmjit::Imm>(sizeof(int64_t)));
340
341 // update buffer_B address for next k iteration
342 a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t)));
343 a->add(B_pf, static_cast<asmjit::Imm>(nBlock * sizeof(int64_t)));
344
345 a->cmp(kIdx, kSize);
346 a->jl(LoopkRem);
347
348 // reset A
349 a->sub(buffer_A, kSize);
350
351 // B for next block
352 // using C_Offset as temp reg
353 a->imul(
354 C_Offset,
355 jIdx,
356 static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t)));
357 a->mov(buffer_B, buffer_B_saved);
358 a->add(buffer_B, C_Offset);
359 a->mov(B_pf, B_pf_saved);
360 a->add(B_pf, C_Offset);
361
362 // store C matrix
363 storeCRegs<instSet>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
364
365 // increment C for next B block
366 a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int64_t)));
367
368 int jLoopTrips = currColRegs / maxNRegs;
369 // jLoopTrips should be at least 1
370 jLoopTrips = jLoopTrips ? jLoopTrips : 1;
371 a->cmp(jIdx, jLoopTrips);
372 a->jl(LoopNRem);
373 }
374
375 a->emitEpilog(frame);
376
377 jit_micro_kernel_fp fn;
378 asmjit::Error err;
379 {
380 unique_lock<mutex> lock(rtMutex_);
381 err = runtime().add(&fn, &code);
382 }
383 if (err) {
384 cout << "Error: in fn add" << endl;
385 return nullptr;
386 }
387
388#ifdef FBGEMM_LOG_CODE
389 fclose(codeLogfile);
390 delete codeLogger;
391#endif
392
393 return fn;
394 });
395}
396
397/**
398 * Instatiate the AVX512 instructions for int64_t GEMM macro-kernel.
399 */
400template CodeGenBase<int64_t, int64_t, int64_t, int64_t>::jit_micro_kernel_fp
401CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate<
402 inst_set_t::avx512>(bool accum, int32_t mc, int32_t nc, int32_t kc);
403
404// Expected to have overflows
405NO_SANITIZE("undefined")
406void cblas_gemm_i64_i64acc(
407 matrix_op_t transa,
408 matrix_op_t transb,
409 int M,
410 int N,
411 int K,
412 const int64_t* A,
413 int lda,
414 const int64_t* B,
415 int ldb,
416 bool accumulate,
417 int64_t* C,
418 int ldc) {
419 cpuinfo_initialize();
420 if (!fbgemmHasAvx512Support()) {
421 cblas_gemm_i64_i64acc_ref(
422 transa, transb, M, N, K, A, lda, B, ldb, accumulate, C, ldc);
423 return;
424 }
425 constexpr int MCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MCB;
426 constexpr int NCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NCB;
427 constexpr int KCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::KCB;
428 constexpr int MR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MR;
429 constexpr int NR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NR;
430 static_assert(MCB % MR == 0, "MR must divide MCB");
431 static_assert(NCB % NR == 0, "NR must divide NCB");
432 constexpr int VLEN =
433 simd_info<inst_set_t::avx512>::WIDTH_BYTES / sizeof(int64_t);
434 static_assert(NR % VLEN == 0, "VLEN must divide NR");
435
436 using CodeGenType = CodeGenBase<int64_t, int64_t, int64_t, int64_t>;
437 CodeGenType codeObj;
438 CodeGenType::jit_micro_kernel_fp fn =
439 codeObj.getOrCreate<inst_set_t::avx512>(true /* accum */, MCB, NCB, KCB);
440 CodeGenType::jit_micro_kernel_fp fn_noacc;
441 if (!accumulate) {
442 fn_noacc = codeObj.getOrCreate<inst_set_t::avx512>(
443 false /* accum */, MCB, NCB, KCB);
444 }
445
446 vector<int64_t> At, Bt;
447 // TODO: handle transpose during packing
448 if (transa == matrix_op_t::Transpose) {
449 At.resize(M * K);
450 for (int i = 0; i < M; ++i) {
451 for (int k = 0; k < K; ++k) {
452 At.at(i * K + k) = A[i + k * lda];
453 }
454 }
455 A = At.data();
456 lda = K;
457 }
458 if (transb == matrix_op_t::Transpose) {
459 Bt.resize(K * N);
460 for (int k = 0; k < K; ++k) {
461 for (int j = 0; j < N; ++j) {
462 Bt.at(k * N + j) = B[k + j * ldb];
463 }
464 }
465 B = Bt.data();
466 ldb = N;
467 }
468
469 alignas(64) array<int64_t, MCB * KCB> packA;
470 alignas(64) array<int64_t, KCB * NCB> packB;
471 alignas(64) array<int64_t, MCB * NCB> packC;
472
473 for (int ic = 0; ic < M; ic += MCB) {
474 for (int kc = 0; kc < K; kc += KCB) {
475 // pack A
476 for (int i = 0; i < std::min(MCB, M - ic); ++i) {
477 memcpy(
478 &packA[i * KCB],
479 A + (ic + i) * lda + kc,
480 std::min(K - kc, KCB) * sizeof(int64_t));
481 }
482
483 for (int jc = 0; jc < N; jc += NCB) {
484 // pack B
485 for (int i = 0; i < std::min(KCB, K - kc); ++i) {
486 memcpy(
487 &packB[i * NCB],
488 B + (kc + i) * ldb + jc,
489 std::min(NCB, N - jc) * sizeof(int64_t));
490 }
491
492 if (M - ic >= MCB && N - jc >= NCB) {
493 if (kc == 0 && !accumulate) {
494 fn_noacc(
495 packA.data(),
496 packB.data(),
497 packB.data(),
498 C + ic * ldc + jc,
499 std::min(KCB, K - kc),
500 ldc);
501 } else {
502 fn(packA.data(),
503 packB.data(),
504 packB.data(),
505 C + ic * ldc + jc,
506 std::min(KCB, K - kc),
507 ldc);
508 }
509 } else {
510 // remainder
511 if (kc == 0 && !accumulate) {
512 fn_noacc(
513 packA.data(),
514 packB.data(),
515 packB.data(),
516 packC.data(),
517 std::min(KCB, K - kc),
518 NCB);
519 } else {
520 for (int i = 0; i < std::min(MCB, M - ic); ++i) {
521 memcpy(
522 &packC[i * NCB],
523 C + (ic + i) * ldc + jc,
524 std::min(NCB, N - jc) * sizeof(int64_t));
525 }
526 fn(packA.data(),
527 packB.data(),
528 packB.data(),
529 packC.data(),
530 std::min(KCB, K - kc),
531 NCB);
532 }
533 for (int i = 0; i < std::min(MCB, M - ic); ++i) {
534 memcpy(
535 C + (ic + i) * ldc + jc,
536 &packC[i * NCB],
537 std::min(NCB, N - jc) * sizeof(int64_t));
538 }
539 }
540 } // jc
541 } // kc
542 } // ic
543}
544
545} // namespace fbgemm
546