1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | #include <asmjit/asmjit.h> |
9 | #include "fbgemm/SimdUtils.h" |
10 | #include "fbgemm/Utils.h" |
11 | |
12 | namespace fbgemm { |
13 | |
14 | namespace x86 = asmjit::x86; |
15 | |
16 | /** |
17 | * @brief Create instruction sequence to generate 16-bit 1s |
18 | * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm |
19 | * |
20 | * @param dest Once the instruction sequence is executed, |
21 | * dest[0:15] will have 0x0001, dest[16:31] |
22 | * will have 0x0001 and so on |
23 | */ |
24 | template < |
25 | inst_set_t instSet, |
26 | typename T, |
27 | typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0> |
28 | void gen16BitVectorOne(x86::Emitter* a, T dest) { |
29 | a->vpcmpeqw(dest, dest, dest); |
30 | a->vpsrlw(dest, dest, 15); |
31 | } |
32 | |
33 | template < |
34 | inst_set_t instSet, |
35 | typename T, |
36 | typename std::enable_if< |
37 | instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm || |
38 | instSet == inst_set_t::avx512_vnni || |
39 | instSet == inst_set_t::avx512_vnni_ymm, |
40 | int>::type = 0> |
41 | void gen16BitVectorOne(x86::Emitter* a, T dest) { |
42 | a->vpternlogd(dest, dest, dest, 0xff); |
43 | a->vpsrlw(dest, dest, 15); |
44 | } |
45 | |
46 | /** |
47 | * @brief Emit instruction do load 32-bit integer. AVX512 has |
48 | * different instrunction to load registers with index >= 16 |
49 | * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm |
50 | * |
51 | * @param dest Destination vector register |
52 | */ |
53 | template < |
54 | inst_set_t instSet, |
55 | typename T, |
56 | typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0> |
57 | void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) { |
58 | a->vmovdqa(dest, ptr); |
59 | } |
60 | |
61 | template < |
62 | inst_set_t instSet, |
63 | typename T, |
64 | typename std::enable_if< |
65 | instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm || |
66 | instSet == inst_set_t::avx512_vnni || |
67 | instSet == inst_set_t::avx512_vnni_ymm, |
68 | int>::type = 0> |
69 | void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) { |
70 | a->vmovdqa32(dest, ptr); |
71 | } |
72 | |
73 | /** |
74 | * @brief Emit partial extract from Wide regiter to Half Register, eg. |
75 | * Zmm -> Ymm or Ymm -> Xmm |
76 | * @tparam instSet instruction set to be used |
77 | * |
78 | * @param half Destination (half) vector register |
79 | * @param vec Source (full) vector register |
80 | * @param idx Index of of the half vector 0 or 1 |
81 | */ |
82 | template < |
83 | inst_set_t instSet, |
84 | typename T, |
85 | typename std::enable_if< |
86 | instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm || |
87 | instSet == inst_set_t::avx512_vnni || |
88 | instSet == inst_set_t::avx512_vnni_ymm, |
89 | int>::type = 0> |
90 | void ( |
91 | x86::Emitter* a, |
92 | x86::Ymm half, |
93 | const x86::Zmm vec, |
94 | int idx) { |
95 | a->vextracti32x8(half, vec, idx); |
96 | } |
97 | |
98 | template < |
99 | inst_set_t instSet, |
100 | typename T, |
101 | typename std::enable_if< |
102 | instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm || |
103 | instSet == inst_set_t::avx512_vnni || |
104 | instSet == inst_set_t::avx512_vnni_ymm, |
105 | int>::type = 0> |
106 | void ( |
107 | x86::Emitter* a, |
108 | x86::Xmm half, |
109 | x86::Ymm vec, |
110 | int idx) { |
111 | a->vextracti32x4(half, vec, idx); |
112 | } |
113 | |
114 | template < |
115 | inst_set_t instSet, |
116 | typename T, |
117 | typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0> |
118 | void ( |
119 | x86::Emitter* a, |
120 | x86::Xmm half, |
121 | x86::Ymm vec, |
122 | int idx) { |
123 | a->vextracti128(half, vec, idx); |
124 | } |
125 | |
126 | /** |
127 | * @brief Create instruction sequence to generate 8-bit 1s |
128 | * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm |
129 | * |
130 | * @param dest Once the instruction sequence is executed, |
131 | * dest[0:7] will have 0x01, dest[8:15] |
132 | * will have 0x01 and so on |
133 | */ |
134 | template < |
135 | typename T, |
136 | typename std::enable_if<std::is_same<T, x86::Ymm>::value, int>::type = 0> |
137 | void gen8BitVectorOne(x86::Emitter* a, T dest) { |
138 | a->vpcmpeqw(dest, dest, dest); |
139 | a->vpabsb(dest, dest); |
140 | } |
141 | |
142 | template < |
143 | typename T, |
144 | typename std::enable_if<std::is_same<T, x86::Zmm>::value, int>::type = 0> |
145 | void gen8BitVectorOne(x86::Emitter* a, T dest) { |
146 | a->vpternlogd(dest, dest, dest, 0xff); |
147 | a->vpabsb(dest, dest); |
148 | } |
149 | |
150 | /** |
151 | * @brief Generates instruction sequence to compute s32 += U8 * I8 |
152 | * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm |
153 | * |
154 | * @param cReg contains result |
155 | * |
156 | */ |
157 | |
158 | template < |
159 | inst_set_t INST_SET, |
160 | typename std::enable_if< |
161 | INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512, |
162 | int>::type = 0> |
163 | void genU8I8S32FMA( |
164 | x86::Emitter* a, |
165 | typename simd_info<INST_SET>::vec_reg_t aReg, |
166 | typename simd_info<INST_SET>::vec_reg_t bReg, |
167 | typename simd_info<INST_SET>::vec_reg_t cReg, |
168 | typename simd_info<INST_SET>::vec_reg_t oneReg16Bit, |
169 | typename simd_info<INST_SET>::vec_reg_t tmpReg) { |
170 | a->vpmaddubsw(tmpReg, aReg, bReg); |
171 | a->vpmaddwd(tmpReg, oneReg16Bit, tmpReg); |
172 | a->vpaddd(cReg, tmpReg, cReg); |
173 | } |
174 | |
175 | template < |
176 | inst_set_t INST_SET, |
177 | typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0> |
178 | void genU8I8S32FMA( |
179 | x86::Emitter* a, |
180 | typename simd_info<INST_SET>::vec_reg_t aReg, |
181 | typename simd_info<INST_SET>::vec_reg_t bReg, |
182 | typename simd_info<INST_SET>::vec_reg_t cReg, |
183 | typename simd_info<INST_SET>::vec_reg_t /*oneReg16Bit*/, |
184 | typename simd_info<INST_SET>::vec_reg_t /*tmpReg*/) { |
185 | a->vpdpbusd(cReg, aReg, bReg); |
186 | } |
187 | |
188 | /** |
189 | * @brief Add 4 consecutive numbers of type uint8 |
190 | * and emit their sum as 32-bit numbers. |
191 | * i.e., dest[0:31] contains |
192 | * src[0:7] + src[8:15] + src[16:23] + src[24:31] |
193 | * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm |
194 | * |
195 | * @param dest contains result |
196 | * |
197 | */ |
198 | template < |
199 | inst_set_t INST_SET, |
200 | typename std::enable_if< |
201 | INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512, |
202 | int>::type = 0> |
203 | void genU8Sum4( |
204 | x86::Emitter* a, |
205 | typename simd_info<INST_SET>::vec_reg_t src, |
206 | typename simd_info<INST_SET>::vec_reg_t dest, |
207 | typename simd_info<INST_SET>::vec_reg_t oneReg16Bit, |
208 | typename simd_info<INST_SET>::vec_reg_t tmpReg) { |
209 | gen8BitVectorOne(a, tmpReg); |
210 | a->vpmaddubsw(tmpReg, src, tmpReg); |
211 | a->vpmaddwd(tmpReg, tmpReg, oneReg16Bit); |
212 | a->vpaddd(dest, tmpReg, dest); |
213 | /*a->vxorps(tmpReg, tmpReg, tmpReg);*/ |
214 | /*a->vmpsadbw(tmpReg, src, tmpReg, static_cast<asmjit::Imm>(0));*/ |
215 | /*a->vpermilps(tmpReg, tmpReg, static_cast<asmjit::Imm>(4));*/ |
216 | /*a->vpmovzxwd(tmpReg, tmpReg.half());*/ |
217 | /*a->vpaddd(dest, tmpReg, dest);*/ |
218 | } |
219 | |
220 | template < |
221 | inst_set_t INST_SET, |
222 | typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0> |
223 | void genU8Sum4( |
224 | x86::Emitter* a, |
225 | typename simd_info<INST_SET>::vec_reg_t src, |
226 | typename simd_info<INST_SET>::vec_reg_t dest, |
227 | typename simd_info<INST_SET>::vec_reg_t /*oneReg16Bit*/, |
228 | typename simd_info<INST_SET>::vec_reg_t tmpReg) { |
229 | gen8BitVectorOne(a, tmpReg); |
230 | a->vpdpbusd(dest, src, tmpReg); |
231 | } |
232 | |
233 | /** |
234 | * @brief Add 8 consecutive numbers of type uint8 |
235 | * and emit their sum as 16-bit numbers. |
236 | * i.e., dest[0:15] contains |
237 | * src[0:7] + src[8:15] + src[16:23] + src[24:31] |
238 | * src[32:39] + src[40:47] + src[48:55] + src[56:63] |
239 | * |
240 | * and |
241 | * |
242 | * dest[64:79] contains |
243 | * src[64:71] + src[71:79] + src[80:87] + src[88:95] |
244 | * src[96:103] + src[104:111] + src[112:119] + src[120:127] |
245 | * |
246 | * so on |
247 | * |
248 | * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm |
249 | * |
250 | * @param dest contains result |
251 | * |
252 | */ |
253 | template <typename T> |
254 | void genU8Sum8(x86::Emitter* a, T src, T dest, T tmpReg) { |
255 | a->vxorps(tmpReg, tmpReg, tmpReg); |
256 | a->vpsadbw(tmpReg, src, tmpReg); |
257 | a->vpaddd(dest, tmpReg, dest); |
258 | } |
259 | |
260 | /** |
261 | * @brief Broadcast lower 8-bits of src to destination vector |
262 | * register. |
263 | */ |
264 | template <typename T> |
265 | void broadcast8Bit(x86::Emitter* a, x86::Gp src, T dest) { |
266 | // move src to dest |
267 | auto xmm = dest.xmm(); |
268 | a->movq(xmm, src); |
269 | a->vpbroadcastb(dest, xmm); |
270 | } |
271 | |
272 | } // namespace fbgemm |
273 | |