1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8#include <asmjit/asmjit.h>
9#include "fbgemm/SimdUtils.h"
10#include "fbgemm/Utils.h"
11
12namespace fbgemm {
13
14namespace x86 = asmjit::x86;
15
16/**
17 * @brief Create instruction sequence to generate 16-bit 1s
18 * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
19 *
20 * @param dest Once the instruction sequence is executed,
21 * dest[0:15] will have 0x0001, dest[16:31]
22 * will have 0x0001 and so on
23 */
24template <
25 inst_set_t instSet,
26 typename T,
27 typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
28void gen16BitVectorOne(x86::Emitter* a, T dest) {
29 a->vpcmpeqw(dest, dest, dest);
30 a->vpsrlw(dest, dest, 15);
31}
32
33template <
34 inst_set_t instSet,
35 typename T,
36 typename std::enable_if<
37 instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
38 instSet == inst_set_t::avx512_vnni ||
39 instSet == inst_set_t::avx512_vnni_ymm,
40 int>::type = 0>
41void gen16BitVectorOne(x86::Emitter* a, T dest) {
42 a->vpternlogd(dest, dest, dest, 0xff);
43 a->vpsrlw(dest, dest, 15);
44}
45
46/**
47 * @brief Emit instruction do load 32-bit integer. AVX512 has
48 * different instrunction to load registers with index >= 16
49 * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
50 *
51 * @param dest Destination vector register
52 */
53template <
54 inst_set_t instSet,
55 typename T,
56 typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
57void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
58 a->vmovdqa(dest, ptr);
59}
60
61template <
62 inst_set_t instSet,
63 typename T,
64 typename std::enable_if<
65 instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
66 instSet == inst_set_t::avx512_vnni ||
67 instSet == inst_set_t::avx512_vnni_ymm,
68 int>::type = 0>
69void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
70 a->vmovdqa32(dest, ptr);
71}
72
73/**
74 * @brief Emit partial extract from Wide regiter to Half Register, eg.
75 * Zmm -> Ymm or Ymm -> Xmm
76 * @tparam instSet instruction set to be used
77 *
78 * @param half Destination (half) vector register
79 * @param vec Source (full) vector register
80 * @param idx Index of of the half vector 0 or 1
81 */
82template <
83 inst_set_t instSet,
84 typename T,
85 typename std::enable_if<
86 instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
87 instSet == inst_set_t::avx512_vnni ||
88 instSet == inst_set_t::avx512_vnni_ymm,
89 int>::type = 0>
90void emitExtractHalfVector(
91 x86::Emitter* a,
92 x86::Ymm half,
93 const x86::Zmm vec,
94 int idx) {
95 a->vextracti32x8(half, vec, idx);
96}
97
98template <
99 inst_set_t instSet,
100 typename T,
101 typename std::enable_if<
102 instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
103 instSet == inst_set_t::avx512_vnni ||
104 instSet == inst_set_t::avx512_vnni_ymm,
105 int>::type = 0>
106void emitExtractHalfVector(
107 x86::Emitter* a,
108 x86::Xmm half,
109 x86::Ymm vec,
110 int idx) {
111 a->vextracti32x4(half, vec, idx);
112}
113
114template <
115 inst_set_t instSet,
116 typename T,
117 typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
118void emitExtractHalfVector(
119 x86::Emitter* a,
120 x86::Xmm half,
121 x86::Ymm vec,
122 int idx) {
123 a->vextracti128(half, vec, idx);
124}
125
126/**
127 * @brief Create instruction sequence to generate 8-bit 1s
128 * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
129 *
130 * @param dest Once the instruction sequence is executed,
131 * dest[0:7] will have 0x01, dest[8:15]
132 * will have 0x01 and so on
133 */
134template <
135 typename T,
136 typename std::enable_if<std::is_same<T, x86::Ymm>::value, int>::type = 0>
137void gen8BitVectorOne(x86::Emitter* a, T dest) {
138 a->vpcmpeqw(dest, dest, dest);
139 a->vpabsb(dest, dest);
140}
141
142template <
143 typename T,
144 typename std::enable_if<std::is_same<T, x86::Zmm>::value, int>::type = 0>
145void gen8BitVectorOne(x86::Emitter* a, T dest) {
146 a->vpternlogd(dest, dest, dest, 0xff);
147 a->vpabsb(dest, dest);
148}
149
150/**
151 * @brief Generates instruction sequence to compute s32 += U8 * I8
152 * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
153 *
154 * @param cReg contains result
155 *
156 */
157
158template <
159 inst_set_t INST_SET,
160 typename std::enable_if<
161 INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512,
162 int>::type = 0>
163void genU8I8S32FMA(
164 x86::Emitter* a,
165 typename simd_info<INST_SET>::vec_reg_t aReg,
166 typename simd_info<INST_SET>::vec_reg_t bReg,
167 typename simd_info<INST_SET>::vec_reg_t cReg,
168 typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
169 typename simd_info<INST_SET>::vec_reg_t tmpReg) {
170 a->vpmaddubsw(tmpReg, aReg, bReg);
171 a->vpmaddwd(tmpReg, oneReg16Bit, tmpReg);
172 a->vpaddd(cReg, tmpReg, cReg);
173}
174
175template <
176 inst_set_t INST_SET,
177 typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0>
178void genU8I8S32FMA(
179 x86::Emitter* a,
180 typename simd_info<INST_SET>::vec_reg_t aReg,
181 typename simd_info<INST_SET>::vec_reg_t bReg,
182 typename simd_info<INST_SET>::vec_reg_t cReg,
183 typename simd_info<INST_SET>::vec_reg_t /*oneReg16Bit*/,
184 typename simd_info<INST_SET>::vec_reg_t /*tmpReg*/) {
185 a->vpdpbusd(cReg, aReg, bReg);
186}
187
188/**
189 * @brief Add 4 consecutive numbers of type uint8
190 * and emit their sum as 32-bit numbers.
191 * i.e., dest[0:31] contains
192 * src[0:7] + src[8:15] + src[16:23] + src[24:31]
193 * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
194 *
195 * @param dest contains result
196 *
197 */
198template <
199 inst_set_t INST_SET,
200 typename std::enable_if<
201 INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512,
202 int>::type = 0>
203void genU8Sum4(
204 x86::Emitter* a,
205 typename simd_info<INST_SET>::vec_reg_t src,
206 typename simd_info<INST_SET>::vec_reg_t dest,
207 typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
208 typename simd_info<INST_SET>::vec_reg_t tmpReg) {
209 gen8BitVectorOne(a, tmpReg);
210 a->vpmaddubsw(tmpReg, src, tmpReg);
211 a->vpmaddwd(tmpReg, tmpReg, oneReg16Bit);
212 a->vpaddd(dest, tmpReg, dest);
213 /*a->vxorps(tmpReg, tmpReg, tmpReg);*/
214 /*a->vmpsadbw(tmpReg, src, tmpReg, static_cast<asmjit::Imm>(0));*/
215 /*a->vpermilps(tmpReg, tmpReg, static_cast<asmjit::Imm>(4));*/
216 /*a->vpmovzxwd(tmpReg, tmpReg.half());*/
217 /*a->vpaddd(dest, tmpReg, dest);*/
218}
219
220template <
221 inst_set_t INST_SET,
222 typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0>
223void genU8Sum4(
224 x86::Emitter* a,
225 typename simd_info<INST_SET>::vec_reg_t src,
226 typename simd_info<INST_SET>::vec_reg_t dest,
227 typename simd_info<INST_SET>::vec_reg_t /*oneReg16Bit*/,
228 typename simd_info<INST_SET>::vec_reg_t tmpReg) {
229 gen8BitVectorOne(a, tmpReg);
230 a->vpdpbusd(dest, src, tmpReg);
231}
232
233/**
234 * @brief Add 8 consecutive numbers of type uint8
235 * and emit their sum as 16-bit numbers.
236 * i.e., dest[0:15] contains
237 * src[0:7] + src[8:15] + src[16:23] + src[24:31]
238 * src[32:39] + src[40:47] + src[48:55] + src[56:63]
239 *
240 * and
241 *
242 * dest[64:79] contains
243 * src[64:71] + src[71:79] + src[80:87] + src[88:95]
244 * src[96:103] + src[104:111] + src[112:119] + src[120:127]
245 *
246 * so on
247 *
248 * @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
249 *
250 * @param dest contains result
251 *
252 */
253template <typename T>
254void genU8Sum8(x86::Emitter* a, T src, T dest, T tmpReg) {
255 a->vxorps(tmpReg, tmpReg, tmpReg);
256 a->vpsadbw(tmpReg, src, tmpReg);
257 a->vpaddd(dest, tmpReg, dest);
258}
259
260/**
261 * @brief Broadcast lower 8-bits of src to destination vector
262 * register.
263 */
264template <typename T>
265void broadcast8Bit(x86::Emitter* a, x86::Gp src, T dest) {
266 // move src to dest
267 auto xmm = dest.xmm();
268 a->movq(xmm, src);
269 a->vpbroadcastb(dest, xmm);
270}
271
272} // namespace fbgemm
273