1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | #include <asmjit/asmjit.h> |
9 | #include <cpuinfo.h> |
10 | #include <map> |
11 | #include <mutex> |
12 | #include <sstream> |
13 | #include <string> |
14 | #include <tuple> |
15 | #include "./CodeCache.h" |
16 | #include "fbgemm/Fbgemm.h" |
17 | #include "fbgemm/SimdUtils.h" |
18 | //#define FBGEMM_LOG_CODE 1 |
19 | |
20 | namespace fbgemm { |
21 | |
22 | namespace x86 = asmjit::x86; |
23 | |
24 | /** |
25 | * @brief Generate instructions for initializing the C registers to 0. |
26 | */ |
27 | void initCRegs(x86::Emitter* a, int rowRegs, int colRegs); |
28 | |
29 | /** |
30 | * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator. |
31 | * @tparam TA Type of matrix A. |
32 | * @tparam TB Type of matrix B. |
33 | * @tparam TC Type of matrix C. |
34 | * @tparam accT Accumulation type, currently we support 16-bit (std::int16_t) or |
35 | * 32-bit (std::int32_t) accumulation. |
36 | */ |
37 | template <typename TA, typename TB, typename TC, typename accT> |
38 | class CodeGenBase { |
39 | public: |
40 | using jit_micro_kernel_fp = void (*)( |
41 | const TA* bufferA, |
42 | const TB* bufferB, |
43 | const TB* b_pf, |
44 | TC* bufferC, |
45 | int kc, |
46 | int ldc); |
47 | |
48 | /** |
49 | * @brief Constructor for initializing AVX2/AVX512 registers. |
50 | */ |
51 | CodeGenBase(const BlockingFactors* params = nullptr) |
52 | : blocking_params(params) {} |
53 | |
54 | /** |
55 | * @brief Get or Create the instructions for macro-kernel. |
56 | * |
57 | * If the problem size (mc, nc) and accumulation flag (accum) can be found in |
58 | * the code cache (a hash map), then get the macro-kernel instructions |
59 | * directly from it. Otherwise, create the instructions for macro-kernel, and |
60 | * store that into the code cache. |
61 | */ |
62 | template <inst_set_t instSet> |
63 | jit_micro_kernel_fp |
64 | getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc); |
65 | |
66 | /** |
67 | * @brief Generate instructions for computing block in the rank-k update. |
68 | */ |
69 | template <inst_set_t instSet> |
70 | void genComputeBlock( |
71 | x86::Emitter* a, |
72 | x86::Gp buffer_A, |
73 | x86::Gp buffer_B, |
74 | x86::Gp B_pf, |
75 | int rowRegs, |
76 | int colRegs, |
77 | int lda); |
78 | |
79 | /** |
80 | * @brief Generate instructions for storing the C registers back to the |
81 | * memory. |
82 | */ |
83 | template <inst_set_t instSet> |
84 | void storeCRegs( |
85 | x86::Emitter* a, |
86 | int rowRegs, |
87 | int colRegs, |
88 | x86::Gp C_Offset, |
89 | x86::Gp ldcReg, |
90 | bool accum); |
91 | |
92 | const BlockingFactors* blocking_params; |
93 | /** |
94 | * @brief Generate filename to dump generated code |
95 | * (debug-only) |
96 | */ |
97 | template <inst_set_t instSet> |
98 | static std::string getCodeLoggingFile( |
99 | bool accum, |
100 | int mc, |
101 | int nc, |
102 | int NCB, |
103 | int KCB, |
104 | int MR, |
105 | int NR) { |
106 | std::ostringstream oss; |
107 | oss << "gemm_" ; |
108 | if (std::is_same<accT, std::int16_t>::value) { |
109 | oss << "acc16_" ; |
110 | } else if (std::is_same<accT, std::int32_t>::value) { |
111 | oss << "acc32_" ; |
112 | } else { |
113 | oss << "unknown_" ; |
114 | } |
115 | oss << "accum-" + std::to_string(accum) << "_MC-" + std::to_string(mc) |
116 | << "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB) |
117 | << "_KCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR) |
118 | << "_NR-" + std::to_string(NR); |
119 | if (instSet == inst_set_t::avx512_vnni) { |
120 | oss << "_avx512vnni" ; |
121 | } else if (instSet == inst_set_t::avx512) { |
122 | oss << "_avx512" ; |
123 | } else if (instSet == inst_set_t::avx512_ymm) { |
124 | oss << "_avx512_ymm" ; |
125 | } else if (instSet == inst_set_t::avx2) { |
126 | oss << "_avx2" ; |
127 | } |
128 | oss << ".txt" ; |
129 | return oss.str(); |
130 | } |
131 | |
132 | private: |
133 | static asmjit::JitRuntime& runtime() { |
134 | static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, |
135 | // depents on other static |
136 | // variables. Required to prevent |
137 | // initialization order fiasco |
138 | return rt; |
139 | } |
140 | |
141 | static std::mutex rtMutex_; ///< Controll access to runtime; |
142 | |
143 | // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr |
144 | static CodeCache< |
145 | std::tuple<bool, int, int, int, int, int, int>, |
146 | jit_micro_kernel_fp> |
147 | codeCache_; ///< JIT Code Cache for reuse. |
148 | }; |
149 | |
150 | template <typename TA, typename TB, typename TC, typename accT> |
151 | std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_; |
152 | |
153 | template <typename TA, typename TB, typename TC, typename accT> |
154 | CodeCache< |
155 | std::tuple<bool, int, int, int, int, int, int>, |
156 | typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> |
157 | CodeGenBase<TA, TB, TC, accT>::codeCache_; |
158 | |
159 | } // namespace fbgemm |
160 | |