1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8#include <asmjit/asmjit.h>
9#include <cpuinfo.h>
10#include <map>
11#include <mutex>
12#include <sstream>
13#include <string>
14#include <tuple>
15#include "./CodeCache.h"
16#include "fbgemm/Fbgemm.h"
17#include "fbgemm/SimdUtils.h"
18//#define FBGEMM_LOG_CODE 1
19
20namespace fbgemm {
21
22namespace x86 = asmjit::x86;
23
24/**
25 * @brief Generate instructions for initializing the C registers to 0.
26 */
27void initCRegs(x86::Emitter* a, int rowRegs, int colRegs);
28
29/**
30 * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator.
31 * @tparam TA Type of matrix A.
32 * @tparam TB Type of matrix B.
33 * @tparam TC Type of matrix C.
34 * @tparam accT Accumulation type, currently we support 16-bit (std::int16_t) or
35 * 32-bit (std::int32_t) accumulation.
36 */
37template <typename TA, typename TB, typename TC, typename accT>
38class CodeGenBase {
39 public:
40 using jit_micro_kernel_fp = void (*)(
41 const TA* bufferA,
42 const TB* bufferB,
43 const TB* b_pf,
44 TC* bufferC,
45 int kc,
46 int ldc);
47
48 /**
49 * @brief Constructor for initializing AVX2/AVX512 registers.
50 */
51 CodeGenBase(const BlockingFactors* params = nullptr)
52 : blocking_params(params) {}
53
54 /**
55 * @brief Get or Create the instructions for macro-kernel.
56 *
57 * If the problem size (mc, nc) and accumulation flag (accum) can be found in
58 * the code cache (a hash map), then get the macro-kernel instructions
59 * directly from it. Otherwise, create the instructions for macro-kernel, and
60 * store that into the code cache.
61 */
62 template <inst_set_t instSet>
63 jit_micro_kernel_fp
64 getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc);
65
66 /**
67 * @brief Generate instructions for computing block in the rank-k update.
68 */
69 template <inst_set_t instSet>
70 void genComputeBlock(
71 x86::Emitter* a,
72 x86::Gp buffer_A,
73 x86::Gp buffer_B,
74 x86::Gp B_pf,
75 int rowRegs,
76 int colRegs,
77 int lda);
78
79 /**
80 * @brief Generate instructions for storing the C registers back to the
81 * memory.
82 */
83 template <inst_set_t instSet>
84 void storeCRegs(
85 x86::Emitter* a,
86 int rowRegs,
87 int colRegs,
88 x86::Gp C_Offset,
89 x86::Gp ldcReg,
90 bool accum);
91
92 const BlockingFactors* blocking_params;
93 /**
94 * @brief Generate filename to dump generated code
95 * (debug-only)
96 */
97 template <inst_set_t instSet>
98 static std::string getCodeLoggingFile(
99 bool accum,
100 int mc,
101 int nc,
102 int NCB,
103 int KCB,
104 int MR,
105 int NR) {
106 std::ostringstream oss;
107 oss << "gemm_";
108 if (std::is_same<accT, std::int16_t>::value) {
109 oss << "acc16_";
110 } else if (std::is_same<accT, std::int32_t>::value) {
111 oss << "acc32_";
112 } else {
113 oss << "unknown_";
114 }
115 oss << "accum-" + std::to_string(accum) << "_MC-" + std::to_string(mc)
116 << "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB)
117 << "_KCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR)
118 << "_NR-" + std::to_string(NR);
119 if (instSet == inst_set_t::avx512_vnni) {
120 oss << "_avx512vnni";
121 } else if (instSet == inst_set_t::avx512) {
122 oss << "_avx512";
123 } else if (instSet == inst_set_t::avx512_ymm) {
124 oss << "_avx512_ymm";
125 } else if (instSet == inst_set_t::avx2) {
126 oss << "_avx2";
127 }
128 oss << ".txt";
129 return oss.str();
130 }
131
132 private:
133 static asmjit::JitRuntime& runtime() {
134 static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
135 // depents on other static
136 // variables. Required to prevent
137 // initialization order fiasco
138 return rt;
139 }
140
141 static std::mutex rtMutex_; ///< Controll access to runtime;
142
143 // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr
144 static CodeCache<
145 std::tuple<bool, int, int, int, int, int, int>,
146 jit_micro_kernel_fp>
147 codeCache_; ///< JIT Code Cache for reuse.
148};
149
150template <typename TA, typename TB, typename TC, typename accT>
151std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_;
152
153template <typename TA, typename TB, typename TC, typename accT>
154CodeCache<
155 std::tuple<bool, int, int, int, int, int, int>,
156 typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
157 CodeGenBase<TA, TB, TC, accT>::codeCache_;
158
159} // namespace fbgemm
160