1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | #include <asmjit/asmjit.h> |
9 | #include <cpuinfo.h> |
10 | #include <cassert> |
11 | #include <cstdint> |
12 | #include <map> |
13 | #include <mutex> |
14 | #include <sstream> |
15 | #include <string> |
16 | #include <tuple> |
17 | #include <type_traits> |
18 | #include "./CodeCache.h" |
19 | #include "fbgemm/ConvUtils.h" |
20 | #include "fbgemm/Fbgemm.h" |
21 | #include "fbgemm/Utils.h" |
22 | /*#define FBGEMM_LOG_CODE 1*/ |
23 | |
24 | namespace fbgemm { |
25 | |
26 | namespace x86 = asmjit::x86; |
27 | |
28 | /** |
29 | * @brief Generate instructions for initializing the C registers to 0. |
30 | */ |
31 | void initCRegs(x86::Emitter* a, int rowRegs, int colRegs); |
32 | |
33 | template <typename TA, typename TB, typename TC, typename accT> |
34 | class DirectConvCodeGenBase { |
35 | public: |
36 | using jit_micro_kernel_fp = void (*)( |
37 | const TA* bufferA, |
38 | const TB* bufferB, |
39 | const TB* b_pf, |
40 | TC* bufferC, |
41 | int kc, |
42 | int ldc); |
43 | |
44 | // microkernel signature for transposed direct conv |
45 | // ic: input channel |
46 | // ldcReg: leading dimension of output, a.k.a OC |
47 | // o1Xoc: output width multiply output channel: |
48 | // OUT_DIM[1] x OC |
49 | using jit_micro_kernel_fp_convT = void (*)( |
50 | const TA* bufferA, |
51 | const TB* bufferB, |
52 | TC* bufferC, |
53 | int ic, |
54 | int ldcReg, |
55 | int o1Xoc, |
56 | int i1); |
57 | |
58 | static std::mutex rtMutex_; ///< Control access to runtime; |
59 | |
60 | // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr |
61 | static CodeCache< |
62 | std::tuple<bool, int, int, int, int, int, int>, |
63 | jit_micro_kernel_fp> |
64 | codeCache_; ///< JIT Code Cache for reuse. |
65 | |
66 | // The hash depends on accumulate, stride, mr, nr |
67 | static CodeCache< |
68 | std::tuple<bool, int, int, int>, |
69 | jit_micro_kernel_fp_convT> |
70 | codeCacheT_; ///< JIT Code Cache for reuse. |
71 | |
72 | /** |
73 | * @brief Generate instructions for storing the C registers back to the |
74 | * memory. |
75 | */ |
76 | template <inst_set_t instSet> |
77 | void storeCRegs( |
78 | x86::Emitter* a, |
79 | int rowRegs, |
80 | int colRegs, |
81 | x86::Gp C_Offset, |
82 | x86::Gp ldcReg, |
83 | bool accum); |
84 | |
85 | /** |
86 | * @brief Generate instructions for storing the C registers back to the |
87 | * memory. |
88 | */ |
89 | template <inst_set_t instSet> |
90 | void storeCRegsTrans( |
91 | x86::Emitter* a, |
92 | int rowRegs, |
93 | int colRegs, |
94 | x86::Gp C_offset, |
95 | x86::Gp o1XocReg, |
96 | x86::Gp ldcReg, |
97 | bool accum); |
98 | |
99 | /** |
100 | * @brief Generate filename to dump generated code |
101 | * (debug-only) |
102 | */ |
103 | template <inst_set_t instSet> |
104 | static std::string getCodeLoggingFile( |
105 | bool accum, |
106 | int mc, |
107 | int nc, |
108 | int NCB, |
109 | int KCB, |
110 | int MR, |
111 | int NR) { |
112 | std::ostringstream oss; |
113 | oss << "directconv_" ; |
114 | if (std::is_same<accT, std::int16_t>::value) { |
115 | oss << "acc16_" ; |
116 | } else if (std::is_same<accT, std::int32_t>::value) { |
117 | oss << "acc32_" ; |
118 | } else { |
119 | oss << "unknown_" ; |
120 | } |
121 | oss << "accum-" + std::to_string(accum) << "_MC-" + std::to_string(mc) |
122 | << "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB) |
123 | << "_KCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR) |
124 | << "_NR-" + std::to_string(NR); |
125 | if (instSet == inst_set_t::avx512_vnni) { |
126 | oss << "_avx512vnni" ; |
127 | } else if (instSet == inst_set_t::avx512) { |
128 | oss << "_avx512" ; |
129 | } else if (instSet == inst_set_t::avx512_ymm) { |
130 | oss << "_avx512_ymm" ; |
131 | } else if (instSet == inst_set_t::avx2) { |
132 | oss << "_avx2" ; |
133 | } |
134 | oss << ".txt" ; |
135 | return oss.str(); |
136 | } |
137 | |
138 | /** |
139 | * @brief Get or Create the instructions for macro-kernel. |
140 | * |
141 | * If the problem size (mc, nc) and accumulation flag (accum) can be found in |
142 | * the code cache (a hash map), then get the macro-kernel instructions |
143 | * directly from it. Otherwise, create the instructions for macro-kernel, and |
144 | * store that into the code cache. |
145 | */ |
146 | template <inst_set_t instSet> |
147 | jit_micro_kernel_fp |
148 | getOrCreateDirectConv(bool accum, int32_t mc, int32_t nc, int32_t kc); |
149 | |
150 | /** |
151 | * @brief Get or Create the instructions for macro-kernel. |
152 | * |
153 | * If the problem size (mc, nc) and accumulation flag (accum) can be found in |
154 | * the code cache (a hash map), then get the macro-kernel instructions |
155 | * directly from it. Otherwise, create the instructions for macro-kernel, and |
156 | * store that into the code cache. |
157 | */ |
158 | template <inst_set_t instSet> |
159 | jit_micro_kernel_fp_convT |
160 | getOrCreateDirectConvTrans(bool accum, int32_t stride, int32_t numColRegs); |
161 | |
162 | /** |
163 | * @brief Generate instructions for computing block in the rank-k update. |
164 | */ |
165 | template <inst_set_t instSet> |
166 | void genComputeBlock( |
167 | x86::Emitter* a, |
168 | x86::Gp buffer_A, |
169 | x86::Gp buffer_B, |
170 | x86::Gp B_pf, |
171 | int rowRegs, |
172 | int colRegs, |
173 | int lda); |
174 | /** |
175 | * @brief Generate instructions for computing block in the rank-k update. |
176 | */ |
177 | template <inst_set_t instSet> |
178 | void genComputeBlockDirectConv( |
179 | x86::Emitter* a, |
180 | x86::Gp buffer_A, |
181 | x86::Gp buffer_B, |
182 | x86::Gp B_pf, |
183 | int rowRegs, |
184 | int colRegs, |
185 | int strideXich); |
186 | |
187 | /** |
188 | * @brief Generate instructions for computing block in the rank-k update. |
189 | */ |
190 | template <inst_set_t instSet> |
191 | void genComputeBlockDirectConvTrans( |
192 | x86::Emitter* a, |
193 | x86::Gp buffer_A, |
194 | x86::Gp buffer_B, |
195 | x86::Gp icReg, |
196 | x86::Gp C_offset, |
197 | int rowRegs, |
198 | int colRegs); |
199 | |
200 | private: |
201 | static asmjit::JitRuntime& runtime() { |
202 | static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, |
203 | // depents on other static |
204 | // variables. Required to prevent |
205 | // initialization order fiasco |
206 | return rt; |
207 | } |
208 | }; |
209 | |
210 | template <typename TA, typename TB, typename TC, typename accT> |
211 | std::mutex DirectConvCodeGenBase<TA, TB, TC, accT>::rtMutex_; |
212 | |
213 | template <typename TA, typename TB, typename TC, typename accT> |
214 | CodeCache< |
215 | std::tuple<bool, int, int, int, int, int, int>, |
216 | typename DirectConvCodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> |
217 | DirectConvCodeGenBase<TA, TB, TC, accT>::codeCache_; |
218 | |
219 | template <typename TA, typename TB, typename TC, typename accT> |
220 | CodeCache< |
221 | std::tuple<bool, int, int, int>, |
222 | typename DirectConvCodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp_convT> |
223 | DirectConvCodeGenBase<TA, TB, TC, accT>::codeCacheT_; |
224 | |
225 | }; // namespace fbgemm |
226 | |