1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8#define FBGEMM_EXPORTS
9#include "./OptimizedKernelsAvx2.h"
10#if defined(__x86_64__) || defined(__i386__) || \
11 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
12#include <immintrin.h>
13#endif
14
15namespace fbgemm {
16
17int32_t reduceAvx2(const uint8_t* A, int len) {
18 int32_t row_sum = 0;
19#if defined(__AVX2__)
20 __m256i sum_v = _mm256_setzero_si256();
21 __m256i one_epi16_v = _mm256_set1_epi16(1);
22 __m256i one_epi8_v = _mm256_set1_epi8(1);
23
24 int i;
25 // vectorized
26 for (i = 0; i < len / 32 * 32; i += 32) {
27 __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
28 sum_v = _mm256_add_epi32(
29 sum_v,
30 _mm256_madd_epi16(
31 _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v));
32 }
33
34 alignas(64) int32_t temp[8];
35 _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
36 for (int k = 0; k < 8; ++k) {
37 row_sum += temp[k];
38 }
39
40 // scalar
41 for (; i < len; ++i) {
42 row_sum += A[i];
43 }
44
45#else
46 for (int i = 0; i < len; ++i) {
47 row_sum += A[i];
48 }
49#endif
50 return row_sum;
51}
52
53void transpose_8rows(
54 int N,
55 const uint8_t* src,
56 int ld_src,
57 uint8_t* dst,
58 int ld_dst) {
59 constexpr int M = 8;
60 int j;
61 // vectorized loop
62 for (j = 0; j < N / 32 * 32; j += 32) {
63 // a : a0 a1 ... a31
64 // b : b0 b1 ... b31
65 // c : c0 c1 ... c31
66 // d : d0 d1 ... d31
67 __m256i a = _mm256_lddqu_si256(
68 reinterpret_cast<const __m256i*>(src + j + 0 * ld_src));
69 __m256i b = _mm256_lddqu_si256(
70 reinterpret_cast<const __m256i*>(src + j + 1 * ld_src));
71 __m256i c = _mm256_lddqu_si256(
72 reinterpret_cast<const __m256i*>(src + j + 2 * ld_src));
73 __m256i d = _mm256_lddqu_si256(
74 reinterpret_cast<const __m256i*>(src + j + 3 * ld_src));
75 __m256i e = _mm256_lddqu_si256(
76 reinterpret_cast<const __m256i*>(src + j + 4 * ld_src));
77 __m256i f = _mm256_lddqu_si256(
78 reinterpret_cast<const __m256i*>(src + j + 5 * ld_src));
79 __m256i g = _mm256_lddqu_si256(
80 reinterpret_cast<const __m256i*>(src + j + 6 * ld_src));
81 __m256i h = _mm256_lddqu_si256(
82 reinterpret_cast<const __m256i*>(src + j + 7 * ld_src));
83
84 // even-odd interleaving
85 // ab_lo : a0 b0 a1 b1 ... a7 b7 | a16 b16 ... a23 b23
86 // ab_hi : a8 b8 a9 b9 ... a15 b15 | a24 b24 ... a31 b31
87 // cd_lo : c0 d0 c1 d1 ... c7 d7 | c16 d16 ... c23 d23
88 // cd_hi : c8 d8 c9 d9 ... c15 d15 | c24 d24 ... c31 d31
89 __m256i ab_lo = _mm256_unpacklo_epi8(a, b);
90 __m256i ab_hi = _mm256_unpackhi_epi8(a, b);
91 __m256i cd_lo = _mm256_unpacklo_epi8(c, d);
92 __m256i cd_hi = _mm256_unpackhi_epi8(c, d);
93 __m256i ef_lo = _mm256_unpacklo_epi8(e, f);
94 __m256i ef_hi = _mm256_unpackhi_epi8(e, f);
95 __m256i gh_lo = _mm256_unpacklo_epi8(g, h);
96 __m256i gh_hi = _mm256_unpackhi_epi8(g, h);
97
98 // 4-row interleaving but permuted at 128-bit granularity
99 // abcd0 : a0 b0 c0 d0 ... a-d3 | a-d16 ... a-d19
100 // abcd1 : a4 b4 c4 d4 ... a-d7 | a-d20 ... a-d23
101 // abcd2 : a8 b8 c8 d8 ... a-d11 | a-d24 ... a-d27
102 // abcd3 : a12 b12 c12 d12 ... a-d15 | a-d28 ... a-d31
103 __m256i abcd0 = _mm256_unpacklo_epi16(ab_lo, cd_lo);
104 __m256i abcd1 = _mm256_unpackhi_epi16(ab_lo, cd_lo);
105 __m256i abcd2 = _mm256_unpacklo_epi16(ab_hi, cd_hi);
106 __m256i abcd3 = _mm256_unpackhi_epi16(ab_hi, cd_hi);
107 __m256i efgh0 = _mm256_unpacklo_epi16(ef_lo, gh_lo);
108 __m256i efgh1 = _mm256_unpackhi_epi16(ef_lo, gh_lo);
109 __m256i efgh2 = _mm256_unpacklo_epi16(ef_hi, gh_hi);
110 __m256i efgh3 = _mm256_unpackhi_epi16(ef_hi, gh_hi);
111
112 // 8-row interleaving
113 __m256i y0 = _mm256_unpacklo_epi32(abcd0, efgh0);
114 __m256i y1 = _mm256_unpackhi_epi32(abcd0, efgh0);
115 __m256i y2 = _mm256_unpacklo_epi32(abcd1, efgh1);
116 __m256i y3 = _mm256_unpackhi_epi32(abcd1, efgh1);
117 __m256i y4 = _mm256_unpacklo_epi32(abcd2, efgh2);
118 __m256i y5 = _mm256_unpackhi_epi32(abcd2, efgh2);
119 __m256i y6 = _mm256_unpacklo_epi32(abcd3, efgh3);
120 __m256i y7 = _mm256_unpackhi_epi32(abcd3, efgh3);
121
122 // Storing with 128-bit lanes are permuted so that everything is in order
123 _mm_storel_epi64(
124 reinterpret_cast<__m128i*>(dst + (j + 0) * ld_dst),
125 _mm256_castsi256_si128(y0));
126 *reinterpret_cast<int64_t*>(dst + (j + 1) * ld_dst) =
127 _mm256_extract_epi64(y0, 1);
128 _mm_storel_epi64(
129 reinterpret_cast<__m128i*>(dst + (j + 2) * ld_dst),
130 _mm256_castsi256_si128(y1));
131 *reinterpret_cast<int64_t*>(dst + (j + 3) * ld_dst) =
132 _mm256_extract_epi64(y1, 1);
133 _mm_storel_epi64(
134 reinterpret_cast<__m128i*>(dst + (j + 4) * ld_dst),
135 _mm256_castsi256_si128(y2));
136 *reinterpret_cast<int64_t*>(dst + (j + 5) * ld_dst) =
137 _mm256_extract_epi64(y2, 1);
138 _mm_storel_epi64(
139 reinterpret_cast<__m128i*>(dst + (j + 6) * ld_dst),
140 _mm256_castsi256_si128(y3));
141 *reinterpret_cast<int64_t*>(dst + (j + 7) * ld_dst) =
142 _mm256_extract_epi64(y3, 1);
143 _mm_storel_epi64(
144 reinterpret_cast<__m128i*>(dst + (j + 8) * ld_dst),
145 _mm256_castsi256_si128(y4));
146 *reinterpret_cast<int64_t*>(dst + (j + 9) * ld_dst) =
147 _mm256_extract_epi64(y4, 1);
148 _mm_storel_epi64(
149 reinterpret_cast<__m128i*>(dst + (j + 10) * ld_dst),
150 _mm256_castsi256_si128(y5));
151 *reinterpret_cast<int64_t*>(dst + (j + 11) * ld_dst) =
152 _mm256_extract_epi64(y5, 1);
153 _mm_storel_epi64(
154 reinterpret_cast<__m128i*>(dst + (j + 12) * ld_dst),
155 _mm256_castsi256_si128(y6));
156 *reinterpret_cast<int64_t*>(dst + (j + 13) * ld_dst) =
157 _mm256_extract_epi64(y6, 1);
158 _mm_storel_epi64(
159 reinterpret_cast<__m128i*>(dst + (j + 14) * ld_dst),
160 _mm256_castsi256_si128(y7));
161 *reinterpret_cast<int64_t*>(dst + (j + 15) * ld_dst) =
162 _mm256_extract_epi64(y7, 1);
163 *reinterpret_cast<int64_t*>(dst + (j + 16) * ld_dst) =
164 _mm256_extract_epi64(y0, 2);
165 *reinterpret_cast<int64_t*>(dst + (j + 17) * ld_dst) =
166 _mm256_extract_epi64(y0, 3);
167 *reinterpret_cast<int64_t*>(dst + (j + 18) * ld_dst) =
168 _mm256_extract_epi64(y1, 2);
169 *reinterpret_cast<int64_t*>(dst + (j + 19) * ld_dst) =
170 _mm256_extract_epi64(y1, 3);
171 *reinterpret_cast<int64_t*>(dst + (j + 20) * ld_dst) =
172 _mm256_extract_epi64(y2, 2);
173 *reinterpret_cast<int64_t*>(dst + (j + 21) * ld_dst) =
174 _mm256_extract_epi64(y2, 3);
175 *reinterpret_cast<int64_t*>(dst + (j + 22) * ld_dst) =
176 _mm256_extract_epi64(y3, 2);
177 *reinterpret_cast<int64_t*>(dst + (j + 23) * ld_dst) =
178 _mm256_extract_epi64(y3, 3);
179 *reinterpret_cast<int64_t*>(dst + (j + 24) * ld_dst) =
180 _mm256_extract_epi64(y4, 2);
181 *reinterpret_cast<int64_t*>(dst + (j + 25) * ld_dst) =
182 _mm256_extract_epi64(y4, 3);
183 *reinterpret_cast<int64_t*>(dst + (j + 26) * ld_dst) =
184 _mm256_extract_epi64(y5, 2);
185 *reinterpret_cast<int64_t*>(dst + (j + 27) * ld_dst) =
186 _mm256_extract_epi64(y5, 3);
187 *reinterpret_cast<int64_t*>(dst + (j + 28) * ld_dst) =
188 _mm256_extract_epi64(y6, 2);
189 *reinterpret_cast<int64_t*>(dst + (j + 29) * ld_dst) =
190 _mm256_extract_epi64(y6, 3);
191 *reinterpret_cast<int64_t*>(dst + (j + 30) * ld_dst) =
192 _mm256_extract_epi64(y7, 2);
193 *reinterpret_cast<int64_t*>(dst + (j + 31) * ld_dst) =
194 _mm256_extract_epi64(y7, 3);
195 }
196
197 // scalar loop for remainder
198 for (; j < N; ++j) {
199 for (int i = 0; i < M; ++i) {
200 dst[j * ld_dst + i] = src[j + i * ld_src];
201 }
202 }
203}
204
205void spmdmKernelAvx2(
206 int N,
207 const uint8_t* A_buffer,
208 const int32_t* colptr,
209 const int8_t* values,
210 const int16_t* rowidx,
211 int32_t* C_buffer) {
212 for (int j = 0; j < N; ++j) {
213 int k = colptr[j];
214 int k_end_aligned = colptr[j] + (colptr[j + 1] - colptr[j]) / 4 * 4;
215
216 for (; k < k_end_aligned; k += 4) {
217 __m256i w =
218 _mm256_set1_epi32(*(reinterpret_cast<const int32_t*>(&values[k])));
219 __m256i a[4];
220 a[0] = _mm256_load_si256(
221 reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 0] * 32]));
222 a[1] = _mm256_load_si256(
223 reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 1] * 32]));
224 a[2] = _mm256_load_si256(
225 reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 2] * 32]));
226 a[3] = _mm256_load_si256(
227 reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 3] * 32]));
228
229 __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]);
230 __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]);
231 __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]);
232 __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]);
233
234 a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo);
235 a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo);
236 a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi);
237 a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi);
238
239 __m256i ab[4];
240 ab[0] = _mm256_maddubs_epi16(a[0], w);
241 ab[1] = _mm256_maddubs_epi16(a[1], w);
242 ab[2] = _mm256_maddubs_epi16(a[2], w);
243 ab[3] = _mm256_maddubs_epi16(a[3], w);
244
245 __m256i one = _mm256_set1_epi16(1);
246 ab[0] = _mm256_madd_epi16(ab[0], one);
247 ab[1] = _mm256_madd_epi16(ab[1], one);
248 ab[2] = _mm256_madd_epi16(ab[2], one);
249 ab[3] = _mm256_madd_epi16(ab[3], one);
250
251 __m256i t[4];
252 t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20);
253 t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20);
254 t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31);
255 t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31);
256
257 _mm256_store_si256(
258 reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]),
259 _mm256_add_epi32(
260 _mm256_load_si256(
261 reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 0 * 8])),
262 t[0]));
263 _mm256_store_si256(
264 reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]),
265 _mm256_add_epi32(
266 _mm256_load_si256(
267 reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 1 * 8])),
268 t[1]));
269 _mm256_store_si256(
270 reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]),
271 _mm256_add_epi32(
272 _mm256_load_si256(
273 reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 2 * 8])),
274 t[2]));
275 _mm256_store_si256(
276 reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]),
277 _mm256_add_epi32(
278 _mm256_load_si256(
279 reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 3 * 8])),
280 t[3]));
281 }
282
283 int remainder = colptr[j + 1] - k;
284 if (remainder > 0) {
285 int32_t temp_w = 0;
286 for (int r = 0; r < remainder; ++r) {
287 (reinterpret_cast<int8_t*>(&temp_w))[r] = values[k + r];
288 }
289 __m256i w = _mm256_set1_epi32(temp_w);
290 __m256i a[4];
291 a[0] = _mm256_load_si256(
292 reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 0] * 32]));
293 a[1] = remainder > 1 ? _mm256_load_si256(reinterpret_cast<const __m256i*>(
294 &A_buffer[rowidx[k + 1] * 32]))
295 : _mm256_setzero_si256();
296 a[2] = remainder > 2 ? _mm256_load_si256(reinterpret_cast<const __m256i*>(
297 &A_buffer[rowidx[k + 2] * 32]))
298 : _mm256_setzero_si256();
299 a[3] = _mm256_setzero_si256();
300
301 __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]);
302 __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]);
303 __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]);
304 __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]);
305
306 a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo);
307 a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo);
308 a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi);
309 a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi);
310
311 __m256i ab[4];
312 ab[0] = _mm256_maddubs_epi16(a[0], w);
313 ab[1] = _mm256_maddubs_epi16(a[1], w);
314 ab[2] = _mm256_maddubs_epi16(a[2], w);
315 ab[3] = _mm256_maddubs_epi16(a[3], w);
316
317 __m256i one = _mm256_set1_epi16(1);
318 ab[0] = _mm256_madd_epi16(ab[0], one);
319 ab[1] = _mm256_madd_epi16(ab[1], one);
320 ab[2] = _mm256_madd_epi16(ab[2], one);
321 ab[3] = _mm256_madd_epi16(ab[3], one);
322
323 __m256i t[4];
324 t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20);
325 t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20);
326 t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31);
327 t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31);
328
329 _mm256_store_si256(
330 reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]),
331 _mm256_add_epi32(
332 _mm256_load_si256(
333 reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 0 * 8])),
334 t[0]));
335 _mm256_store_si256(
336 reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]),
337 _mm256_add_epi32(
338 _mm256_load_si256(
339 reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 1 * 8])),
340 t[1]));
341 _mm256_store_si256(
342 reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]),
343 _mm256_add_epi32(
344 _mm256_load_si256(
345 reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 2 * 8])),
346 t[2]));
347 _mm256_store_si256(
348 reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]),
349 _mm256_add_epi32(
350 _mm256_load_si256(
351 reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 3 * 8])),
352 t[3]));
353 }
354 } // for each column of B
355}
356
357} // namespace fbgemm
358