1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <asmjit/asmjit.h>
9#include "./CodeGenHelpers.h"
10#include "./GroupwiseConv.h"
11#include "fbgemm/Fbgemm.h"
12
13namespace fbgemm {
14
15using namespace std;
16
17namespace x86 = asmjit::x86;
18
19GCONV_INST_DEF_AVX2_HEADER
20GenConvKernel<SPATIAL_DIM, INST_SET>::genConstForPermutations(x86::Emitter* a) {
21 if (this->C_per_G_ == 4) {
22 x86::Gp permute_const_reg = a->gpz(12);
23 x86::Xmm const_reg_xmm = x86::xmm11;
24 // We have 1st group in even lanes and 2nd group in odd lanes.
25 // Permute to put 1st group to lower 128-bit and 2nd group in upper
26 // 128-bit.
27 // load 7, 5, 3, 1, 6, 4, 2, 0 in a 64-bit reg
28 a->mov(permute_const_reg, static_cast<asmjit::Imm>(0x0705030106040200));
29 a->movq(const_reg_xmm, permute_const_reg);
30 // Zero extend 8 packed 8-bit integers in the low 8 bytes of const_reg_xmm
31 // to 8 packed 32-bit integers in stPermReg_V_
32 a->vpmovzxbd(stPermReg_V_, const_reg_xmm);
33 } else {
34 // this->C_per_G_ == 2
35 x86::Gp permute_const_reg = a->gpz(12);
36 x86::Xmm const_reg_xmm = x86::xmm11;
37 // We have 1st group in position 0 and 4, 2nd group 1 and 5 and so on.
38 // Permute to put 1st group to lower 64-bit and 2nd group to next
39 // 64-bit and so on.
40 // load 7, 3, 6, 2, 5, 1, 4, 0 in a 64-bit reg
41 a->mov(permute_const_reg, static_cast<asmjit::Imm>(0x0703060205010400));
42 a->movq(const_reg_xmm, permute_const_reg);
43 a->vpmovzxbd(stPermReg_V_, const_reg_xmm);
44 }
45}
46
47GCONV_INST_DEF_AVX2_HEADER
48GenConvKernel<SPATIAL_DIM, INST_SET>::genForLoadingWeights(x86::Emitter* a) {
49 using WRegs = x86::Ymm;
50 int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4;
51 // load weights
52 for (int r = 0; r < this->R_; ++r) {
53 for (int s = 0; s < this->S_; ++s) {
54 // For other cases, weights are too big to be kept in registers
55 // and are loaded as they are used.
56 if (this->C_per_G_ == 4 || this->C_per_G_ == 2) {
57 a->vmovaps(
58 WRegs(r * this->S_ + s),
59 x86::dword_ptr(
60 wghts_R_,
61 (r * this->S_ + s) * this->K_per_G_ * GTogether_ *
62 paddedICPerG * sizeof(int8_t)));
63 }
64 }
65 }
66}
67
68GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeResult(
69 x86::Emitter* a) {
70 if (GTogether_ > 1) {
71 // store with permutation
72 a->vpermd(x86::Ymm(9), stPermReg_V_, x86::Ymm(9));
73 if (this->accum_) {
74 a->vpaddd(x86::Ymm(9), x86::Ymm(9), x86::dword_ptr(out_acts_R_));
75 }
76 a->vmovups(x86::dword_ptr(out_acts_R_), x86::Ymm(9));
77 } else {
78 // horizontal add and store
79 if (this->C_per_G_ == 8) {
80 a->vphaddd(x86::Ymm(9), x86::Ymm(9), x86::Ymm(8));
81 a->vpermq(x86::Ymm(9), x86::Ymm(9), static_cast<asmjit::Imm>(0xd8));
82 if (this->accum_) {
83 a->vpaddd(x86::Ymm(9), x86::Ymm(9), x86::dword_ptr(out_acts_R_));
84 }
85 a->vmovups(x86::dword_ptr(out_acts_R_), x86::Ymm(9));
86 } else if (this->K_per_G_ == 16) {
87 a->vphaddd(x86::Ymm(9), x86::Ymm(9), x86::Ymm(8));
88 a->vpermq(x86::Ymm(9), x86::Ymm(9), static_cast<asmjit::Imm>(0xd8));
89
90 a->vphaddd(x86::Ymm(7), x86::Ymm(7), x86::Ymm(6));
91 a->vpermq(x86::Ymm(7), x86::Ymm(7), static_cast<asmjit::Imm>(0xd8));
92
93 a->vphaddd(x86::Ymm(5), x86::Ymm(5), x86::Ymm(4));
94 a->vpermq(x86::Ymm(5), x86::Ymm(5), static_cast<asmjit::Imm>(0xd8));
95
96 a->vphaddd(x86::Ymm(3), x86::Ymm(3), x86::Ymm(2));
97 a->vpermq(x86::Ymm(3), x86::Ymm(3), static_cast<asmjit::Imm>(0xd8));
98
99 a->vphaddd(x86::Ymm(9), x86::Ymm(9), x86::Ymm(7));
100 a->vpermq(x86::Ymm(9), x86::Ymm(9), static_cast<asmjit::Imm>(0xd8));
101
102 a->vphaddd(x86::Ymm(5), x86::Ymm(5), x86::Ymm(3));
103 a->vpermq(x86::Ymm(5), x86::Ymm(5), static_cast<asmjit::Imm>(0xd8));
104
105 if (this->accum_) {
106 a->vpaddd(x86::Ymm(9), x86::Ymm(9), x86::dword_ptr(out_acts_R_));
107 a->vpaddd(x86::Ymm(5), x86::Ymm(5), x86::dword_ptr(out_acts_R_, 32));
108 }
109 a->vmovups(x86::dword_ptr(out_acts_R_), x86::Ymm(9));
110 a->vmovups(x86::dword_ptr(out_acts_R_, 32), x86::Ymm(5));
111 }
112 }
113}
114
115GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeOffset(
116 x86::Emitter* a) {
117 switch (this->C_per_G_) {
118 case 2:
119 // store 128-bits containing rowoffset for four groups
120 if (this->accum_) {
121 a->paddd(rowOffsetReg_V_.half(), x86::dword_ptr(row_offset_R_));
122 }
123 a->vmovdqu(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_.half());
124 break;
125 case 4:
126 // store 64-bits containing rowoffset for two groups
127 if (this->accum_) {
128 a->vmovq(tmpReg1_V_.half(), x86::dword_ptr(row_offset_R_));
129 a->paddd(rowOffsetReg_V_.half(), tmpReg1_V_.half());
130 }
131 a->vmovq(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_.half());
132 break;
133 case 8:
134 if (this->accum_) {
135 a->vmovd(tmpReg1_V_.half(), x86::dword_ptr(row_offset_R_));
136 a->paddd(rowOffsetReg_V_.half(), tmpReg1_V_.half());
137 }
138 a->vmovd(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_.half());
139 break;
140 case 16:
141 // rowOffsetReg_V_[0:63] has sum for first 8 and
142 // rowOffsetReg_V_[64:127] has sum for second 8
143 // execute vphaddd twice to sum the two
144 a->vphaddd(rowOffsetReg_V_, rowOffsetReg_V_, rowOffsetReg_V_);
145 a->vphaddd(rowOffsetReg_V_, rowOffsetReg_V_, rowOffsetReg_V_);
146 if (this->accum_) {
147 a->vmovd(tmpReg1_V_.half(), x86::dword_ptr(row_offset_R_));
148 a->paddd(rowOffsetReg_V_.half(), tmpReg1_V_.half());
149 }
150 a->vmovd(x86::dword_ptr(row_offset_R_), rowOffsetReg_V_.half());
151 break;
152 default:
153 assert(0 && "not supported case");
154 }
155}
156
157GCONV_INST_DEF_AVX2_HEADER
158GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleFilterPoint(
159 x86::Emitter* a,
160 int r,
161 int s,
162 int act_s,
163 bool use_zero_reg) {
164 using WRegs = x86::Ymm;
165
166 if (GTogether_ > 1) {
167 if (this->C_per_G_ == 2) { // group together = 4
168 if (use_zero_reg) {
169 a->vmovapd(actReg_V_, zeroPTReg_V_); // 32 * 8 bit zero points
170 } else {
171 a->vbroadcastsd( // 64 bits broadcast, 2(C_per_g) * 4 (g_together) and
172 // broadcast them to align with weights layout
173 actReg_V_,
174 x86::word_ptr(in_acts_R_, (act_s * this->C_) * sizeof(uint8_t)));
175 }
176 // 8 * 16 bit activation to 8 * 32 bit activation( C_per_G = 2)
177 // zero extend because vpmaddubsw and vpmaddwd together sum 4 consecutive
178 // elements
179 a->vpmovzxwd(actReg_V_, actReg_V_.half());
180 } else if (this->C_per_G_ == 4) { // group together = 2
181 if (use_zero_reg) {
182 a->vmovapd(actReg_V_, zeroPTReg_V_); // 32 * 8 bit zero points
183 } else {
184 a->vbroadcastsd(
185 actReg_V_,
186 x86::dword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t)));
187 }
188 }
189 // row offset
190 if (this->needRowOffset_) {
191 genU8Sum4<INST_SET>(
192 a, actReg_V_, rowOffsetReg_V_, oneReg16Bit_V_, tmpReg1_V_);
193 }
194 // 32 * int8 weight product 32 * uint8 activation -> 8
195 // output(K_per_g * group_together)
196 genU8I8S32FMA<INST_SET>(
197 a,
198 actReg_V_,
199 WRegs(r * this->S_ + s),
200 x86::Ymm(9),
201 oneReg16Bit_V_,
202 tmpReg1_V_);
203 } else {
204 if (this->C_per_G_ == 8) {
205 if (use_zero_reg) {
206 a->vmovapd(actReg_V_, zeroPTReg_V_);
207 } else {
208 a->vbroadcastsd(
209 actReg_V_,
210 x86::qword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t)));
211 }
212 } else {
213 // this->C_per_G_ == 16
214 if (use_zero_reg) {
215 a->vmovapd(actReg_V_, zeroPTReg_V_);
216 } else {
217 a->vbroadcasti128(
218 actReg_V_,
219 x86::oword_ptr(in_acts_R_, act_s * this->C_ * sizeof(uint8_t)));
220 }
221 }
222 // row offset
223 if (this->needRowOffset_) {
224 genU8Sum8(a, actReg_V_, rowOffsetReg_V_, tmpReg1_V_);
225 }
226 int kLoopMultiplier = 32 / this->C_per_G_;
227 for (int k = 0; k < kLoopIters_; ++k) {
228 a->vmovaps(
229 WRegs(0),
230 x86::dword_ptr(
231 wghts_R_,
232 (((r * this->S_ + s) * this->K_per_G_) + k * kLoopMultiplier) *
233 this->C_per_G_ * sizeof(int8_t)));
234 // FMA result is not final reduction on C_per_G, producing 8 output in
235 // which consectutive 2 elements if summedforms one final output over
236 // K_Per_G dimension
237 genU8I8S32FMA<INST_SET>(
238 a, actReg_V_, WRegs(0), x86::Ymm(9 - k), oneReg16Bit_V_, tmpReg1_V_);
239 }
240 }
241}
242
243#define GENCONVKERNEL_FUNCS(S, IN) \
244 template void GenConvKernel<S, IN>::genForLoadingWeights<IN>( \
245 x86::Emitter * a); \
246 template void GenConvKernel<S, IN>::genConstForPermutations<IN>( \
247 x86::Emitter * a); \
248 template void GenConvKernel<S, IN>::genForSingleFilterPoint<IN>( \
249 x86::Emitter * a, int r, int s, int act_s, bool use_zero_reg); \
250 template void GenConvKernel<S, IN>::storeResult<IN>(x86::Emitter * a); \
251 template void GenConvKernel<S, IN>::storeOffset<IN>(x86::Emitter * a);
252GENCONVKERNEL_FUNCS(1, inst_set_t::avx2)
253GENCONVKERNEL_FUNCS(2, inst_set_t::avx2)
254GENCONVKERNEL_FUNCS(3, inst_set_t::avx2)
255#undef GENCONVKERNEL_FUNCS
256
257template class GenConvKernel<1, inst_set_t::avx2>;
258template class GenConvKernel<2, inst_set_t::avx2>;
259template class GenConvKernel<3, inst_set_t::avx2>;
260
261} // namespace fbgemm
262