1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8
9#if defined(__x86_64__) || defined(__i386__) || \
10 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
11#include <immintrin.h>
12#endif
13#include <cassert>
14#include <cstdint>
15
16#include "./MaskAvx2.h"
17
18namespace fbgemm {
19
20namespace internal {
21
22#ifdef __AVX2__
23// NOTE: Make sure every function defined in here has static linkage because
24// this header file is included by UtilsAvx512.cc compiled with -mavx512f option
25
26// 4 * 4 = 16 instructions
27static inline void transpose_kernel_4x4_sse(
28 const float* src,
29 int64_t ld_src,
30 float* dst,
31 int64_t ld_dst) {
32 // load from src to registers
33 // a : a0 a1 a2 a3
34 // b : b0 b1 b2 b3
35 // c : c0 c1 c2 c3
36 // d : d0 d1 d2 d3
37 __m128 a = _mm_loadu_ps(&src[0 * ld_src]);
38 __m128 b = _mm_loadu_ps(&src[1 * ld_src]);
39 __m128 c = _mm_loadu_ps(&src[2 * ld_src]);
40 __m128 d = _mm_loadu_ps(&src[3 * ld_src]);
41
42 // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE
43 // a : a0 b0 c0 d0
44 // b : a1 b1 c1 d1
45 // c : a2 b2 c2 d2
46 // d : a3 b3 c3 d3
47 _MM_TRANSPOSE4_PS(a, b, c, d);
48
49 // store from registers to dst
50 _mm_storeu_ps(&dst[0 * ld_dst], a);
51 _mm_storeu_ps(&dst[1 * ld_dst], b);
52 _mm_storeu_ps(&dst[2 * ld_dst], c);
53 _mm_storeu_ps(&dst[3 * ld_dst], d);
54}
55
56// kernel for transpose mxn where m, n <= 4
57// M + (M + 1) / 2 * 2 + 2 * N instructions
58template <unsigned M>
59static void transpose_kernel_mxn_sse(
60 unsigned N,
61 const float* src,
62 int64_t ld_src,
63 float* dst,
64 int64_t ld_dst) {
65 // clang-format off
66 alignas(64) static const int masks[5][4] = {
67 { 0, 0, 0, 0, },
68 { -1, 0, 0, 0, },
69 { -1, -1, 0, 0, },
70 { -1, -1, -1, 0, },
71 { -1, -1, -1, -1, },
72 };
73 // clang-format on
74
75 // load from src to registers
76 __m128i mask_v = _mm_load_si128(reinterpret_cast<const __m128i*>(masks[N]));
77 __m128 input[4];
78 unsigned i;
79 for (i = 0; i < M; ++i) {
80 input[i] = _mm_maskload_ps(&src[i * ld_src], mask_v);
81 }
82 for (; i < 4; ++i) {
83 // Not really needed but to avoid uninitialized variable warning.
84 // Shouldn't be much overhead because xor can be executed in parallel with
85 // other instructions.
86 input[i] = _mm_setzero_ps();
87 }
88
89 __m128 temp[4];
90 for (i = 0; i < (M + 1) / 2; ++i) {
91 temp[2 * i] = _mm_unpacklo_ps(input[2 * i], input[2 * i + 1]);
92 temp[2 * i + 1] = _mm_unpackhi_ps(input[2 * i], input[2 * i + 1]);
93 }
94 for (i = i * 2; i < 4; ++i) {
95 temp[i] = _mm_setzero_ps();
96 }
97
98 mask_v = _mm_load_si128(reinterpret_cast<const __m128i*>(masks[M]));
99 for (i = 0; i < N; ++i) {
100 if (i % 2 == 0) {
101 input[i] = _mm_movelh_ps(temp[i / 2], temp[2 + i / 2]);
102 } else {
103 input[i] = _mm_movehl_ps(temp[2 + i / 2], temp[i / 2]);
104 }
105 _mm_maskstore_ps(&dst[i * ld_dst], mask_v, input[i]);
106 }
107}
108
109// 8 * 5 = 40 instructions
110static inline void transpose_kernel_8x8_avx2(
111 const float* src,
112 int64_t ld_src,
113 float* dst,
114 int64_t ld_dst) {
115 // load from src to registers
116 // a : a0 a1 a2 a3 a4 a5 a6 a7
117 // b : b0 b1 b2 b3 b4 b5 b6 b7
118 // c : c0 c1 c2 c3 c4 c5 c6 c7
119 // d : d0 d1 d2 d3 d4 d5 d6 d7
120 // e : e0 e1 e2 e3 e4 e5 e6 e7
121 // f : f0 f1 f2 f3 f4 f5 f6 f7
122 // g : g0 g1 g2 g3 g4 g5 g6 g7
123 // h : h0 h1 h2 h3 h4 h5 h6 h7
124 __m256 a = _mm256_loadu_ps(&src[0 * ld_src]);
125 __m256 b = _mm256_loadu_ps(&src[1 * ld_src]);
126 __m256 c = _mm256_loadu_ps(&src[2 * ld_src]);
127 __m256 d = _mm256_loadu_ps(&src[3 * ld_src]);
128 __m256 e = _mm256_loadu_ps(&src[4 * ld_src]);
129 __m256 f = _mm256_loadu_ps(&src[5 * ld_src]);
130 __m256 g = _mm256_loadu_ps(&src[6 * ld_src]);
131 __m256 h = _mm256_loadu_ps(&src[7 * ld_src]);
132
133 __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367;
134 __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37;
135 // unpacking and interleaving 32-bit elements
136 // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5
137 // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7
138 // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5
139 // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7
140 // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5
141 // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7
142 // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5
143 // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7
144 ab0145 = _mm256_unpacklo_ps(a, b);
145 ab2367 = _mm256_unpackhi_ps(a, b);
146 cd0145 = _mm256_unpacklo_ps(c, d);
147 cd2367 = _mm256_unpackhi_ps(c, d);
148 ef0145 = _mm256_unpacklo_ps(e, f);
149 ef2367 = _mm256_unpackhi_ps(e, f);
150 gh0145 = _mm256_unpacklo_ps(g, h);
151 gh2367 = _mm256_unpackhi_ps(g, h);
152
153 // shuffling the 32-bit elements
154 // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4
155 // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5
156 // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4
157 // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5
158 // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6
159 // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7
160 // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6
161 // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7
162 abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44);
163 abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee);
164 efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44);
165 efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee);
166 abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44);
167 abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee);
168 efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44);
169 efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee);
170
171 // shuffling 128-bit elements
172 // a : a0 b0 c0 d0 e0 f0 g0 h0
173 // b : a1 b1 c1 d1 e1 f1 g1 h1
174 // c : a2 b2 c2 d2 e2 f2 g2 h2
175 // d : a3 b3 c3 d3 e3 f3 g3 h3
176 // e : a4 b4 c4 d4 e4 f4 g4 h4
177 // f : a5 b5 c5 d5 e5 f5 g5 h5
178 // g : a6 b6 c6 d6 e6 f6 g6 h6
179 // h : a7 b7 c7 d7 e7 f7 g7 h7
180 a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02);
181 b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02);
182 c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02);
183 d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02);
184 e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13);
185 f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13);
186 g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13);
187 h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13);
188
189 // store from registers to dst
190 _mm256_storeu_ps(&dst[0 * ld_dst], a);
191 _mm256_storeu_ps(&dst[1 * ld_dst], b);
192 _mm256_storeu_ps(&dst[2 * ld_dst], c);
193 _mm256_storeu_ps(&dst[3 * ld_dst], d);
194 _mm256_storeu_ps(&dst[4 * ld_dst], e);
195 _mm256_storeu_ps(&dst[5 * ld_dst], f);
196 _mm256_storeu_ps(&dst[6 * ld_dst], g);
197 _mm256_storeu_ps(&dst[7 * ld_dst], h);
198}
199
200// kernel for transposing mxn where m, n <= 8
201// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + 2 * N instructions
202template <unsigned M>
203static void transpose_kernel_mxn_avx2(
204 unsigned N,
205 const float* src,
206 int64_t ld_src,
207 float* dst,
208 int64_t ld_dst) {
209 // load from src to registers
210 __m256i mask_v = _mm256_load_si256(
211 reinterpret_cast<const __m256i*>(internal::avx2_ps_or_epi32_masks[N]));
212 __m256 input[8];
213 unsigned i;
214 for (i = 0; i < M; ++i) {
215 input[i] = _mm256_maskload_ps(&src[i * ld_src], mask_v);
216 }
217 for (; i < 8; ++i) {
218 // Not really needed but to avoid uninitialized variable warning.
219 // Shouldn't be much overhead because xor can be executed in parallel with
220 // other instructions.
221 input[i] = _mm256_setzero_ps();
222 }
223
224 // unpacking and interleaving 32-bit elements
225 __m256 temp[8];
226 for (i = 0; i < (M + 1) / 2; ++i) {
227 temp[2 * i] = _mm256_unpacklo_ps(input[2 * i], input[2 * i + 1]);
228 temp[2 * i + 1] = _mm256_unpackhi_ps(input[2 * i], input[2 * i + 1]);
229 }
230 for (i = i * 2; i < 8; ++i) {
231 temp[i] = _mm256_setzero_ps();
232 }
233
234 // shuffling the 32-bit elements
235 for (i = 0; i < (M + 3) / 4; ++i) {
236 input[4 * i] = _mm256_shuffle_ps(temp[4 * i], temp[4 * i + 2], 0x44);
237 input[4 * i + 1] = _mm256_shuffle_ps(temp[4 * i], temp[4 * i + 2], 0xee);
238 input[4 * i + 2] =
239 _mm256_shuffle_ps(temp[4 * i + 1], temp[4 * i + 3], 0x44);
240 input[4 * i + 3] =
241 _mm256_shuffle_ps(temp[4 * i + 1], temp[4 * i + 3], 0xee);
242 }
243
244 // shuffling 128-bit elements
245 // store from registers to dst
246 mask_v = _mm256_load_si256(
247 reinterpret_cast<const __m256i*>(internal::avx2_ps_or_epi32_masks[M]));
248 for (i = 0; i < N; ++i) {
249 if (i < 4) {
250 temp[i] = _mm256_permute2f128_ps(input[4 + i], input[i], 0x02);
251 } else {
252 temp[i] = _mm256_permute2f128_ps(input[i], input[i - 4], 0x13);
253 }
254 _mm256_maskstore_ps(&dst[i * ld_dst], mask_v, temp[i]);
255 }
256}
257
258inline __m256i permute_row(__m256i row) {
259 // clang-format off
260 row = _mm256_shuffle_epi8(
261 row,
262 _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
263 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0));
264 // clang-format on
265 return row;
266}
267
268// template <>
269inline static void transpose_kernel_8x32_avx2(
270 const uint8_t* src,
271 int64_t ld_src,
272 uint8_t* dst,
273 int64_t ld_dst) {
274 // load from src to registers
275 // a : a0 a1 a2 a3 a4 a5 a6 a7 ... a31
276 // b : b0 b1 b2 b3 b4 b5 b6 b7 ... b31
277 // c : c0 c1 c2 c3 c4 c5 c6 c7 ... c31
278 // d : d0 d1 d2 d3 d4 d5 d6 d7 ... d31
279 // e : e0 e1 e2 e3 e4 e5 e6 e7 ... e31
280 // f : f0 f1 f2 f3 f4 f5 f6 f7 ... f31
281 // g : g0 g1 g2 g3 g4 g5 g6 g7 ... g31
282 // h : h0 h1 h2 h3 h4 h5 h6 h7 ... h31
283
284 // load from src
285 __m256i a = _mm256_loadu_si256(
286 reinterpret_cast<const __m256i*>((src) + (0 * ld_src)));
287 __m256i b = _mm256_loadu_si256(
288 reinterpret_cast<const __m256i*>((src) + (1 * ld_src)));
289 __m256i c = _mm256_loadu_si256(
290 reinterpret_cast<const __m256i*>((src) + (2 * ld_src)));
291 __m256i d = _mm256_loadu_si256(
292 reinterpret_cast<const __m256i*>((src) + (3 * ld_src)));
293 __m256i e = _mm256_loadu_si256(
294 reinterpret_cast<const __m256i*>((src) + (4 * ld_src)));
295 __m256i f = _mm256_loadu_si256(
296 reinterpret_cast<const __m256i*>((src) + (5 * ld_src)));
297 __m256i g = _mm256_loadu_si256(
298 reinterpret_cast<const __m256i*>((src) + (6 * ld_src)));
299 __m256i h = _mm256_loadu_si256(
300 reinterpret_cast<const __m256i*>((src) + (7 * ld_src)));
301
302 // shuffle in stride of one:
303 // t0 : a0 -- a3, b0 -- b3, a4 -- a7, b4 -- b7,
304 // a16 -- a19, b16 -- b19, a20 -- a23, b20 -- b23
305
306 // t1 : a8 -- a11, b8 -- b11, a12 -- a15, b12 -- b15,
307 // a24 -- a27, b24 -- b27, a28 -- a31, b28 -- b31
308
309 // t2 : c0 -- c3, d0 -- d3, c4 -- c7, d4 -- d7,
310 // c16 -- c19, d16 -- d19, c20 -- c23, d20 -- d23
311
312 __m256i __t0 = _mm256_unpacklo_epi32(a, b);
313 __m256i __t1 = _mm256_unpackhi_epi32(a, b);
314 __m256i __t2 = _mm256_unpacklo_epi32(c, d);
315 __m256i __t3 = _mm256_unpackhi_epi32(c, d);
316 __m256i __t4 = _mm256_unpacklo_epi32(e, f);
317 __m256i __t5 = _mm256_unpackhi_epi32(e, f);
318 __m256i __t6 = _mm256_unpacklo_epi32(g, h);
319 __m256i __t7 = _mm256_unpackhi_epi32(g, h);
320
321 // shuffle in stride of two:
322 // tt0: a0--a3, b0--b3, c0--c3, d0--d3,
323 // a16--a19, b16 -- b19, c16 -- c19, d16--d19
324
325 // tt1: a4 -- a7, b4 -- b7, c8--c11, d8--d11,
326 // a20--a23, b20--b23, c20--c23, d20--d23
327
328 // tt2: a8 -- a11, b8 -- b11, c8 -- c11, d8 -- d11,
329 // a24 -- a27, b24 -- b27, c24 -- c27, d24 -- d27
330
331 // tt3: a12 -- a15, b12 -- b15, c12--c15, d12--d15,
332 // a28--a31, b28--b31, c28--c31, d28--d31
333
334 // tt4: e0--e3, f0--f3, g0--h3, g0--h3,
335 // e16--e19, f16--f19, g16--h19, g16--h19
336 __m256i __tt0 = _mm256_unpacklo_epi64(__t0, __t2);
337 __m256i __tt1 = _mm256_unpackhi_epi64(__t0, __t2);
338 __m256i __tt2 = _mm256_unpacklo_epi64(__t1, __t3);
339 __m256i __tt3 = _mm256_unpackhi_epi64(__t1, __t3);
340 __m256i __tt4 = _mm256_unpacklo_epi64(__t4, __t6);
341 __m256i __tt5 = _mm256_unpackhi_epi64(__t4, __t6);
342 __m256i __tt6 = _mm256_unpacklo_epi64(__t5, __t7);
343 __m256i __tt7 = _mm256_unpackhi_epi64(__t5, __t7);
344
345 // permute: pack consecutive elements(0-3) together
346 // ttt0: a0--d0 a1--d1 a2--d2 a3--d3 a16-d16 a17-d17 a18-d18 a18-d19
347
348 // ttt1: a4--d4 a5--d5 a6--d6 a7--d7 a20-d20 a21-d21 a22-d22 a23-d23
349
350 // ttt2: a8--d8 a9--d9 a10--d10 a11--d11 a24-d24 a25-d25 a26-d26 a27-d27
351 __m256i __ttt0 = permute_row(__tt0);
352 __m256i __ttt1 = permute_row(__tt1);
353 __m256i __ttt2 = permute_row(__tt2);
354 __m256i __ttt3 = permute_row(__tt3);
355 __m256i __ttt4 = permute_row(__tt4);
356 __m256i __ttt5 = permute_row(__tt5);
357 __m256i __ttt6 = permute_row(__tt6);
358 __m256i __ttt7 = permute_row(__tt7);
359
360 //
361 // a: a0-h0 a1-h1 a16-h16 a17-h17
362 // b: a2-h2 a3-h3 a18-h18 a19-h19
363
364 // c: a4-h4 a6-h6 a20-h20 a22-h22 (a-h)x(4-7)
365 // d: a5-h5 a7-h7 a21-h21 a23-h23 (a-h)x(20-23)
366
367 // e: a8-h8 a9-h9 a24-h24 a25-h25 (a-h)x(8-11)
368 // f: a10-h10 a11-h11 a26-h26 a27-h27 (a-h)x(24-27)
369
370 // g: (a-h)x(12-15)
371 // h: (a-h)x(28-31)
372 a = _mm256_unpacklo_epi32(__ttt0, __ttt4);
373 b = _mm256_unpackhi_epi32(__ttt0, __ttt4);
374 c = _mm256_unpacklo_epi32(__ttt1, __ttt5);
375 d = _mm256_unpackhi_epi32(__ttt1, __ttt5);
376 e = _mm256_unpacklo_epi32(__ttt2, __ttt6);
377 f = _mm256_unpackhi_epi32(__ttt2, __ttt6);
378 g = _mm256_unpacklo_epi32(__ttt3, __ttt7);
379 h = _mm256_unpackhi_epi32(__ttt3, __ttt7);
380
381 // stores back 32 rows:
382
383 reinterpret_cast<uint64_t*>(dst)[0] = _mm256_extract_epi64(a, 0);
384 reinterpret_cast<uint64_t*>((dst) + ld_dst)[0] = _mm256_extract_epi64(a, 1);
385 reinterpret_cast<uint64_t*>((dst) + ld_dst * 2)[0] =
386 _mm256_extract_epi64(b, 0);
387 reinterpret_cast<uint64_t*>((dst) + ld_dst * 3)[0] =
388 _mm256_extract_epi64(b, 1);
389
390 reinterpret_cast<uint64_t*>((dst) + ld_dst * 4)[0] =
391 _mm256_extract_epi64(c, 0);
392 reinterpret_cast<uint64_t*>((dst) + ld_dst * 5)[0] =
393 _mm256_extract_epi64(c, 1);
394 reinterpret_cast<uint64_t*>((dst) + ld_dst * 6)[0] =
395 _mm256_extract_epi64(d, 0);
396 reinterpret_cast<uint64_t*>((dst) + ld_dst * 7)[0] =
397 _mm256_extract_epi64(d, 1);
398
399 reinterpret_cast<uint64_t*>((dst) + ld_dst * 8)[0] =
400 _mm256_extract_epi64(e, 0);
401 reinterpret_cast<uint64_t*>((dst) + ld_dst * 9)[0] =
402 _mm256_extract_epi64(e, 1);
403 reinterpret_cast<uint64_t*>((dst) + ld_dst * 10)[0] =
404 _mm256_extract_epi64(f, 0);
405 reinterpret_cast<uint64_t*>((dst) + ld_dst * 11)[0] =
406 _mm256_extract_epi64(f, 1);
407
408 reinterpret_cast<uint64_t*>((dst) + ld_dst * 12)[0] =
409 _mm256_extract_epi64(g, 0);
410 reinterpret_cast<uint64_t*>((dst) + ld_dst * 13)[0] =
411 _mm256_extract_epi64(g, 1);
412 reinterpret_cast<uint64_t*>((dst) + ld_dst * 14)[0] =
413 _mm256_extract_epi64(h, 0);
414 reinterpret_cast<uint64_t*>((dst) + ld_dst * 15)[0] =
415 _mm256_extract_epi64(h, 1);
416
417 reinterpret_cast<uint64_t*>((dst) + ld_dst * 16)[0] =
418 _mm256_extract_epi64(a, 2);
419 reinterpret_cast<uint64_t*>((dst) + ld_dst * 17)[0] =
420 _mm256_extract_epi64(a, 3);
421 reinterpret_cast<uint64_t*>((dst) + ld_dst * 18)[0] =
422 _mm256_extract_epi64(b, 2);
423 reinterpret_cast<uint64_t*>((dst) + ld_dst * 19)[0] =
424 _mm256_extract_epi64(b, 3);
425
426 reinterpret_cast<uint64_t*>((dst) + ld_dst * 20)[0] =
427 _mm256_extract_epi64(c, 2);
428 reinterpret_cast<uint64_t*>((dst) + ld_dst * 21)[0] =
429 _mm256_extract_epi64(c, 3);
430 reinterpret_cast<uint64_t*>((dst) + ld_dst * 22)[0] =
431 _mm256_extract_epi64(d, 2);
432 reinterpret_cast<uint64_t*>((dst) + ld_dst * 23)[0] =
433 _mm256_extract_epi64(d, 3);
434
435 reinterpret_cast<uint64_t*>((dst) + ld_dst * 24)[0] =
436 _mm256_extract_epi64(e, 2);
437 reinterpret_cast<uint64_t*>((dst) + ld_dst * 25)[0] =
438 _mm256_extract_epi64(e, 3);
439 reinterpret_cast<uint64_t*>((dst) + ld_dst * 26)[0] =
440 _mm256_extract_epi64(f, 2);
441 reinterpret_cast<uint64_t*>((dst) + ld_dst * 27)[0] =
442 _mm256_extract_epi64(f, 3);
443
444 reinterpret_cast<uint64_t*>((dst) + ld_dst * 28)[0] =
445 _mm256_extract_epi64(g, 2);
446 reinterpret_cast<uint64_t*>((dst) + ld_dst * 29)[0] =
447 _mm256_extract_epi64(g, 3);
448 reinterpret_cast<uint64_t*>((dst) + ld_dst * 30)[0] =
449 _mm256_extract_epi64(h, 2);
450 reinterpret_cast<uint64_t*>((dst) + ld_dst * 31)[0] =
451 _mm256_extract_epi64(h, 3);
452}
453
454static inline void load_with_remainders_i16(
455 const uint16_t* src,
456 int64_t ld_src,
457 __m256i r[],
458 unsigned mrem,
459 unsigned nrem) {
460 if (nrem < 16) {
461 uint16_t local_buffer[16] = {0};
462 __m256i mask_nrem_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
463 internal::avx2_ps_or_epi32_masks[nrem / 2]));
464 unsigned half = nrem % 2;
465 for (unsigned i = 0; i < mrem; ++i) {
466 // mask load
467 r[i] = _mm256_maskload_epi32(
468 reinterpret_cast<const int*>(&src[i * ld_src]), mask_nrem_v);
469 if (half == 1) {
470 _mm256_storeu_si256(reinterpret_cast<__m256i*>(&local_buffer[0]), r[i]);
471 local_buffer[nrem - 1] = src[i * ld_src + nrem - 1];
472 r[i] = _mm256_loadu_si256(
473 reinterpret_cast<const __m256i*>(&local_buffer[0]));
474 }
475 }
476 } else {
477 for (unsigned i = 0; i < mrem; ++i) {
478 // normal load
479 r[i] = _mm256_loadu_si256(
480 reinterpret_cast<const __m256i*>(src + i * ld_src));
481 }
482 }
483}
484
485static inline void store_with_remainders_i16(
486 uint16_t* dst,
487 int64_t ld_dst,
488 __m256i u[],
489 unsigned mrem,
490 unsigned nrem) {
491 if (mrem < 8) {
492 uint16_t local_buffer[8] = {0};
493 __m256i mask_mrem_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
494 internal::avx2_ps_or_epi32_masks[mrem / 2]));
495 unsigned half = mrem % 2;
496 unsigned i = 0;
497 for (; i < nrem; i += 1) {
498 // mask store
499 int reg_idx = i % 8;
500 __m128i d;
501 if (i >= 8) {
502 d = _mm256_extractf128_si256(u[reg_idx], 1);
503 } else {
504 d = _mm256_extractf128_si256(u[reg_idx], 0);
505 }
506 _mm256_maskstore_epi32(
507 reinterpret_cast<int*>(dst + i * ld_dst),
508 mask_mrem_v,
509 _mm256_castsi128_si256(d));
510 if (half == 1) {
511 _mm_storeu_si128(reinterpret_cast<__m128i*>(local_buffer), d);
512 (dst + i * ld_dst)[mrem - 1] = local_buffer[mrem - 1];
513 }
514 }
515
516 } else {
517 unsigned i = 0;
518 for (; i < nrem; i += 1) {
519 // normal store
520 unsigned reg_idx = i % 8;
521 if (i >= 8) {
522 _mm_storeu_si128(
523 reinterpret_cast<__m128i*>(dst + i * ld_dst),
524 _mm256_extractf128_si256(u[reg_idx], 1));
525 } else {
526 _mm_storeu_si128(
527 reinterpret_cast<__m128i*>(dst + i * ld_dst),
528 _mm256_extractf128_si256(u[reg_idx], 0));
529 }
530 }
531 }
532}
533
534template <bool MREM = false, bool NREM = false>
535inline static void transpose_kernel_8x16_avx2(
536 const uint16_t* src,
537 int64_t ld_src,
538 uint16_t* dst,
539 int64_t ld_dst,
540 unsigned mrem = 8,
541 unsigned nrem = 16) {
542 __m256i r[8];
543 // load from src to registers
544 // a : a0 a1 a2 a3 a4 a5 a6 a7 ... a15
545 // b : b0 b1 b2 b3 b4 b5 b6 b7 ... b15
546 // c : c0 c1 c2 c3 c4 c5 c6 c7 ... c15
547 // d : d0 d1 d2 d3 d4 d5 d6 d7 ... d15
548 // e : e0 e1 e2 e3 e4 e5 e6 e7 ... e15
549 // f : f0 f1 f2 f3 f4 f5 f6 f7 ... f15
550 // g : g0 g1 g2 g3 g4 g5 g6 g7 ... g15
551 // h : h0 h1 h2 h3 h4 h5 h6 h7 ... h15
552 if (MREM || NREM) {
553 load_with_remainders_i16(src, ld_src, r, mrem, nrem);
554 } else {
555 r[0] = _mm256_loadu_si256(
556 reinterpret_cast<const __m256i*>((src) + (0 * ld_src)));
557 r[1] = _mm256_loadu_si256(
558 reinterpret_cast<const __m256i*>((src) + (1 * ld_src)));
559 r[2] = _mm256_loadu_si256(
560 reinterpret_cast<const __m256i*>((src) + (2 * ld_src)));
561 r[3] = _mm256_loadu_si256(
562 reinterpret_cast<const __m256i*>((src) + (3 * ld_src)));
563 r[4] = _mm256_loadu_si256(
564 reinterpret_cast<const __m256i*>((src) + (4 * ld_src)));
565 r[5] = _mm256_loadu_si256(
566 reinterpret_cast<const __m256i*>((src) + (5 * ld_src)));
567 r[6] = _mm256_loadu_si256(
568 reinterpret_cast<const __m256i*>((src) + (6 * ld_src)));
569 r[7] = _mm256_loadu_si256(
570 reinterpret_cast<const __m256i*>((src) + (7 * ld_src)));
571 }
572 // t0 : a0a1, b0b1, a2a3, b2b3,
573 // a8a9, b8b9, a10a11, b10b11
574
575 // t1 : a4a5, b4b5, a6a7, b6b7,
576 // a12a13, b12b13, a14a15, b14b15
577
578 // t2 : c0c1, d0d1, c2c3, d2d3,
579 // c8c9, d8d9, c10c11, d10d11
580
581 __m256i __t0 = _mm256_unpacklo_epi32(r[0], r[1]);
582 __m256i __t1 = _mm256_unpackhi_epi32(r[0], r[1]);
583 __m256i __t2 = _mm256_unpacklo_epi32(r[2], r[3]);
584 __m256i __t3 = _mm256_unpackhi_epi32(r[2], r[3]);
585 __m256i __t4 = _mm256_unpacklo_epi32(r[4], r[5]);
586 __m256i __t5 = _mm256_unpackhi_epi32(r[4], r[5]);
587 __m256i __t6 = _mm256_unpacklo_epi32(r[6], r[7]);
588 __m256i __t7 = _mm256_unpackhi_epi32(r[6], r[7]);
589
590 // tt0: a0a1, b0b1, c0c1, d0d1,
591 // a9a9, b8b9, c8c9, d8d9
592
593 // tt1: a2a3, b2b3, c2c3, d2d3,
594 // a10a11, b10b11, c10c11, d10d11
595
596 // tt2: a4a5, b4b5, c4c5, d4d5,
597 // a12a13, b12b13, c12c13, d12d13
598
599 // tt3: a6a7, b6b7, c6c7, d6d7,
600 // a14a15, b14b15, c14c15, d14d15
601
602 // tt4: e0e1, f0f1, g0g1, h0h1,
603 // e9e9, f8f9, g8g9, h8h9
604 __m256i __tt0 = _mm256_unpacklo_epi64(__t0, __t2);
605 __m256i __tt1 = _mm256_unpackhi_epi64(__t0, __t2);
606 __m256i __tt2 = _mm256_unpacklo_epi64(__t1, __t3);
607 __m256i __tt3 = _mm256_unpackhi_epi64(__t1, __t3);
608 __m256i __tt4 = _mm256_unpacklo_epi64(__t4, __t6);
609 __m256i __tt5 = _mm256_unpackhi_epi64(__t4, __t6);
610 __m256i __tt6 = _mm256_unpacklo_epi64(__t5, __t7);
611 __m256i __tt7 = _mm256_unpackhi_epi64(__t5, __t7);
612
613 // t0: a0b0, a1b1, c0c1, d0d1,
614 // a8b8, a9b9, c8c9, d8d9
615 __t0 = _mm256_shufflelo_epi16(__tt0, 0xD8);
616 __t1 = _mm256_shufflelo_epi16(__tt1, 0xD8);
617 __t2 = _mm256_shufflelo_epi16(__tt2, 0xD8);
618 __t3 = _mm256_shufflelo_epi16(__tt3, 0xD8);
619 __t4 = _mm256_shufflelo_epi16(__tt4, 0xD8);
620 __t5 = _mm256_shufflelo_epi16(__tt5, 0xD8);
621 __t6 = _mm256_shufflelo_epi16(__tt6, 0xD8);
622 __t7 = _mm256_shufflelo_epi16(__tt7, 0xD8);
623
624 // tt0: a0b0, a1b1, c0d0, c1d1,
625 // a8b8, a9b9, c8d8, c9d9
626 __tt0 = _mm256_shufflehi_epi16(__t0, 0xD8);
627 __tt1 = _mm256_shufflehi_epi16(__t1, 0xD8);
628 __tt2 = _mm256_shufflehi_epi16(__t2, 0xD8);
629 __tt3 = _mm256_shufflehi_epi16(__t3, 0xD8);
630 __tt4 = _mm256_shufflehi_epi16(__t4, 0xD8);
631 __tt5 = _mm256_shufflehi_epi16(__t5, 0xD8);
632 __tt6 = _mm256_shufflehi_epi16(__t6, 0xD8);
633 __tt7 = _mm256_shufflehi_epi16(__t7, 0xD8);
634
635 // t0: a0b0, c0d0, a1b1, c1d1,
636 // a8b8, c8d8, a9b9, c9d9
637 __t0 = _mm256_shuffle_epi32(__tt0, 0xD8);
638 __t1 = _mm256_shuffle_epi32(__tt1, 0xD8);
639 __t2 = _mm256_shuffle_epi32(__tt2, 0xD8);
640 __t3 = _mm256_shuffle_epi32(__tt3, 0xD8);
641 // t4: e0f0, g0h0, e1f1, g1h1,
642 // e8f8, g8h8, e9f9, g9h9
643 __t4 = _mm256_shuffle_epi32(__tt4, 0xD8);
644 __t5 = _mm256_shuffle_epi32(__tt5, 0xD8);
645 __t6 = _mm256_shuffle_epi32(__tt6, 0xD8);
646 __t7 = _mm256_shuffle_epi32(__tt7, 0xD8);
647
648 // r0: a0b0, c0d0, e0f0, g0h0,
649 // a8b8, c8d8, e8f8, g8h8
650 r[0] = _mm256_unpacklo_epi64(__t0, __t4); // 0, 8
651 // r1: a1b1, c1d1, e1f1, g1h1,
652 // a9b9, c9d9, e9f9, g9h9
653 r[1] = _mm256_unpackhi_epi64(__t0, __t4); // 1, 9
654 r[2] = _mm256_unpacklo_epi64(__t1, __t5); // 2, 10
655 r[3] = _mm256_unpackhi_epi64(__t1, __t5); // 3, 11
656 r[4] = _mm256_unpacklo_epi64(__t2, __t6); // 4, 12
657 r[5] = _mm256_unpackhi_epi64(__t2, __t6); // 5, 13
658 r[6] = _mm256_unpacklo_epi64(__t3, __t7); // 6, 14
659 r[7] = _mm256_unpackhi_epi64(__t3, __t7); // 7, 15
660
661 // stores back 16 rows:
662 if (MREM || NREM) {
663 store_with_remainders_i16(dst, ld_dst, r, mrem, nrem);
664 } else {
665 _mm_storeu_si128(
666 reinterpret_cast<__m128i*>(dst), _mm256_extractf128_si256(r[0], 0));
667 _mm_storeu_si128(
668 reinterpret_cast<__m128i*>((dst) + ld_dst),
669 _mm256_extractf128_si256(r[1], 0));
670 _mm_storeu_si128(
671 reinterpret_cast<__m128i*>((dst) + ld_dst * 2),
672 _mm256_extractf128_si256(r[2], 0));
673 _mm_storeu_si128(
674 reinterpret_cast<__m128i*>((dst) + ld_dst * 3),
675 _mm256_extractf128_si256(r[3], 0));
676 _mm_storeu_si128(
677 reinterpret_cast<__m128i*>((dst) + ld_dst * 4),
678 _mm256_extractf128_si256(r[4], 0));
679 _mm_storeu_si128(
680 reinterpret_cast<__m128i*>((dst) + ld_dst * 5),
681 _mm256_extractf128_si256(r[5], 0));
682 _mm_storeu_si128(
683 reinterpret_cast<__m128i*>((dst) + ld_dst * 6),
684 _mm256_extractf128_si256(r[6], 0));
685 _mm_storeu_si128(
686 reinterpret_cast<__m128i*>((dst) + ld_dst * 7),
687 _mm256_extractf128_si256(r[7], 0));
688
689 _mm_storeu_si128(
690 reinterpret_cast<__m128i*>((dst) + ld_dst * 8),
691 _mm256_extractf128_si256(r[0], 1));
692 _mm_storeu_si128(
693 reinterpret_cast<__m128i*>((dst) + ld_dst * 9),
694 _mm256_extractf128_si256(r[1], 1));
695 _mm_storeu_si128(
696 reinterpret_cast<__m128i*>((dst) + ld_dst * 10),
697 _mm256_extractf128_si256(r[2], 1));
698 _mm_storeu_si128(
699 reinterpret_cast<__m128i*>((dst) + ld_dst * 11),
700 _mm256_extractf128_si256(r[3], 1));
701 _mm_storeu_si128(
702 reinterpret_cast<__m128i*>((dst) + ld_dst * 12),
703 _mm256_extractf128_si256(r[4], 1));
704 _mm_storeu_si128(
705 reinterpret_cast<__m128i*>((dst) + ld_dst * 13),
706 _mm256_extractf128_si256(r[5], 1));
707 _mm_storeu_si128(
708 reinterpret_cast<__m128i*>((dst) + ld_dst * 14),
709 _mm256_extractf128_si256(r[6], 1));
710 _mm_storeu_si128(
711 reinterpret_cast<__m128i*>((dst) + ld_dst * 15),
712 _mm256_extractf128_si256(r[7], 1));
713 }
714}
715
716// kernel for transposing mxn where m, n <= 8
717// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + 2 * N instructions
718template <unsigned M>
719static void transpose_kernel_mxn_avx2_uint8(
720 unsigned N,
721 const uint8_t* src,
722 int64_t ld_src,
723 uint8_t* dst,
724 int64_t ld_dst) {
725 // load from src to registers
726 // first load masks
727 __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
728 internal::avx2_ps_or_epi32_masks[N / 4]));
729
730 __m256i input[8];
731 unsigned i, j;
732 for (i = 0; i < M; ++i) {
733 uint8_t local_buffer[32] = {0};
734
735 // first load into local buffer with mask
736 input[i] = _mm256_maskload_epi32(
737 reinterpret_cast<const int*>(src + i * ld_src), mask_v);
738
739 _mm256_storeu_si256(reinterpret_cast<__m256i*>(&local_buffer[0]), input[i]);
740
741 // fill in the local buffer with the remainder elements
742 for (j = N / 4 * 4; j < N; j++)
743 local_buffer[j] = src[i * ld_src + j];
744
745 // from local buffer to input registers
746 input[i] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(&local_buffer[0]));
747 }
748
749 // for (; i < 8; ++i) {
750 // input[i] = _mm256_setzero_si256();
751 //}
752
753 // interleaving 8-bit elements
754 // e.g., temp[0] now becomes: a0 b0 a1 b1 a2 b2 ...
755 __m256i temp[8];
756 for (i = 0; i < (M + 1) / 2; ++i) {
757 temp[2 * i] = _mm256_unpacklo_epi8(input[2 * i], input[2 * i + 1]);
758 temp[2 * i + 1] = _mm256_unpackhi_epi8(input[2 * i], input[2 * i + 1]);
759 }
760 for (i = i * 2; i < 8; ++i) {
761 temp[i] = _mm256_setzero_si256();
762 }
763
764 // interleaving 16-bit elements
765 // e.g., temp[0] now becomes: a0 b0 c0 d0 a1 b1 c1 d1 ...
766 for (i = 0; i < (M + 3) / 4; ++i) {
767 input[4 * i] = _mm256_unpacklo_epi16(temp[i * 4], temp[i * 4 + 2]);
768 input[4 * i + 1] = _mm256_unpackhi_epi16(temp[i * 4], temp[i * 4 + 2]);
769 input[4 * i + 2] = _mm256_unpacklo_epi16(temp[i * 4 + 1], temp[i * 4 + 3]);
770 input[4 * i + 3] = _mm256_unpackhi_epi16(temp[i * 4 + 1], temp[i * 4 + 3]);
771 }
772
773 // interleaving 32-bit elements
774 // e.g., temp[0] now becomes a0 b0 c0 d0 e0 f0 g0 h0 ...
775 for (i = 0; i < 4 /*(M + 1) / 2*/; ++i) {
776 temp[2 * i] = _mm256_unpacklo_epi32(input[i], input[(i + 4)]);
777 temp[2 * i + 1] = _mm256_unpackhi_epi32(input[i], input[(i + 4)]);
778 }
779
780 // retrieve the final result, extract every 64-bit
781 // i.e., take a 256-bit temp[0] for example, that will
782 // 0-63 bit: a0 -- h0,
783 // 64-127 bit: a1 -- h1,
784 // 128-191 bit: a16 -- h16,
785 // 192-255 bit: a17 -- h17
786 uint64_t t;
787 mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
788 internal::avx2_ps_or_epi32_masks[M / 4]));
789 for (i = 0; i < N; ++i) {
790 if (i < 16) {
791 if (i % 2 == 0)
792 t = _mm256_extract_epi64(temp[i / 2], 0);
793 else
794 t = _mm256_extract_epi64(temp[i / 2], 1);
795
796 } else {
797 if (i % 2 == 0)
798 t = _mm256_extract_epi64(temp[(i - 16) / 2], 2);
799 else
800 t = _mm256_extract_epi64(temp[(i - 16) / 2], 3);
801 }
802 __m256i t_vec = _mm256_set_epi64x(0, 0, 0, t);
803 _mm256_maskstore_epi32(
804 reinterpret_cast<int*>(dst + i * ld_dst), mask_v, t_vec);
805 for (j = M / 4 * 4; j < M; j++) {
806 dst[ld_dst * i + j] = ((t >> (8 * j)) & 255);
807 }
808 }
809}
810
811#endif // __AVX2__
812
813} // namespace internal
814
815} // namespace fbgemm
816