1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include <asmjit/asmjit.h> |
9 | #include "./CodeGenHelpers.h" |
10 | #include "./GroupwiseConv.h" |
11 | #include "fbgemm/Fbgemm.h" |
12 | |
13 | namespace fbgemm { |
14 | |
15 | using namespace std; |
16 | |
17 | namespace x86 = asmjit::x86; |
18 | |
19 | GCONV_INST_DEF_AVX2_HEADER |
20 | GenConvKernel<SPATIAL_DIM, INST_SET>::genConstForPermutations(x86::Emitter* a) { |
21 | if (this->C_per_G_ == 4) { |
22 | x86::Gp permute_const_reg = a->gpz(12); |
23 | x86::Xmm const_reg_xmm = x86::xmm11; |
24 | // We have 1st group in even lanes and 2nd group in odd lanes. |
25 | // Permute to put 1st group to lower 128-bit and 2nd group in upper |
26 | // 128-bit. |
27 | // load 7, 5, 3, 1, 6, 4, 2, 0 in a 64-bit reg |
28 | a->mov(permute_const_reg, static_cast<asmjit::Imm>(0x0705030106040200)); |
29 | a->movq(const_reg_xmm, permute_const_reg); |
30 | // Zero extend 8 packed 8-bit integers in the low 8 bytes of const_reg_xmm |
31 | // to 8 packed 32-bit integers in stPermReg_V_ |
32 | a->vpmovzxbd(stPermReg_V_, const_reg_xmm); |
33 | } else { |
34 | // this->C_per_G_ == 2 |
35 | x86::Gp permute_const_reg = a->gpz(12); |
36 | x86::Xmm const_reg_xmm = x86::xmm11; |
37 | // We have 1st group in position 0 and 4, 2nd group 1 and 5 and so on. |
38 | // Permute to put 1st group to lower 64-bit and 2nd group to next |
39 | // 64-bit and so on. |
40 | // load 7, 3, 6, 2, 5, 1, 4, 0 in a 64-bit reg |
41 | a->mov(permute_const_reg, static_cast<asmjit::Imm>(0x0703060205010400)); |
42 | a->movq(const_reg_xmm, permute_const_reg); |
43 | a->vpmovzxbd(stPermReg_V_, const_reg_xmm); |
44 | } |
45 | } |
46 | |
47 | GCONV_INST_DEF_AVX2_HEADER |
48 | GenConvKernel<SPATIAL_DIM, INST_SET>::genForLoadingWeights(x86::Emitter* a) { |
49 | using WRegs = x86::Ymm; |
50 | int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4; |
51 | // load weights |
52 | for (int r = 0; r < this->R_; ++r) { |
53 | for (int s = 0; s < this->S_; ++s) { |
54 | // For other cases, weights are too big to be kept in registers |
55 | // and are loaded as they are used. |
56 | if (this->C_per_G_ == 4 || this->C_per_G_ == 2) { |
57 | a->vmovaps( |
58 | WRegs(r * this->S_ + s), |
59 | x86::dword_ptr( |
60 | wghts_R_, |
61 | (r * this->S_ + s) * this->K_per_G_ * GTogether_ * |
62 | paddedICPerG * sizeof(int8_t))); |
63 | } |
64 | } |
65 | } |
66 | } |
67 | |
68 | GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeResult( |
69 | x86::Emitter* a) { |
70 | if (GTogether_ > 1) { |
71 | // store with permutation |
72 | a->vpermd(x86::Ymm(9), stPermReg_V_, x86::Ymm(9)); |
73 | if (this->accum_) { |
74 | a->vpaddd(x86::Ymm(9), x86::Ymm(9), x86::dword_ptr(out_acts_R_)); |
75 | } |
76 | a->vmovups(x86::dword_ptr(out_acts_R_), x86::Ymm(9)); |
77 | } else { |
78 | // horizontal add and store |
79 | if (this->C_per_G_ == 8) { |
80 | a->vphaddd(x86::Ymm(9), x86::Ymm(9), x86::Ymm(8)); |
81 | a->vpermq(x86::Ymm(9), x86::Ymm(9), static_cast<asmjit::Imm>(0xd8)); |
82 | if (this->accum_) { |
83 | a->vpaddd(x86::Ymm(9), x86::Ymm(9), x86::dword_ptr(out_acts_R_)); |
84 | } |
85 | a->vmovups(x86::dword_ptr(out_acts_R_), x86::Ymm(9)); |
86 | } else if (this->K_per_G_ == 16) { |
87 | a->vphaddd(x86::Ymm(9), x86::Ymm(9), x86::Ymm(8)); |
88 | a->vpermq(x86::Ymm(9), x86::Ymm(9), static_cast<asmjit::Imm>(0xd8)); |
89 | |
90 | a->vphaddd(x86::Ymm(7), x86::Ymm(7), x86::Ymm(6)); |
91 | a->vpermq(x86::Ymm(7), x86::Ymm(7), static_cast<asmjit::Imm>(0xd8)); |
92 | |
93 | a->vphaddd(x86::Ymm(5), x86::Ymm(5), x86::Ymm(4)); |
94 | a->vpermq(x86::Ymm(5), x86::Ymm(5), static_cast<asmjit::Imm>(0xd8)); |
95 | |
96 | a->vphaddd(x86::Ymm(3), x86::Ymm(3), x86::Ymm(2)); |
97 | a->vpermq(x86::Ymm(3), x86::Ymm(3), static_cast<asmjit::Imm>(0xd8)); |
98 | |
99 | a->vphaddd(x86::Ymm(9), x86::Ymm(9), x86::Ymm(7)); |
100 | a->vpermq(x86::Ymm(9), x86::Ymm(9), static_cast<asmjit::Imm>(0xd8)); |
101 | |
102 | a->vphaddd(x86::Ymm(5), x86::Ymm(5), x86::Ymm(3)); |
103 | a->vpermq(x86::Ymm(5), x86::Ymm(5), static_cast<asmjit::Imm>(0xd8)); |
104 | |
105 | if (this->accum_) { |
106 | a->vpaddd(x86::Ymm(9), x86::Ymm(9), x86::dword_ptr(out_acts_R_)); |
107 | a->vpaddd(x86::Ymm(5), x86::Ymm(5), x86::dword_ptr(out_acts_R_, 32)); |
108 | } |
109 | a->vmovups(x86::dword_ptr(out_acts_R_), x86::Ymm(9)); |
110 | a->vmovups(x86::dword_ptr(out_acts_R_, 32), x86::Ymm(5)); |
111 | } |
112 | } |
113 | } |
114 | |
115 | GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeOffset( |
116 | x86::Emitter* a) { |
117 | switch (this->C_per_G_) { |
118 | case 2: |
119 | // store 128-bits containing rowoffset for four groups |
120 | if (this->accum_) { |
121 | a->paddd(rowOffsetReg_V_.half(), x86::dword_ptr(row_offset_R_)); |
122 | } |
123 | a->vmovdqu(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_.half()); |
124 | break; |
125 | case 4: |
126 | // store 64-bits containing rowoffset for two groups |
127 | if (this->accum_) { |
128 | a->vmovq(tmpReg1_V_.half(), x86::dword_ptr(row_offset_R_)); |
129 | a->paddd(rowOffsetReg_V_.half(), tmpReg1_V_.half()); |
130 | } |
131 | a->vmovq(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_.half()); |
132 | break; |
133 | case 8: |
134 | if (this->accum_) { |
135 | a->vmovd(tmpReg1_V_.half(), x86::dword_ptr(row_offset_R_)); |
136 | a->paddd(rowOffsetReg_V_.half(), tmpReg1_V_.half()); |
137 | } |
138 | a->vmovd(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_.half()); |
139 | break; |
140 | case 16: |
141 | // rowOffsetReg_V_[0:63] has sum for first 8 and |
142 | // rowOffsetReg_V_[64:127] has sum for second 8 |
143 | // execute vphaddd twice to sum the two |
144 | a->vphaddd(rowOffsetReg_V_, rowOffsetReg_V_, rowOffsetReg_V_); |
145 | a->vphaddd(rowOffsetReg_V_, rowOffsetReg_V_, rowOffsetReg_V_); |
146 | if (this->accum_) { |
147 | a->vmovd(tmpReg1_V_.half(), x86::dword_ptr(row_offset_R_)); |
148 | a->paddd(rowOffsetReg_V_.half(), tmpReg1_V_.half()); |
149 | } |
150 | a->vmovd(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_.half()); |
151 | break; |
152 | default: |
153 | assert(0 && "not supported case" ); |
154 | } |
155 | } |
156 | |
157 | GCONV_INST_DEF_AVX2_HEADER |
158 | GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleFilterPoint( |
159 | x86::Emitter* a, |
160 | int r, |
161 | int s, |
162 | int act_s, |
163 | bool use_zero_reg) { |
164 | using WRegs = x86::Ymm; |
165 | |
166 | if (GTogether_ > 1) { |
167 | if (this->C_per_G_ == 2) { // group together = 4 |
168 | if (use_zero_reg) { |
169 | a->vmovapd(actReg_V_, zeroPTReg_V_); // 32 * 8 bit zero points |
170 | } else { |
171 | a->vbroadcastsd( // 64 bits broadcast, 2(C_per_g) * 4 (g_together) and |
172 | // broadcast them to align with weights layout |
173 | actReg_V_, |
174 | x86::word_ptr(in_acts_R_, (act_s * this->C_) * sizeof(uint8_t))); |
175 | } |
176 | // 8 * 16 bit activation to 8 * 32 bit activation( C_per_G = 2) |
177 | // zero extend because vpmaddubsw and vpmaddwd together sum 4 consecutive |
178 | // elements |
179 | a->vpmovzxwd(actReg_V_, actReg_V_.half()); |
180 | } else if (this->C_per_G_ == 4) { // group together = 2 |
181 | if (use_zero_reg) { |
182 | a->vmovapd(actReg_V_, zeroPTReg_V_); // 32 * 8 bit zero points |
183 | } else { |
184 | a->vbroadcastsd( |
185 | actReg_V_, |
186 | x86::dword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t))); |
187 | } |
188 | } |
189 | // row offset |
190 | if (this->needRowOffset_) { |
191 | genU8Sum4<INST_SET>( |
192 | a, actReg_V_, rowOffsetReg_V_, oneReg16Bit_V_, tmpReg1_V_); |
193 | } |
194 | // 32 * int8 weight product 32 * uint8 activation -> 8 |
195 | // output(K_per_g * group_together) |
196 | genU8I8S32FMA<INST_SET>( |
197 | a, |
198 | actReg_V_, |
199 | WRegs(r * this->S_ + s), |
200 | x86::Ymm(9), |
201 | oneReg16Bit_V_, |
202 | tmpReg1_V_); |
203 | } else { |
204 | if (this->C_per_G_ == 8) { |
205 | if (use_zero_reg) { |
206 | a->vmovapd(actReg_V_, zeroPTReg_V_); |
207 | } else { |
208 | a->vbroadcastsd( |
209 | actReg_V_, |
210 | x86::qword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t))); |
211 | } |
212 | } else { |
213 | // this->C_per_G_ == 16 |
214 | if (use_zero_reg) { |
215 | a->vmovapd(actReg_V_, zeroPTReg_V_); |
216 | } else { |
217 | a->vbroadcasti128( |
218 | actReg_V_, |
219 | x86::oword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t))); |
220 | } |
221 | } |
222 | // row offset |
223 | if (this->needRowOffset_) { |
224 | genU8Sum8(a, actReg_V_, rowOffsetReg_V_, tmpReg1_V_); |
225 | } |
226 | int kLoopMultiplier = 32 / this->C_per_G_; |
227 | for (int k = 0; k < kLoopIters_; ++k) { |
228 | a->vmovaps( |
229 | WRegs(0), |
230 | x86::dword_ptr( |
231 | wghts_R_, |
232 | (((r * this->S_ + s) * this->K_per_G_) + k * kLoopMultiplier) * |
233 | this->C_per_G_ * sizeof(int8_t))); |
234 | // FMA result is not final reduction on C_per_G, producing 8 output in |
235 | // which consectutive 2 elements if summedforms one final output over |
236 | // K_Per_G dimension |
237 | genU8I8S32FMA<INST_SET>( |
238 | a, actReg_V_, WRegs(0), x86::Ymm(9 - k), oneReg16Bit_V_, tmpReg1_V_); |
239 | } |
240 | } |
241 | } |
242 | |
243 | #define GENCONVKERNEL_FUNCS(S, IN) \ |
244 | template void GenConvKernel<S, IN>::genForLoadingWeights<IN>( \ |
245 | x86::Emitter * a); \ |
246 | template void GenConvKernel<S, IN>::genConstForPermutations<IN>( \ |
247 | x86::Emitter * a); \ |
248 | template void GenConvKernel<S, IN>::genForSingleFilterPoint<IN>( \ |
249 | x86::Emitter * a, int r, int s, int act_s, bool use_zero_reg); \ |
250 | template void GenConvKernel<S, IN>::storeResult<IN>(x86::Emitter * a); \ |
251 | template void GenConvKernel<S, IN>::storeOffset<IN>(x86::Emitter * a); |
252 | GENCONVKERNEL_FUNCS(1, inst_set_t::avx2) |
253 | GENCONVKERNEL_FUNCS(2, inst_set_t::avx2) |
254 | GENCONVKERNEL_FUNCS(3, inst_set_t::avx2) |
255 | #undef GENCONVKERNEL_FUNCS |
256 | |
257 | template class GenConvKernel<1, inst_set_t::avx2>; |
258 | template class GenConvKernel<2, inst_set_t::avx2>; |
259 | template class GenConvKernel<3, inst_set_t::avx2>; |
260 | |
261 | } // namespace fbgemm |
262 | |