1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8#include <asmjit/asmjit.h>
9#include <cpuinfo.h>
10#include <cassert>
11#include <cstdint>
12#include <map>
13#include <mutex>
14#include <sstream>
15#include <string>
16#include <tuple>
17#include <type_traits>
18#include "./CodeCache.h"
19#include "fbgemm/ConvUtils.h"
20#include "fbgemm/Fbgemm.h"
21#include "fbgemm/SimdUtils.h"
22#include "fbgemm/Utils.h"
23/*#define FBGEMM_LOG_CODE 1*/
24
25#define GCONV_INST_AVX2_HEADER \
26 template <inst_set_t ISET = INST_SET> \
27 typename std::enable_if<ISET == inst_set_t::avx2, void>::type
28
29#define GCONV_INST_AVX512_AND_VNNI_HEADER \
30 template <inst_set_t ISET = INST_SET> \
31 typename std::enable_if< \
32 ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \
33 void>::type
34
35#define GCONV_INST_DEF_AVX2_HEADER \
36 template <int SPATIAL_DIM, inst_set_t INST_SET> \
37 template <inst_set_t ISET> \
38 typename std::enable_if<ISET == inst_set_t::avx2, void>::type
39
40#define GCONV_INST_DEF_AVX512_AND_VNNI_HEADER \
41 template <int SPATIAL_DIM, inst_set_t INST_SET> \
42 template <inst_set_t ISET> \
43 typename std::enable_if< \
44 ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \
45 void>::type
46
47namespace fbgemm {
48
49namespace x86 = asmjit::x86;
50
51template <typename>
52struct is_requantization : std::false_type {};
53
54template <
55 bool FUSE_RELU,
56 QuantizationGranularity Q_GRAN,
57 typename BIAS_TYPE,
58 typename outT,
59 typename inT,
60 typename nextOPType>
61struct is_requantization<
62 ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>>
63 : std::true_type {};
64
65using jit_conv_kernel_fp = void (*)(
66 const uint8_t* in_acts,
67 int8_t* wghts,
68 int32_t* out_acts,
69 int32_t a_zero_pt,
70 int32_t oh_start,
71 int32_t oh_end,
72 int32_t ow,
73 int32_t* row_offset);
74
75using kernel_sig_t = std::tuple<
76 bool, /* is A zero point 0 */
77 bool, /* should row offset be calculated */
78 bool, /* is top edge included */
79 bool, /* is bottom edge included */
80 bool, /* is top bottom edge same? */
81 bool, /* use paddings on bottom side? */
82 bool, /* use paddings on right side? */
83 bool, /* accumulate rowoffsets and output instead of overwrite? */
84 int, /* groups */
85 int, /* stride */
86 int, /* number of input channels per group */
87 int>; /* number of output channels per group */
88
89// Common code in a base class
90template <int SPATIAL_DIM, inst_set_t INST_SET>
91class GenConvKernelBase {
92 public:
93 GenConvKernelBase(
94 const conv_param_t<SPATIAL_DIM>& conv_param,
95 std::int32_t a_zero_point,
96 bool needRowOffset,
97 bool isTopEdgeIncluded,
98 bool isBottomEdgeIncluded,
99 bool isTopBottomEdgeSame,
100 bool accum) {
101 assert(fbgemmOptimizedGConv(conv_param));
102
103 isAZeroPointZero_ = a_zero_point == 0;
104 needRowOffset_ = needRowOffset;
105 isTopEdgeIncluded_ = isTopEdgeIncluded;
106 isBottomEdgeIncluded_ = isBottomEdgeIncluded;
107 isTopBottomEdgeSame_ = isTopBottomEdgeSame;
108 accum_ = accum;
109
110 G_ = conv_param.G;
111 K_per_G_ = conv_param.OC / conv_param.G;
112 K_ = conv_param.OC;
113 C_per_G_ = conv_param.IC / conv_param.G;
114 C_ = conv_param.IC;
115
116 // Strides are assumed to be the same in all directions
117 STRIDE_ = conv_param.stride[0];
118 R_ = conv_param.K[0];
119 S_ = conv_param.K[1];
120 OH_ = conv_param.OUT_DIM[0];
121 OW_ = conv_param.OUT_DIM[1];
122 H_PAD_ = conv_param.pad[0];
123 W_PAD_ = conv_param.pad[1];
124
125 use_bottom_padding_ =
126 !(STRIDE_ > 1 && conv_param.IN_DIM[SPATIAL_DIM - 2] % 2 == 0);
127 use_right_padding_ =
128 !(STRIDE_ > 1 && conv_param.IN_DIM[SPATIAL_DIM - 1] % 2 == 0);
129 }
130
131 ~GenConvKernelBase() {}
132
133 static std::string getCodeLoggingFile(kernel_sig_t kernel_sig) {
134 std::ostringstream oss;
135 oss << "conv";
136 oss << "_G-" << std::get<8>(kernel_sig);
137 oss << "_stride-" << std::get<9>(kernel_sig);
138 oss << "_IC_per_G-" << std::get<10>(kernel_sig);
139 oss << "_OC_per_G-" << std::get<11>(kernel_sig);
140 oss << "_isZeroPointZero-" << std::get<0>(kernel_sig);
141 oss << "_rowoffset-" << std::get<1>(kernel_sig);
142 oss << "_topEdge-" << std::get<2>(kernel_sig);
143 oss << "_bottomEdge-" << std::get<3>(kernel_sig);
144 oss << "_isTopBottomSame-" << std::get<4>(kernel_sig);
145 oss << "_useBottomPadding-" << std::get<5>(kernel_sig);
146 oss << "_useRightPadding-" << std::get<6>(kernel_sig);
147 oss << "_accum-" << std::get<7>(kernel_sig);
148
149 if (INST_SET == inst_set_t::avx512) {
150 oss << "_avx512";
151 } else if (INST_SET == inst_set_t::avx2) {
152 oss << "_avx2";
153 } else {
154 oss << "_unknown";
155 }
156
157 oss << ".txt";
158 return oss.str();
159 }
160
161 static asmjit::JitRuntime& runtime() {
162 static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
163 // depents on other static
164 // variables. Required to prevent
165 // initialization order fiasco
166 return rt;
167 }
168
169 static std::mutex rtMutex_; ///< Control access to runtime;
170
171 static CodeCache<
172 kernel_sig_t,
173 jit_conv_kernel_fp>
174 codeCache_; ///< JIT Code Cache for reuse.
175
176 protected:
177 // current conv parameters
178 int G_; ///< Number of groups
179 int K_; ///< Number of output channels
180 int K_per_G_; ///< Number of output channels per group
181 int C_; ///< Number of input channels
182 int STRIDE_; ///< Stride in either direction
183 int C_per_G_; ///< Number of input channels per group
184 int R_; ///< Filter/Kernel height
185 int S_; ///< Filter/Kernel width
186 int OH_; ///< output height
187 int OW_; ///< output width
188 int H_PAD_; ///< Padding for height (top and bottom)
189 int W_PAD_; ///< Padding for width (left and right)
190
191 // Other parameters
192 bool isAZeroPointZero_;
193 bool needRowOffset_;
194 bool isTopEdgeIncluded_;
195 bool isBottomEdgeIncluded_;
196 bool isTopBottomEdgeSame_;
197 bool accum_;
198 // For 3x3 kernels with pad == 1: If stride is 2 and image height/width are
199 // even, the right or bottom paddings are not used. This variables is set to
200 // false if paddings on the left and bottom are not used and kernel generation
201 // takes care to not generate code with paddings on the right and bottom side.
202 bool use_bottom_padding_;
203 bool use_right_padding_;
204};
205
206// Generic class
207template <int SPATIAL_DIM, inst_set_t INST_SET>
208class FBGEMM_API GenConvKernel
209 : public GenConvKernelBase<SPATIAL_DIM, INST_SET> {
210 typedef typename simd_info<INST_SET>::vec_reg_t vec_reg_t;
211
212 public:
213 GenConvKernel(
214 const conv_param_t<SPATIAL_DIM>& conv_param,
215 std::int32_t a_zero_point,
216 bool needRowoffset,
217 bool isTopEdgeIncluded,
218 bool isBottomEdgeIncluded,
219 bool isTopBottomEdgeSame,
220 bool accum)
221 : GenConvKernelBase<SPATIAL_DIM, INST_SET>(
222 conv_param,
223 a_zero_point,
224 needRowoffset,
225 isTopEdgeIncluded,
226 isBottomEdgeIncluded,
227 isTopBottomEdgeSame,
228 accum) {
229 constexpr int SIMD_WIDTH = simd_info<INST_SET>::WIDTH_BYTES;
230 GTogether_ = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>::
231 numOfGroupsTogether(conv_param);
232 kLoopIters_ = this->K_per_G_ * this->C_per_G_ / SIMD_WIDTH;
233 // y/zmm0-8 are used for holding weights
234 zeroPTReg_V_ = vec_reg_t(10);
235 tmpReg1_V_ = vec_reg_t(11);
236 stPermReg_V_ = vec_reg_t(12);
237 actReg_V_ = vec_reg_t(13);
238 oneReg16Bit_V_ = vec_reg_t(15);
239 rowOffsetReg_V_ = vec_reg_t(14);
240 }
241
242 jit_conv_kernel_fp getOrCreate();
243
244 GCONV_INST_AVX2_HEADER genForLoadingWeights(x86::Emitter* a);
245
246 GCONV_INST_AVX512_AND_VNNI_HEADER genForLoadingWeights(x86::Emitter* a);
247
248 GCONV_INST_AVX2_HEADER genConstForPermutations(x86::Emitter* a);
249
250 GCONV_INST_AVX512_AND_VNNI_HEADER genConstForPermutations(x86::Emitter* a);
251
252 GCONV_INST_AVX2_HEADER genForSingleFilterPoint(
253 x86::Emitter* a,
254 int r,
255 int s,
256 int act_s,
257 bool use_zero_reg);
258
259 GCONV_INST_AVX512_AND_VNNI_HEADER genForSingleFilterPoint(
260 x86::Emitter* a,
261 int r,
262 int s,
263 int act_s,
264 bool use_zero_reg);
265
266 GCONV_INST_AVX2_HEADER storeResult(x86::Emitter* a);
267
268 GCONV_INST_AVX512_AND_VNNI_HEADER storeResult(x86::Emitter* a);
269
270 GCONV_INST_AVX2_HEADER storeOffset(x86::Emitter* a);
271
272 GCONV_INST_AVX512_AND_VNNI_HEADER storeOffset(x86::Emitter* a);
273
274 void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom);
275
276 void initResultRegs(x86::Emitter* a);
277
278 void genCoreInsts(x86::Emitter* a);
279
280 void genForSingleOutput(
281 x86::Emitter* a,
282 bool isLeft,
283 bool isRight,
284 bool isTop,
285 bool isBottom);
286
287 private:
288 int GTogether_;
289 // The number of iterations needed for K dim.
290 // e.g., C_per_G_ = K_per_G_ = 8, we have to iterate
291 // twice on K dim because 4 (from K dim) * 8 ( from C dim)
292 // fill the full avx2 vector width.
293 int kLoopIters_;
294 asmjit::FuncDetail func_;
295 asmjit::FuncFrame frame_;
296 vec_reg_t zeroPTReg_V_;
297 vec_reg_t tmpReg1_V_;
298 vec_reg_t stPermReg_V_;
299 vec_reg_t actReg_V_;
300 vec_reg_t resultReg_V_;
301 vec_reg_t oneReg8Bit_V_;
302 vec_reg_t oneReg16Bit_V_;
303 vec_reg_t rowOffsetReg_V_;
304
305 // arguments to the function created
306 x86::Gp in_acts_R_;
307 x86::Gp wghts_R_;
308 x86::Gp out_acts_R_;
309 x86::Gp a_zero_pt_R_;
310 x86::Gp H_R_;
311 x86::Gp H_start_R_;
312 x86::Gp H_end_R_;
313 x86::Gp W_R_;
314 x86::Gp row_offset_R_;
315
316 // Used registers
317 x86::Gp loopR1_;
318 x86::Gp loopR2_;
319 x86::Gp scratchReg1_;
320 x86::Gp scratchReg2_;
321};
322
323template <int SPATIAL_DIM, inst_set_t INST_SET>
324std::mutex GenConvKernelBase<SPATIAL_DIM, INST_SET>::rtMutex_;
325
326template <int SPATIAL_DIM, inst_set_t INST_SET>
327CodeCache<kernel_sig_t, jit_conv_kernel_fp>
328 GenConvKernelBase<SPATIAL_DIM, INST_SET>::codeCache_;
329
330} // namespace fbgemm
331