1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | #include <asmjit/asmjit.h> |
9 | #include <cpuinfo.h> |
10 | #include <cassert> |
11 | #include <cstdint> |
12 | #include <map> |
13 | #include <mutex> |
14 | #include <sstream> |
15 | #include <string> |
16 | #include <tuple> |
17 | #include <type_traits> |
18 | #include "./CodeCache.h" |
19 | #include "fbgemm/ConvUtils.h" |
20 | #include "fbgemm/Fbgemm.h" |
21 | #include "fbgemm/SimdUtils.h" |
22 | #include "fbgemm/Utils.h" |
23 | /*#define FBGEMM_LOG_CODE 1*/ |
24 | |
25 | #define \ |
26 | template <inst_set_t ISET = INST_SET> \ |
27 | typename std::enable_if<ISET == inst_set_t::avx2, void>::type |
28 | |
29 | #define GCONV_INST_AVX512_AND_VNNI_HEADER \ |
30 | template <inst_set_t ISET = INST_SET> \ |
31 | typename std::enable_if< \ |
32 | ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \ |
33 | void>::type |
34 | |
35 | #define \ |
36 | template <int SPATIAL_DIM, inst_set_t INST_SET> \ |
37 | template <inst_set_t ISET> \ |
38 | typename std::enable_if<ISET == inst_set_t::avx2, void>::type |
39 | |
40 | #define GCONV_INST_DEF_AVX512_AND_VNNI_HEADER \ |
41 | template <int SPATIAL_DIM, inst_set_t INST_SET> \ |
42 | template <inst_set_t ISET> \ |
43 | typename std::enable_if< \ |
44 | ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \ |
45 | void>::type |
46 | |
47 | namespace fbgemm { |
48 | |
49 | namespace x86 = asmjit::x86; |
50 | |
51 | template <typename> |
52 | struct is_requantization : std::false_type {}; |
53 | |
54 | template < |
55 | bool FUSE_RELU, |
56 | QuantizationGranularity Q_GRAN, |
57 | typename BIAS_TYPE, |
58 | typename outT, |
59 | typename inT, |
60 | typename nextOPType> |
61 | struct is_requantization< |
62 | ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>> |
63 | : std::true_type {}; |
64 | |
65 | using jit_conv_kernel_fp = void (*)( |
66 | const uint8_t* in_acts, |
67 | int8_t* wghts, |
68 | int32_t* out_acts, |
69 | int32_t a_zero_pt, |
70 | int32_t oh_start, |
71 | int32_t oh_end, |
72 | int32_t ow, |
73 | int32_t* row_offset); |
74 | |
75 | using kernel_sig_t = std::tuple< |
76 | bool, /* is A zero point 0 */ |
77 | bool, /* should row offset be calculated */ |
78 | bool, /* is top edge included */ |
79 | bool, /* is bottom edge included */ |
80 | bool, /* is top bottom edge same? */ |
81 | bool, /* use paddings on bottom side? */ |
82 | bool, /* use paddings on right side? */ |
83 | bool, /* accumulate rowoffsets and output instead of overwrite? */ |
84 | int, /* groups */ |
85 | int, /* stride */ |
86 | int, /* number of input channels per group */ |
87 | int>; /* number of output channels per group */ |
88 | |
89 | // Common code in a base class |
90 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
91 | class GenConvKernelBase { |
92 | public: |
93 | GenConvKernelBase( |
94 | const conv_param_t<SPATIAL_DIM>& conv_param, |
95 | std::int32_t a_zero_point, |
96 | bool needRowOffset, |
97 | bool isTopEdgeIncluded, |
98 | bool isBottomEdgeIncluded, |
99 | bool isTopBottomEdgeSame, |
100 | bool accum) { |
101 | assert(fbgemmOptimizedGConv(conv_param)); |
102 | |
103 | isAZeroPointZero_ = a_zero_point == 0; |
104 | needRowOffset_ = needRowOffset; |
105 | isTopEdgeIncluded_ = isTopEdgeIncluded; |
106 | isBottomEdgeIncluded_ = isBottomEdgeIncluded; |
107 | isTopBottomEdgeSame_ = isTopBottomEdgeSame; |
108 | accum_ = accum; |
109 | |
110 | G_ = conv_param.G; |
111 | K_per_G_ = conv_param.OC / conv_param.G; |
112 | K_ = conv_param.OC; |
113 | C_per_G_ = conv_param.IC / conv_param.G; |
114 | C_ = conv_param.IC; |
115 | |
116 | // Strides are assumed to be the same in all directions |
117 | STRIDE_ = conv_param.stride[0]; |
118 | R_ = conv_param.K[0]; |
119 | S_ = conv_param.K[1]; |
120 | OH_ = conv_param.OUT_DIM[0]; |
121 | OW_ = conv_param.OUT_DIM[1]; |
122 | H_PAD_ = conv_param.pad[0]; |
123 | W_PAD_ = conv_param.pad[1]; |
124 | |
125 | use_bottom_padding_ = |
126 | !(STRIDE_ > 1 && conv_param.IN_DIM[SPATIAL_DIM - 2] % 2 == 0); |
127 | use_right_padding_ = |
128 | !(STRIDE_ > 1 && conv_param.IN_DIM[SPATIAL_DIM - 1] % 2 == 0); |
129 | } |
130 | |
131 | ~GenConvKernelBase() {} |
132 | |
133 | static std::string getCodeLoggingFile(kernel_sig_t kernel_sig) { |
134 | std::ostringstream oss; |
135 | oss << "conv" ; |
136 | oss << "_G-" << std::get<8>(kernel_sig); |
137 | oss << "_stride-" << std::get<9>(kernel_sig); |
138 | oss << "_IC_per_G-" << std::get<10>(kernel_sig); |
139 | oss << "_OC_per_G-" << std::get<11>(kernel_sig); |
140 | oss << "_isZeroPointZero-" << std::get<0>(kernel_sig); |
141 | oss << "_rowoffset-" << std::get<1>(kernel_sig); |
142 | oss << "_topEdge-" << std::get<2>(kernel_sig); |
143 | oss << "_bottomEdge-" << std::get<3>(kernel_sig); |
144 | oss << "_isTopBottomSame-" << std::get<4>(kernel_sig); |
145 | oss << "_useBottomPadding-" << std::get<5>(kernel_sig); |
146 | oss << "_useRightPadding-" << std::get<6>(kernel_sig); |
147 | oss << "_accum-" << std::get<7>(kernel_sig); |
148 | |
149 | if (INST_SET == inst_set_t::avx512) { |
150 | oss << "_avx512" ; |
151 | } else if (INST_SET == inst_set_t::avx2) { |
152 | oss << "_avx2" ; |
153 | } else { |
154 | oss << "_unknown" ; |
155 | } |
156 | |
157 | oss << ".txt" ; |
158 | return oss.str(); |
159 | } |
160 | |
161 | static asmjit::JitRuntime& runtime() { |
162 | static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, |
163 | // depents on other static |
164 | // variables. Required to prevent |
165 | // initialization order fiasco |
166 | return rt; |
167 | } |
168 | |
169 | static std::mutex rtMutex_; ///< Control access to runtime; |
170 | |
171 | static CodeCache< |
172 | kernel_sig_t, |
173 | jit_conv_kernel_fp> |
174 | codeCache_; ///< JIT Code Cache for reuse. |
175 | |
176 | protected: |
177 | // current conv parameters |
178 | int G_; ///< Number of groups |
179 | int K_; ///< Number of output channels |
180 | int K_per_G_; ///< Number of output channels per group |
181 | int C_; ///< Number of input channels |
182 | int STRIDE_; ///< Stride in either direction |
183 | int C_per_G_; ///< Number of input channels per group |
184 | int R_; ///< Filter/Kernel height |
185 | int S_; ///< Filter/Kernel width |
186 | int OH_; ///< output height |
187 | int OW_; ///< output width |
188 | int H_PAD_; ///< Padding for height (top and bottom) |
189 | int W_PAD_; ///< Padding for width (left and right) |
190 | |
191 | // Other parameters |
192 | bool isAZeroPointZero_; |
193 | bool needRowOffset_; |
194 | bool isTopEdgeIncluded_; |
195 | bool isBottomEdgeIncluded_; |
196 | bool isTopBottomEdgeSame_; |
197 | bool accum_; |
198 | // For 3x3 kernels with pad == 1: If stride is 2 and image height/width are |
199 | // even, the right or bottom paddings are not used. This variables is set to |
200 | // false if paddings on the left and bottom are not used and kernel generation |
201 | // takes care to not generate code with paddings on the right and bottom side. |
202 | bool use_bottom_padding_; |
203 | bool use_right_padding_; |
204 | }; |
205 | |
206 | // Generic class |
207 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
208 | class FBGEMM_API GenConvKernel |
209 | : public GenConvKernelBase<SPATIAL_DIM, INST_SET> { |
210 | typedef typename simd_info<INST_SET>::vec_reg_t vec_reg_t; |
211 | |
212 | public: |
213 | GenConvKernel( |
214 | const conv_param_t<SPATIAL_DIM>& conv_param, |
215 | std::int32_t a_zero_point, |
216 | bool needRowoffset, |
217 | bool isTopEdgeIncluded, |
218 | bool isBottomEdgeIncluded, |
219 | bool isTopBottomEdgeSame, |
220 | bool accum) |
221 | : GenConvKernelBase<SPATIAL_DIM, INST_SET>( |
222 | conv_param, |
223 | a_zero_point, |
224 | needRowoffset, |
225 | isTopEdgeIncluded, |
226 | isBottomEdgeIncluded, |
227 | isTopBottomEdgeSame, |
228 | accum) { |
229 | constexpr int SIMD_WIDTH = simd_info<INST_SET>::WIDTH_BYTES; |
230 | GTogether_ = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>:: |
231 | numOfGroupsTogether(conv_param); |
232 | kLoopIters_ = this->K_per_G_ * this->C_per_G_ / SIMD_WIDTH; |
233 | // y/zmm0-8 are used for holding weights |
234 | zeroPTReg_V_ = vec_reg_t(10); |
235 | tmpReg1_V_ = vec_reg_t(11); |
236 | stPermReg_V_ = vec_reg_t(12); |
237 | actReg_V_ = vec_reg_t(13); |
238 | oneReg16Bit_V_ = vec_reg_t(15); |
239 | rowOffsetReg_V_ = vec_reg_t(14); |
240 | } |
241 | |
242 | jit_conv_kernel_fp getOrCreate(); |
243 | |
244 | GCONV_INST_AVX2_HEADER genForLoadingWeights(x86::Emitter* a); |
245 | |
246 | GCONV_INST_AVX512_AND_VNNI_HEADER genForLoadingWeights(x86::Emitter* a); |
247 | |
248 | GCONV_INST_AVX2_HEADER genConstForPermutations(x86::Emitter* a); |
249 | |
250 | GCONV_INST_AVX512_AND_VNNI_HEADER genConstForPermutations(x86::Emitter* a); |
251 | |
252 | GCONV_INST_AVX2_HEADER genForSingleFilterPoint( |
253 | x86::Emitter* a, |
254 | int r, |
255 | int s, |
256 | int act_s, |
257 | bool use_zero_reg); |
258 | |
259 | GCONV_INST_AVX512_AND_VNNI_HEADER genForSingleFilterPoint( |
260 | x86::Emitter* a, |
261 | int r, |
262 | int s, |
263 | int act_s, |
264 | bool use_zero_reg); |
265 | |
266 | GCONV_INST_AVX2_HEADER storeResult(x86::Emitter* a); |
267 | |
268 | GCONV_INST_AVX512_AND_VNNI_HEADER storeResult(x86::Emitter* a); |
269 | |
270 | GCONV_INST_AVX2_HEADER storeOffset(x86::Emitter* a); |
271 | |
272 | GCONV_INST_AVX512_AND_VNNI_HEADER storeOffset(x86::Emitter* a); |
273 | |
274 | void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom); |
275 | |
276 | void initResultRegs(x86::Emitter* a); |
277 | |
278 | void genCoreInsts(x86::Emitter* a); |
279 | |
280 | void genForSingleOutput( |
281 | x86::Emitter* a, |
282 | bool isLeft, |
283 | bool isRight, |
284 | bool isTop, |
285 | bool isBottom); |
286 | |
287 | private: |
288 | int GTogether_; |
289 | // The number of iterations needed for K dim. |
290 | // e.g., C_per_G_ = K_per_G_ = 8, we have to iterate |
291 | // twice on K dim because 4 (from K dim) * 8 ( from C dim) |
292 | // fill the full avx2 vector width. |
293 | int kLoopIters_; |
294 | asmjit::FuncDetail func_; |
295 | asmjit::FuncFrame frame_; |
296 | vec_reg_t zeroPTReg_V_; |
297 | vec_reg_t tmpReg1_V_; |
298 | vec_reg_t stPermReg_V_; |
299 | vec_reg_t actReg_V_; |
300 | vec_reg_t resultReg_V_; |
301 | vec_reg_t oneReg8Bit_V_; |
302 | vec_reg_t oneReg16Bit_V_; |
303 | vec_reg_t rowOffsetReg_V_; |
304 | |
305 | // arguments to the function created |
306 | x86::Gp in_acts_R_; |
307 | x86::Gp wghts_R_; |
308 | x86::Gp out_acts_R_; |
309 | x86::Gp a_zero_pt_R_; |
310 | x86::Gp H_R_; |
311 | x86::Gp H_start_R_; |
312 | x86::Gp H_end_R_; |
313 | x86::Gp W_R_; |
314 | x86::Gp row_offset_R_; |
315 | |
316 | // Used registers |
317 | x86::Gp loopR1_; |
318 | x86::Gp loopR2_; |
319 | x86::Gp scratchReg1_; |
320 | x86::Gp scratchReg2_; |
321 | }; |
322 | |
323 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
324 | std::mutex GenConvKernelBase<SPATIAL_DIM, INST_SET>::rtMutex_; |
325 | |
326 | template <int SPATIAL_DIM, inst_set_t INST_SET> |
327 | CodeCache<kernel_sig_t, jit_conv_kernel_fp> |
328 | GenConvKernelBase<SPATIAL_DIM, INST_SET>::codeCache_; |
329 | |
330 | } // namespace fbgemm |
331 | |