1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <asmjit/asmjit.h>
9#include "./CodeGenHelpers.h"
10#include "./GroupwiseConv.h"
11
12namespace fbgemm {
13
14using namespace std;
15
16namespace x86 = asmjit::x86;
17
18GCONV_INST_DEF_AVX512_AND_VNNI_HEADER
19GenConvKernel<SPATIAL_DIM, INST_SET>::genConstForPermutations(x86::Emitter* a) {
20 x86::Gp permute_const_reg_upper_half = a->gpz(12);
21 x86::Gp permute_const_reg_lower_half = a->gpz(13);
22 x86::Xmm const_reg_xmm = x86::xmm11;
23 if (this->C_per_G_ == 4) {
24 // 4 group together
25 // We have 1st group in position 0 and 4, 2nd group 1 and 5 and so on.
26 // Permute to put 1st group to lower 128-bit and 2nd group to next
27 // 128-bit and so on.
28 // load f, b, 7, 3,
29 // e, a, 6, 2,
30 // d, 9, 5, 1,
31 // c, 8, 4, 0 in a 128-bit Xmm
32 a->mov(
33 permute_const_reg_lower_half,
34 static_cast<asmjit::Imm>(0x0d0905010c080400));
35 a->mov(
36 permute_const_reg_upper_half,
37 static_cast<asmjit::Imm>(0x0f0b07030e0a0602));
38 } else {
39 // this->C_per_G_ == 2
40 // 8 group together
41 // We have 1st group in position 0 and 8, 2nd group 1 and 9 and so on.
42 // Permute to put 1st group to lower 128-bit and 2nd group to next
43 // 128-bit and so on.
44 // load
45 // f, 7
46 // e, 6
47 // d, 5
48 // c, 4
49 // b, 3
50 // a, 2
51 // 9, 1
52 // 8, 0 in a 128-bit Xmm
53 a->mov(
54 permute_const_reg_lower_half,
55 static_cast<asmjit::Imm>(0x0b030a0209010800));
56 a->mov(
57 permute_const_reg_upper_half,
58 static_cast<asmjit::Imm>(0x0f070e060d050c04));
59 }
60
61 a->movq(const_reg_xmm, permute_const_reg_lower_half);
62 a->pinsrq(const_reg_xmm, permute_const_reg_upper_half, 1);
63 // Zero extend 16 packed 8-bit integers in the low 8 bytes of const_reg_xmm
64 // to 16 packed 32-bit integers in stPermReg_V_
65 a->vpmovzxbd(stPermReg_V_, const_reg_xmm);
66}
67
68GCONV_INST_DEF_AVX512_AND_VNNI_HEADER
69GenConvKernel<SPATIAL_DIM, INST_SET>::genForLoadingWeights(x86::Emitter* a) {
70 using WRegs = x86::Zmm;
71 int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4;
72 // load weights
73 for (int r = 0; r < this->R_; ++r) {
74 for (int s = 0; s < this->S_; ++s) {
75 // For other cases, weights are too big to be kept in registers
76 // and are loaded as they are used.
77 if (this->C_per_G_ != 16) {
78 // still use aligned move since the weigh buffer is 64bytes aligned.
79 a->vmovaps(
80 WRegs(r * this->S_ + s),
81 // load 512 bits for weights, different grouping for different
82 // workload
83 x86::zmmword_ptr(
84 wghts_R_,
85 (r * this->S_ + s) * this->K_per_G_ * GTogether_ *
86 paddedICPerG * sizeof(int8_t)));
87 }
88 }
89 }
90}
91
92GCONV_INST_DEF_AVX512_AND_VNNI_HEADER
93GenConvKernel<SPATIAL_DIM, INST_SET>::storeResult(x86::Emitter* a) {
94 if (GTogether_ > 1) {
95 // store with permutation
96 a->vpermd(x86::Zmm(9), stPermReg_V_, x86::Zmm(9));
97 if (this->accum_) {
98 a->vpaddd(x86::Zmm(9), x86::Zmm(9), x86::zmmword_ptr(out_acts_R_));
99 }
100 a->vmovups(x86::zmmword_ptr(out_acts_R_), x86::Zmm(9));
101 } else {
102 // horizontal add and store
103 if (this->C_per_G_ == 8) {
104 a->vextracti32x8(tmpReg1_V_.ymm(), x86::Zmm(9), 1);
105 a->vphaddd(x86::Ymm(9), x86::Ymm(9), tmpReg1_V_.ymm());
106 a->vpermq(x86::Ymm(9), x86::Ymm(9), static_cast<asmjit::Imm>(0xd8));
107 if (this->accum_) {
108 a->vpaddd(x86::Ymm(9), x86::Ymm(9), x86::ymmword_ptr(out_acts_R_));
109 }
110 a->vmovups(x86::ymmword_ptr(out_acts_R_), x86::Ymm(9));
111 } else if (this->K_per_G_ == 16) {
112 // we have results in 4 Zmm registers, need to reduce them to 2 Ymm
113 // register 2 * 8 * 32 where 16 is K_per_g
114 // first reduce 4 * 16 * 32bits to 4 * 8 * 32bits
115 for (int k = 0; k < kLoopIters_; ++k) {
116 auto source_reg = x86::Zmm(9 - k);
117 auto result_reg = x86::Ymm(9 - k);
118 a->vextracti32x8(x86::Ymm(0), source_reg, 1);
119 a->vphaddd(result_reg, result_reg, x86::Ymm(0));
120 a->vpermq(result_reg, result_reg, static_cast<asmjit::Imm>(0xd8));
121 }
122 // secondly reduce 4 * 8 * 32 to 2 * 8 * 32 bits;
123 for (int k = 0, i = 0; k < kLoopIters_; k += 2, i++) {
124 auto result_reg = x86::Ymm(9 - k);
125 auto adjacent_result_reg = x86::Ymm(9 - k - 1);
126 a->vphaddd(result_reg, result_reg, adjacent_result_reg);
127 a->vpermq(result_reg, result_reg, static_cast<asmjit::Imm>(0xd8));
128 if (this->accum_) {
129 a->vpaddd(
130 result_reg, result_reg, x86::ymmword_ptr(out_acts_R_, 32 * i));
131 }
132 a->vmovups(x86::ymmword_ptr(out_acts_R_, 32 * i), result_reg);
133 }
134 }
135 }
136}
137
138GCONV_INST_DEF_AVX512_AND_VNNI_HEADER
139GenConvKernel<SPATIAL_DIM, INST_SET>::storeOffset(x86::Emitter* a) {
140 auto rowOffsetReg_V_Ymm = rowOffsetReg_V_.half();
141 auto rowOffsetReg_V_Xmm = rowOffsetReg_V_Ymm.half();
142 auto tmpReg1_V_Xmm = tmpReg1_V_.half().half();
143 switch (this->C_per_G_) {
144 case 2:
145 // store 256-bits containing rowoffset for eight groups
146 if (this->accum_) {
147 a->vpaddd(
148 rowOffsetReg_V_Ymm,
149 rowOffsetReg_V_Ymm,
150 x86::ymmword_ptr(row_offset_R_));
151 }
152 a->vmovdqu(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_Ymm);
153 break;
154 case 4:
155 // store 128-bits containing rowoffset for four groups
156 if (this->accum_) {
157 a->vmovdqu(tmpReg1_V_Xmm, x86::dword_ptr(row_offset_R_));
158 a->paddd(rowOffsetReg_V_Xmm, tmpReg1_V_Xmm);
159 }
160 a->vmovups(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_Xmm);
161 break;
162 case 8:
163 // store 32-bits of one group
164 if (this->accum_) {
165 a->vmovd(tmpReg1_V_Xmm, x86::dword_ptr(row_offset_R_));
166 a->paddd(rowOffsetReg_V_Xmm, tmpReg1_V_Xmm);
167 }
168 a->vmovd(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_Xmm);
169 break;
170 case 16:
171 // rowOffsetReg_V_[0:63] has sum for first 8 and
172 // rowOffsetReg_V_[64:127] has sum for second 8
173 // execute vphaddd twice to sum the two
174 a->vphaddd(rowOffsetReg_V_Ymm, rowOffsetReg_V_Ymm, rowOffsetReg_V_Ymm);
175 a->vphaddd(rowOffsetReg_V_Ymm, rowOffsetReg_V_Ymm, rowOffsetReg_V_Ymm);
176 if (this->accum_) {
177 a->vmovd(tmpReg1_V_Xmm, x86::dword_ptr(row_offset_R_));
178 a->paddd(rowOffsetReg_V_Xmm, tmpReg1_V_Xmm);
179 }
180 a->vmovd(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_Xmm);
181 break;
182 default:
183 assert(0 && "not supported case");
184 }
185}
186
187GCONV_INST_DEF_AVX512_AND_VNNI_HEADER
188GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleFilterPoint(
189 x86::Emitter* a,
190 int r,
191 int s,
192 int act_s,
193 bool use_zero_reg) {
194 using WRegs = x86::Zmm;
195
196 if (use_zero_reg) {
197 a->vmovapd(actReg_V_, zeroPTReg_V_); // 64 * 8 bit zero points
198 } else {
199 if (this->C_per_G_ != 8) {
200 // 2(C_Per_g) * 8 (g_together) or
201 // 4(C_Per_g) * 4 (g_together) or
202 // 16(C_Per_g) broadcasted into 4 slots of ZMM
203 a->vbroadcasti32x4(
204 actReg_V_,
205 x86::oword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t)));
206 } else {
207 // 8(C_Per_g) broadcasted into 8 slots of ZMM
208 a->vbroadcasti32x2(
209 actReg_V_,
210 x86::qword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t)));
211 }
212 }
213
214 // zero extend if C_per_g smaller than 4(the accumulation width our FMA
215 // instruction)
216 if (this->C_per_G_ == 2) {
217 // only use the lower half and extend them to 32bits(4 uint8's)
218 a->vpmovzxwd(actReg_V_, actReg_V_.half());
219 }
220
221 // row offset
222 if (this->needRowOffset_) {
223 if (this->C_per_G_ == 2 || this->C_per_G_ == 4) {
224 genU8Sum4<INST_SET>(
225 a, actReg_V_, rowOffsetReg_V_, oneReg16Bit_V_, tmpReg1_V_);
226 } else {
227 // still use Ymm for Sum8
228 genU8Sum8(a, actReg_V_.half(), rowOffsetReg_V_.half(), tmpReg1_V_.half());
229 }
230 }
231 // FMA
232 if (this->C_per_G_ != 16) {
233 genU8I8S32FMA<INST_SET>(
234 a,
235 actReg_V_,
236 WRegs(r * this->S_ + s),
237 WRegs(9),
238 oneReg16Bit_V_,
239 tmpReg1_V_);
240 } else {
241 // simd_info<inst_set_t::avx512>::WIDTH_BYTES
242 int kLoopMultiplier = 64 / this->C_per_G_;
243 for (int k = 0; k < kLoopIters_; ++k) {
244 a->vmovaps(
245 WRegs(0),
246 // copy 512 bits of weights into ZMM, 16(C_Per_g) * 4(1/4 of K_Per_g)
247 x86::zmmword_ptr(
248 wghts_R_,
249 (((r * this->S_ + s) * this->K_per_G_) + k * kLoopMultiplier) *
250 this->C_per_G_ * sizeof(int8_t)));
251 // FMA result is not final reduction on C_per_G, producing 4 * 16 outputs
252 // in which consectutive 4 elements if summed forms one final output over
253 // K_Per_G dimension, we need 16 final 32bits outputs.
254 genU8I8S32FMA<INST_SET>(
255 a, actReg_V_, WRegs(0), WRegs(9 - k), oneReg16Bit_V_, tmpReg1_V_);
256 }
257 }
258}
259#define GENCONVKERNEL_FUNCS(S, IN) \
260 template void GenConvKernel<S, IN>::genForLoadingWeights<IN>( \
261 x86::Emitter * a); \
262 template void GenConvKernel<S, IN>::genConstForPermutations<IN>( \
263 x86::Emitter * a); \
264 template void GenConvKernel<S, IN>::genForSingleFilterPoint<IN>( \
265 x86::Emitter * a, int r, int s, int act_s, bool use_zero_reg); \
266 template void GenConvKernel<S, IN>::storeResult<IN>(x86::Emitter * a); \
267 template void GenConvKernel<S, IN>::storeOffset<IN>(x86::Emitter * a);
268GENCONVKERNEL_FUNCS(1, inst_set_t::avx512)
269GENCONVKERNEL_FUNCS(1, inst_set_t::avx512_vnni)
270GENCONVKERNEL_FUNCS(2, inst_set_t::avx512)
271GENCONVKERNEL_FUNCS(2, inst_set_t::avx512_vnni)
272GENCONVKERNEL_FUNCS(3, inst_set_t::avx512)
273GENCONVKERNEL_FUNCS(3, inst_set_t::avx512_vnni)
274#undef GENCONVKERNEL_FUNCS
275
276template class GenConvKernel<1, inst_set_t::avx512>;
277template class GenConvKernel<1, inst_set_t::avx512_vnni>;
278template class GenConvKernel<2, inst_set_t::avx512>;
279template class GenConvKernel<2, inst_set_t::avx512_vnni>;
280template class GenConvKernel<3, inst_set_t::avx512>;
281template class GenConvKernel<3, inst_set_t::avx512_vnni>;
282
283} // namespace fbgemm
284