1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include <asmjit/asmjit.h> |
9 | #include "./CodeGenHelpers.h" |
10 | #include "./GroupwiseConv.h" |
11 | |
12 | namespace fbgemm { |
13 | |
14 | using namespace std; |
15 | |
16 | namespace x86 = asmjit::x86; |
17 | |
18 | GCONV_INST_DEF_AVX512_AND_VNNI_HEADER |
19 | GenConvKernel<SPATIAL_DIM, INST_SET>::genConstForPermutations(x86::Emitter* a) { |
20 | x86::Gp permute_const_reg_upper_half = a->gpz(12); |
21 | x86::Gp permute_const_reg_lower_half = a->gpz(13); |
22 | x86::Xmm const_reg_xmm = x86::xmm11; |
23 | if (this->C_per_G_ == 4) { |
24 | // 4 group together |
25 | // We have 1st group in position 0 and 4, 2nd group 1 and 5 and so on. |
26 | // Permute to put 1st group to lower 128-bit and 2nd group to next |
27 | // 128-bit and so on. |
28 | // load f, b, 7, 3, |
29 | // e, a, 6, 2, |
30 | // d, 9, 5, 1, |
31 | // c, 8, 4, 0 in a 128-bit Xmm |
32 | a->mov( |
33 | permute_const_reg_lower_half, |
34 | static_cast<asmjit::Imm>(0x0d0905010c080400)); |
35 | a->mov( |
36 | permute_const_reg_upper_half, |
37 | static_cast<asmjit::Imm>(0x0f0b07030e0a0602)); |
38 | } else { |
39 | // this->C_per_G_ == 2 |
40 | // 8 group together |
41 | // We have 1st group in position 0 and 8, 2nd group 1 and 9 and so on. |
42 | // Permute to put 1st group to lower 128-bit and 2nd group to next |
43 | // 128-bit and so on. |
44 | // load |
45 | // f, 7 |
46 | // e, 6 |
47 | // d, 5 |
48 | // c, 4 |
49 | // b, 3 |
50 | // a, 2 |
51 | // 9, 1 |
52 | // 8, 0 in a 128-bit Xmm |
53 | a->mov( |
54 | permute_const_reg_lower_half, |
55 | static_cast<asmjit::Imm>(0x0b030a0209010800)); |
56 | a->mov( |
57 | permute_const_reg_upper_half, |
58 | static_cast<asmjit::Imm>(0x0f070e060d050c04)); |
59 | } |
60 | |
61 | a->movq(const_reg_xmm, permute_const_reg_lower_half); |
62 | a->pinsrq(const_reg_xmm, permute_const_reg_upper_half, 1); |
63 | // Zero extend 16 packed 8-bit integers in the low 8 bytes of const_reg_xmm |
64 | // to 16 packed 32-bit integers in stPermReg_V_ |
65 | a->vpmovzxbd(stPermReg_V_, const_reg_xmm); |
66 | } |
67 | |
68 | GCONV_INST_DEF_AVX512_AND_VNNI_HEADER |
69 | GenConvKernel<SPATIAL_DIM, INST_SET>::genForLoadingWeights(x86::Emitter* a) { |
70 | using WRegs = x86::Zmm; |
71 | int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4; |
72 | // load weights |
73 | for (int r = 0; r < this->R_; ++r) { |
74 | for (int s = 0; s < this->S_; ++s) { |
75 | // For other cases, weights are too big to be kept in registers |
76 | // and are loaded as they are used. |
77 | if (this->C_per_G_ != 16) { |
78 | // still use aligned move since the weigh buffer is 64bytes aligned. |
79 | a->vmovaps( |
80 | WRegs(r * this->S_ + s), |
81 | // load 512 bits for weights, different grouping for different |
82 | // workload |
83 | x86::zmmword_ptr( |
84 | wghts_R_, |
85 | (r * this->S_ + s) * this->K_per_G_ * GTogether_ * |
86 | paddedICPerG * sizeof(int8_t))); |
87 | } |
88 | } |
89 | } |
90 | } |
91 | |
92 | GCONV_INST_DEF_AVX512_AND_VNNI_HEADER |
93 | GenConvKernel<SPATIAL_DIM, INST_SET>::storeResult(x86::Emitter* a) { |
94 | if (GTogether_ > 1) { |
95 | // store with permutation |
96 | a->vpermd(x86::Zmm(9), stPermReg_V_, x86::Zmm(9)); |
97 | if (this->accum_) { |
98 | a->vpaddd(x86::Zmm(9), x86::Zmm(9), x86::zmmword_ptr(out_acts_R_)); |
99 | } |
100 | a->vmovups(x86::zmmword_ptr(out_acts_R_), x86::Zmm(9)); |
101 | } else { |
102 | // horizontal add and store |
103 | if (this->C_per_G_ == 8) { |
104 | a->vextracti32x8(tmpReg1_V_.ymm(), x86::Zmm(9), 1); |
105 | a->vphaddd(x86::Ymm(9), x86::Ymm(9), tmpReg1_V_.ymm()); |
106 | a->vpermq(x86::Ymm(9), x86::Ymm(9), static_cast<asmjit::Imm>(0xd8)); |
107 | if (this->accum_) { |
108 | a->vpaddd(x86::Ymm(9), x86::Ymm(9), x86::ymmword_ptr(out_acts_R_)); |
109 | } |
110 | a->vmovups(x86::ymmword_ptr(out_acts_R_), x86::Ymm(9)); |
111 | } else if (this->K_per_G_ == 16) { |
112 | // we have results in 4 Zmm registers, need to reduce them to 2 Ymm |
113 | // register 2 * 8 * 32 where 16 is K_per_g |
114 | // first reduce 4 * 16 * 32bits to 4 * 8 * 32bits |
115 | for (int k = 0; k < kLoopIters_; ++k) { |
116 | auto source_reg = x86::Zmm(9 - k); |
117 | auto result_reg = x86::Ymm(9 - k); |
118 | a->vextracti32x8(x86::Ymm(0), source_reg, 1); |
119 | a->vphaddd(result_reg, result_reg, x86::Ymm(0)); |
120 | a->vpermq(result_reg, result_reg, static_cast<asmjit::Imm>(0xd8)); |
121 | } |
122 | // secondly reduce 4 * 8 * 32 to 2 * 8 * 32 bits; |
123 | for (int k = 0, i = 0; k < kLoopIters_; k += 2, i++) { |
124 | auto result_reg = x86::Ymm(9 - k); |
125 | auto adjacent_result_reg = x86::Ymm(9 - k - 1); |
126 | a->vphaddd(result_reg, result_reg, adjacent_result_reg); |
127 | a->vpermq(result_reg, result_reg, static_cast<asmjit::Imm>(0xd8)); |
128 | if (this->accum_) { |
129 | a->vpaddd( |
130 | result_reg, result_reg, x86::ymmword_ptr(out_acts_R_, 32 * i)); |
131 | } |
132 | a->vmovups(x86::ymmword_ptr(out_acts_R_, 32 * i), result_reg); |
133 | } |
134 | } |
135 | } |
136 | } |
137 | |
138 | GCONV_INST_DEF_AVX512_AND_VNNI_HEADER |
139 | GenConvKernel<SPATIAL_DIM, INST_SET>::storeOffset(x86::Emitter* a) { |
140 | auto rowOffsetReg_V_Ymm = rowOffsetReg_V_.half(); |
141 | auto rowOffsetReg_V_Xmm = rowOffsetReg_V_Ymm.half(); |
142 | auto tmpReg1_V_Xmm = tmpReg1_V_.half().half(); |
143 | switch (this->C_per_G_) { |
144 | case 2: |
145 | // store 256-bits containing rowoffset for eight groups |
146 | if (this->accum_) { |
147 | a->vpaddd( |
148 | rowOffsetReg_V_Ymm, |
149 | rowOffsetReg_V_Ymm, |
150 | x86::ymmword_ptr(row_offset_R_)); |
151 | } |
152 | a->vmovdqu(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_Ymm); |
153 | break; |
154 | case 4: |
155 | // store 128-bits containing rowoffset for four groups |
156 | if (this->accum_) { |
157 | a->vmovdqu(tmpReg1_V_Xmm, x86::dword_ptr(row_offset_R_)); |
158 | a->paddd(rowOffsetReg_V_Xmm, tmpReg1_V_Xmm); |
159 | } |
160 | a->vmovups(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_Xmm); |
161 | break; |
162 | case 8: |
163 | // store 32-bits of one group |
164 | if (this->accum_) { |
165 | a->vmovd(tmpReg1_V_Xmm, x86::dword_ptr(row_offset_R_)); |
166 | a->paddd(rowOffsetReg_V_Xmm, tmpReg1_V_Xmm); |
167 | } |
168 | a->vmovd(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_Xmm); |
169 | break; |
170 | case 16: |
171 | // rowOffsetReg_V_[0:63] has sum for first 8 and |
172 | // rowOffsetReg_V_[64:127] has sum for second 8 |
173 | // execute vphaddd twice to sum the two |
174 | a->vphaddd(rowOffsetReg_V_Ymm, rowOffsetReg_V_Ymm, rowOffsetReg_V_Ymm); |
175 | a->vphaddd(rowOffsetReg_V_Ymm, rowOffsetReg_V_Ymm, rowOffsetReg_V_Ymm); |
176 | if (this->accum_) { |
177 | a->vmovd(tmpReg1_V_Xmm, x86::dword_ptr(row_offset_R_)); |
178 | a->paddd(rowOffsetReg_V_Xmm, tmpReg1_V_Xmm); |
179 | } |
180 | a->vmovd(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_Xmm); |
181 | break; |
182 | default: |
183 | assert(0 && "not supported case" ); |
184 | } |
185 | } |
186 | |
187 | GCONV_INST_DEF_AVX512_AND_VNNI_HEADER |
188 | GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleFilterPoint( |
189 | x86::Emitter* a, |
190 | int r, |
191 | int s, |
192 | int act_s, |
193 | bool use_zero_reg) { |
194 | using WRegs = x86::Zmm; |
195 | |
196 | if (use_zero_reg) { |
197 | a->vmovapd(actReg_V_, zeroPTReg_V_); // 64 * 8 bit zero points |
198 | } else { |
199 | if (this->C_per_G_ != 8) { |
200 | // 2(C_Per_g) * 8 (g_together) or |
201 | // 4(C_Per_g) * 4 (g_together) or |
202 | // 16(C_Per_g) broadcasted into 4 slots of ZMM |
203 | a->vbroadcasti32x4( |
204 | actReg_V_, |
205 | x86::oword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t))); |
206 | } else { |
207 | // 8(C_Per_g) broadcasted into 8 slots of ZMM |
208 | a->vbroadcasti32x2( |
209 | actReg_V_, |
210 | x86::qword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t))); |
211 | } |
212 | } |
213 | |
214 | // zero extend if C_per_g smaller than 4(the accumulation width our FMA |
215 | // instruction) |
216 | if (this->C_per_G_ == 2) { |
217 | // only use the lower half and extend them to 32bits(4 uint8's) |
218 | a->vpmovzxwd(actReg_V_, actReg_V_.half()); |
219 | } |
220 | |
221 | // row offset |
222 | if (this->needRowOffset_) { |
223 | if (this->C_per_G_ == 2 || this->C_per_G_ == 4) { |
224 | genU8Sum4<INST_SET>( |
225 | a, actReg_V_, rowOffsetReg_V_, oneReg16Bit_V_, tmpReg1_V_); |
226 | } else { |
227 | // still use Ymm for Sum8 |
228 | genU8Sum8(a, actReg_V_.half(), rowOffsetReg_V_.half(), tmpReg1_V_.half()); |
229 | } |
230 | } |
231 | // FMA |
232 | if (this->C_per_G_ != 16) { |
233 | genU8I8S32FMA<INST_SET>( |
234 | a, |
235 | actReg_V_, |
236 | WRegs(r * this->S_ + s), |
237 | WRegs(9), |
238 | oneReg16Bit_V_, |
239 | tmpReg1_V_); |
240 | } else { |
241 | // simd_info<inst_set_t::avx512>::WIDTH_BYTES |
242 | int kLoopMultiplier = 64 / this->C_per_G_; |
243 | for (int k = 0; k < kLoopIters_; ++k) { |
244 | a->vmovaps( |
245 | WRegs(0), |
246 | // copy 512 bits of weights into ZMM, 16(C_Per_g) * 4(1/4 of K_Per_g) |
247 | x86::zmmword_ptr( |
248 | wghts_R_, |
249 | (((r * this->S_ + s) * this->K_per_G_) + k * kLoopMultiplier) * |
250 | this->C_per_G_ * sizeof(int8_t))); |
251 | // FMA result is not final reduction on C_per_G, producing 4 * 16 outputs |
252 | // in which consectutive 4 elements if summed forms one final output over |
253 | // K_Per_G dimension, we need 16 final 32bits outputs. |
254 | genU8I8S32FMA<INST_SET>( |
255 | a, actReg_V_, WRegs(0), WRegs(9 - k), oneReg16Bit_V_, tmpReg1_V_); |
256 | } |
257 | } |
258 | } |
259 | #define GENCONVKERNEL_FUNCS(S, IN) \ |
260 | template void GenConvKernel<S, IN>::genForLoadingWeights<IN>( \ |
261 | x86::Emitter * a); \ |
262 | template void GenConvKernel<S, IN>::genConstForPermutations<IN>( \ |
263 | x86::Emitter * a); \ |
264 | template void GenConvKernel<S, IN>::genForSingleFilterPoint<IN>( \ |
265 | x86::Emitter * a, int r, int s, int act_s, bool use_zero_reg); \ |
266 | template void GenConvKernel<S, IN>::storeResult<IN>(x86::Emitter * a); \ |
267 | template void GenConvKernel<S, IN>::storeOffset<IN>(x86::Emitter * a); |
268 | GENCONVKERNEL_FUNCS(1, inst_set_t::avx512) |
269 | GENCONVKERNEL_FUNCS(1, inst_set_t::avx512_vnni) |
270 | GENCONVKERNEL_FUNCS(2, inst_set_t::avx512) |
271 | GENCONVKERNEL_FUNCS(2, inst_set_t::avx512_vnni) |
272 | GENCONVKERNEL_FUNCS(3, inst_set_t::avx512) |
273 | GENCONVKERNEL_FUNCS(3, inst_set_t::avx512_vnni) |
274 | #undef GENCONVKERNEL_FUNCS |
275 | |
276 | template class GenConvKernel<1, inst_set_t::avx512>; |
277 | template class GenConvKernel<1, inst_set_t::avx512_vnni>; |
278 | template class GenConvKernel<2, inst_set_t::avx512>; |
279 | template class GenConvKernel<2, inst_set_t::avx512_vnni>; |
280 | template class GenConvKernel<3, inst_set_t::avx512>; |
281 | template class GenConvKernel<3, inst_set_t::avx512_vnni>; |
282 | |
283 | } // namespace fbgemm |
284 | |