1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8#if defined(__x86_64__) || defined(__i386__) || \
9 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
10#include <immintrin.h>
11#endif
12#include "./TransposeUtils.h"
13#include "./TransposeUtilsAvx2.h"
14namespace fbgemm {
15
16namespace {
17
18// 16 * 6 = 96 instructions
19inline void transpose_kernel_16x16_avx512(
20 const float* src,
21 int64_t ld_src,
22 float* dst,
23 int64_t ld_dst) {
24 // load from src to registers
25 // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15
26 // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15
27 // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15
28 // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15
29 // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15
30 // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15
31 // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15
32 // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15
33 // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15
34 // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15
35 // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15
36 // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15
37 // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15
38 // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15
39 // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15
40 // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15
41 __m512 a = _mm512_loadu_ps(&src[0 * ld_src]);
42 __m512 b = _mm512_loadu_ps(&src[1 * ld_src]);
43 __m512 c = _mm512_loadu_ps(&src[2 * ld_src]);
44 __m512 d = _mm512_loadu_ps(&src[3 * ld_src]);
45 __m512 e = _mm512_loadu_ps(&src[4 * ld_src]);
46 __m512 f = _mm512_loadu_ps(&src[5 * ld_src]);
47 __m512 g = _mm512_loadu_ps(&src[6 * ld_src]);
48 __m512 h = _mm512_loadu_ps(&src[7 * ld_src]);
49 __m512 i = _mm512_loadu_ps(&src[8 * ld_src]);
50 __m512 j = _mm512_loadu_ps(&src[9 * ld_src]);
51 __m512 k = _mm512_loadu_ps(&src[10 * ld_src]);
52 __m512 l = _mm512_loadu_ps(&src[11 * ld_src]);
53 __m512 m = _mm512_loadu_ps(&src[12 * ld_src]);
54 __m512 n = _mm512_loadu_ps(&src[13 * ld_src]);
55 __m512 o = _mm512_loadu_ps(&src[14 * ld_src]);
56 __m512 p = _mm512_loadu_ps(&src[15 * ld_src]);
57
58 __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq;
59 // unpacking and interleaving 32-bit elements
60 // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13
61 // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15
62 // c0 d0 c1 d1 ...
63 // c2 d2 c3 d3 ...
64 // e0 f0 e1 f1 ...
65 // e2 f2 e3 f3 ...
66 // g0 h0 g1 h1 ...
67 // g2 h2 g3 h3 ...
68 // i0 ...
69 // i2 ...
70 // k0 ...
71 // k2 ...
72 // m0 ...
73 // m2 ...
74 // o0 ...
75 // o1 ...
76 ta = _mm512_unpacklo_ps(a, b);
77 tb = _mm512_unpackhi_ps(a, b);
78 tc = _mm512_unpacklo_ps(c, d);
79 td = _mm512_unpackhi_ps(c, d);
80 te = _mm512_unpacklo_ps(e, f);
81 tf = _mm512_unpackhi_ps(e, f);
82 tg = _mm512_unpacklo_ps(g, h);
83 th = _mm512_unpackhi_ps(g, h);
84 ti = _mm512_unpacklo_ps(i, j);
85 tj = _mm512_unpackhi_ps(i, j);
86 tk = _mm512_unpacklo_ps(k, l);
87 tl = _mm512_unpackhi_ps(k, l);
88 tm = _mm512_unpacklo_ps(m, n);
89 tn = _mm512_unpackhi_ps(m, n);
90 to = _mm512_unpacklo_ps(o, p);
91 tq = _mm512_unpackhi_ps(o, p);
92
93 // unpacking and interleaving 64-bit elements
94 // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12
95 // a1 b1 c1 d1 ...
96 // a2 b2 c2 d2 ...
97 // a3 b3 c3 d3 ...
98 // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12
99 // e1 f1 g1 h1 ...
100 // e2 f2 g2 h2 ...
101 // e3 f3 g3 h3 ...
102 // i0 j0 k0 l0 ...
103 // i1 j1 k1 l1 ...
104 // i2 j2 k2 l2 ...
105 // i3 j3 k3 l3 ...
106 // m0 n0 o0 p0 ...
107 // m1 n1 o1 p1 ...
108 // m2 n2 o2 p2 ...
109 // m3 n3 o3 p3 ...
110 a = _mm512_castpd_ps(
111 _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
112 b = _mm512_castpd_ps(
113 _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
114 c = _mm512_castpd_ps(
115 _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
116 d = _mm512_castpd_ps(
117 _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
118 e = _mm512_castpd_ps(
119 _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
120 f = _mm512_castpd_ps(
121 _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
122 g = _mm512_castpd_ps(
123 _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
124 h = _mm512_castpd_ps(
125 _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
126 i = _mm512_castpd_ps(
127 _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
128 j = _mm512_castpd_ps(
129 _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
130 k = _mm512_castpd_ps(
131 _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
132 l = _mm512_castpd_ps(
133 _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
134 m = _mm512_castpd_ps(
135 _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
136 n = _mm512_castpd_ps(
137 _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
138 o = _mm512_castpd_ps(
139 _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
140 p = _mm512_castpd_ps(
141 _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
142
143 // shuffle 128-bits (composed of 4 32-bit elements)
144 // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8
145 // a1 b1 c1 d1 ...
146 // a2 b2 c2 d2 ...
147 // a3 b3 c3 d3 ...
148 // a4 b4 c4 d4 ...
149 // a5 b5 c5 d5 ...
150 // a6 b6 c6 d6 ...
151 // a7 b7 c7 d7 ...
152 // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8
153 // i1 j1 k1 l1 ...
154 // i2 j2 k2 l2 ...
155 // i3 j3 k3 l3 ...
156 // i4 j4 k4 l4 ...
157 // i5 j5 k5 l5 ...
158 // i6 j6 k6 l6 ...
159 // i7 j7 k7 l7 ...
160 ta = _mm512_shuffle_f32x4(a, e, 0x88);
161 tb = _mm512_shuffle_f32x4(b, f, 0x88);
162 tc = _mm512_shuffle_f32x4(c, g, 0x88);
163 td = _mm512_shuffle_f32x4(d, h, 0x88);
164 te = _mm512_shuffle_f32x4(a, e, 0xdd);
165 tf = _mm512_shuffle_f32x4(b, f, 0xdd);
166 tg = _mm512_shuffle_f32x4(c, g, 0xdd);
167 th = _mm512_shuffle_f32x4(d, h, 0xdd);
168 ti = _mm512_shuffle_f32x4(i, m, 0x88);
169 tj = _mm512_shuffle_f32x4(j, n, 0x88);
170 tk = _mm512_shuffle_f32x4(k, o, 0x88);
171 tl = _mm512_shuffle_f32x4(l, p, 0x88);
172 tm = _mm512_shuffle_f32x4(i, m, 0xdd);
173 tn = _mm512_shuffle_f32x4(j, n, 0xdd);
174 to = _mm512_shuffle_f32x4(k, o, 0xdd);
175 tq = _mm512_shuffle_f32x4(l, p, 0xdd);
176
177 // shuffle 128-bits (composed of 4 32-bit elements)
178 // a0 b0 c0 d0 ... o0
179 // a1 b1 c1 d1 ... o1
180 // a2 b2 c2 d2 ... o2
181 // a3 b3 c3 d3 ... o3
182 // a4 ...
183 // a5 ...
184 // a6 ...
185 // a7 ...
186 // a8 ...
187 // a9 ...
188 // a10 ...
189 // a11 ...
190 // a12 ...
191 // a13 ...
192 // a14 ...
193 // a15 b15 c15 d15 ... o15
194 a = _mm512_shuffle_f32x4(ta, ti, 0x88);
195 b = _mm512_shuffle_f32x4(tb, tj, 0x88);
196 c = _mm512_shuffle_f32x4(tc, tk, 0x88);
197 d = _mm512_shuffle_f32x4(td, tl, 0x88);
198 e = _mm512_shuffle_f32x4(te, tm, 0x88);
199 f = _mm512_shuffle_f32x4(tf, tn, 0x88);
200 g = _mm512_shuffle_f32x4(tg, to, 0x88);
201 h = _mm512_shuffle_f32x4(th, tq, 0x88);
202 i = _mm512_shuffle_f32x4(ta, ti, 0xdd);
203 j = _mm512_shuffle_f32x4(tb, tj, 0xdd);
204 k = _mm512_shuffle_f32x4(tc, tk, 0xdd);
205 l = _mm512_shuffle_f32x4(td, tl, 0xdd);
206 m = _mm512_shuffle_f32x4(te, tm, 0xdd);
207 n = _mm512_shuffle_f32x4(tf, tn, 0xdd);
208 o = _mm512_shuffle_f32x4(tg, to, 0xdd);
209 p = _mm512_shuffle_f32x4(th, tq, 0xdd);
210
211 // store from registers to dst
212 _mm512_storeu_ps(&dst[0 * ld_dst], a);
213 _mm512_storeu_ps(&dst[1 * ld_dst], b);
214 _mm512_storeu_ps(&dst[2 * ld_dst], c);
215 _mm512_storeu_ps(&dst[3 * ld_dst], d);
216 _mm512_storeu_ps(&dst[4 * ld_dst], e);
217 _mm512_storeu_ps(&dst[5 * ld_dst], f);
218 _mm512_storeu_ps(&dst[6 * ld_dst], g);
219 _mm512_storeu_ps(&dst[7 * ld_dst], h);
220 _mm512_storeu_ps(&dst[8 * ld_dst], i);
221 _mm512_storeu_ps(&dst[9 * ld_dst], j);
222 _mm512_storeu_ps(&dst[10 * ld_dst], k);
223 _mm512_storeu_ps(&dst[11 * ld_dst], l);
224 _mm512_storeu_ps(&dst[12 * ld_dst], m);
225 _mm512_storeu_ps(&dst[13 * ld_dst], n);
226 _mm512_storeu_ps(&dst[14 * ld_dst], o);
227 _mm512_storeu_ps(&dst[15 * ld_dst], p);
228}
229
230// kernel for transposing mxn where m, n <= 16
231// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions
232template <int M>
233void transpose_kernel_mxn_avx512(
234 int N,
235 const float* src,
236 int64_t ld_src,
237 float* dst,
238 int64_t ld_dst) {
239 // load from src to registers
240 __mmask16 src_mask = (1 << N) - 1;
241 __m512 input[16];
242 int i;
243 for (i = 0; i < M; ++i) {
244 input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]);
245 }
246 for (; i < 16; ++i) {
247 // Not really needed but to avoid uninitialized variable warning.
248 // Shouldn't be much overhead because xor can be executed in parallel with
249 // other instructions.
250 input[i] = _mm512_setzero_ps();
251 }
252
253 // unpacking and interleaving 32-bit elements
254 __m512 temp[16];
255 for (i = 0; i < (M + 1) / 2; ++i) {
256 temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]);
257 temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]);
258 }
259 for (i = i * 2; i < 16; ++i) {
260 temp[i] = _mm512_setzero_ps();
261 }
262
263 // unpacking and interleaving 64-bit elements
264 for (i = 0; i < (M + 3) / 4; ++i) {
265 input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd(
266 _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2])));
267 input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd(
268 _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2])));
269 input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd(
270 _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3])));
271 input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd(
272 _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3])));
273 }
274
275 // shuffle 128-bits (composed of 4 32-bit elements)
276 for (i = 0; i < (M + 7) / 8; ++i) {
277 temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88);
278 temp[8 * i + 1] =
279 _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88);
280 temp[8 * i + 2] =
281 _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88);
282 temp[8 * i + 3] =
283 _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88);
284 temp[8 * i + 4] =
285 _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd);
286 temp[8 * i + 5] =
287 _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd);
288 temp[8 * i + 6] =
289 _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd);
290 temp[8 * i + 7] =
291 _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd);
292 }
293
294 // store from registers to dst
295 __mmask16 dst_mask = (1 << M) - 1;
296 for (i = 0; i < N; ++i) {
297 if (i < 8) {
298 input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88);
299 } else {
300 input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd);
301 }
302 _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]);
303 }
304}
305
306} // namespace
307
308namespace internal {
309
310template <typename T>
311void transpose_avx512_contiguous_thin(
312 const int64_t M,
313 const int64_t N,
314 const T* src,
315 int64_t ld_src,
316 T* dst,
317 int64_t ld_dst);
318
319template <typename T>
320void transpose_avx512_contiguous_wide(
321 const int64_t M,
322 const int64_t N,
323 const T* src,
324 int64_t ld_src,
325 T* dst,
326 int64_t ld_dst);
327
328// Permute elements in 128 bit lane
329// e.g., if a 128-bit lane has the following elements:
330// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
331//
332// After this function call, it becomes
333// 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15
334// The same happens with other 3 lanes.
335static inline __m512i permute_row(__m512i row) {
336 // clang-format off
337 __m256i shuffle_256v0 = _mm256_set_epi8(
338 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
339 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
340 // clang-format on
341 __m512i shuffle_512v = _mm512_castsi256_si512(shuffle_256v0);
342 row = _mm512_shuffle_epi8(
343 row, _mm512_inserti64x4(shuffle_512v, shuffle_256v0, 1));
344 return row;
345}
346
347static inline void core_transpose_16x32_block_i8(__m512i r[], __m512i u[]) {
348 // Result after this operation; Read in conjunction with comments in
349 // transpose_16x32_block
350 // 00_00 00_01 01_00 01_01 00_04 00_05 01_04 01_05 04_00 04_01 05_00 05_01
351 // 04_04 04_05 05_04 05_05
352 u[0] = _mm512_unpacklo_epi64(r[0], r[1]);
353 // 00_02 00_03 01_02 01_03 00_06 00_07 01_06 01_07 04_02 04_03 05_02 05_03
354 // 04_06 04_07 05_06 05_07
355 u[1] = _mm512_unpackhi_epi64(r[0], r[1]);
356 // 02_00 02_01 03_00 03_01 02_04 02_05 03_04 03_05 06_00 06_01 07_00 07_01
357 // 06_04 06_05 07_04 07_05
358 u[2] = _mm512_unpacklo_epi64(r[2], r[3]);
359 // 02_02 02_03 03_02 03_03 02_06 02_07 03_06 03_07 06_02 06_03 07_02 07_03
360 // 06_06 06_07 07_06 07_07
361 u[3] = _mm512_unpackhi_epi64(r[2], r[3]);
362 // 08_00 08_01 09_00 09_01 08_04 08_05 09_04 09_05 12_00 12_01 13_00 13_01
363 // 12_04 12_05 13_04 13_05
364 u[4] = _mm512_unpacklo_epi64(r[4], r[5]);
365 u[5] = _mm512_unpackhi_epi64(r[4], r[5]);
366 u[6] = _mm512_unpacklo_epi64(r[6], r[7]);
367 u[7] = _mm512_unpackhi_epi64(r[6], r[7]);
368
369 // This instruction doesn't exist for epi32 so casting to ps
370 // 00_00 01_00 02_00 03_00 00_04 01_04 02_04 03_04 04_00 05_00 06_00 07_00
371 // 04_04 05_04 06_04 07_04
372 r[0] = _mm512_castps_si512(_mm512_shuffle_ps(
373 _mm512_castsi512_ps(u[0]), _mm512_castsi512_ps(u[2]), 0x88));
374 // 00_01 01_01 02_01 03_01 00_05 01_05 02_05 03_05 04_01 05_01 06_01 07_01
375 // 04_05 05_05 06_05 07_05
376 r[1] = _mm512_castps_si512(_mm512_shuffle_ps(
377 _mm512_castsi512_ps(u[0]), _mm512_castsi512_ps(u[2]), 0xDD));
378 r[2] = _mm512_castps_si512(_mm512_shuffle_ps(
379 _mm512_castsi512_ps(u[1]), _mm512_castsi512_ps(u[3]), 0x88));
380 r[3] = _mm512_castps_si512(_mm512_shuffle_ps(
381 _mm512_castsi512_ps(u[1]), _mm512_castsi512_ps(u[3]), 0xDD));
382 // 08_00 09_00 10_00 11_00 08_04 09_04 10_04 11_04 12_00 13_00 14_00 15_00
383 // 12_04 13_04 14_04 15_04
384 r[4] = _mm512_castps_si512(_mm512_shuffle_ps(
385 _mm512_castsi512_ps(u[4]), _mm512_castsi512_ps(u[6]), 0x88));
386 r[5] = _mm512_castps_si512(_mm512_shuffle_ps(
387 _mm512_castsi512_ps(u[4]), _mm512_castsi512_ps(u[6]), 0xDD));
388 r[6] = _mm512_castps_si512(_mm512_shuffle_ps(
389 _mm512_castsi512_ps(u[5]), _mm512_castsi512_ps(u[7]), 0x88));
390 r[7] = _mm512_castps_si512(_mm512_shuffle_ps(
391 _mm512_castsi512_ps(u[5]), _mm512_castsi512_ps(u[7]), 0xDD));
392
393 // permute among 128-bit lanes
394 r[0] = permute_row(r[0]);
395 r[1] = permute_row(r[1]);
396 r[2] = permute_row(r[2]);
397 r[3] = permute_row(r[3]);
398 r[4] = permute_row(r[4]);
399 r[5] = permute_row(r[5]);
400 r[6] = permute_row(r[6]);
401 r[7] = permute_row(r[7]);
402
403 __m512i const1 = _mm512_set_epi32(
404 27, 19, 11, 3, 26, 18, 10, 2, 25, 17, 9, 1, 24, 16, 8, 0);
405 __m512i const2 = _mm512_set_epi32(
406 31, 23, 15, 7, 30, 22, 14, 6, 29, 21, 13, 5, 28, 20, 12, 4);
407
408 // merge 128-bit values from two regs
409 u[0] = _mm512_permutex2var_epi32(r[0], const1, r[4]);
410 u[1] = _mm512_permutex2var_epi32(r[0], const2, r[4]);
411 u[2] = _mm512_permutex2var_epi32(r[1], const1, r[5]);
412 u[3] = _mm512_permutex2var_epi32(r[1], const2, r[5]);
413 u[4] = _mm512_permutex2var_epi32(r[2], const1, r[6]);
414 u[5] = _mm512_permutex2var_epi32(r[2], const2, r[6]);
415 u[6] = _mm512_permutex2var_epi32(r[3], const1, r[7]);
416 u[7] = _mm512_permutex2var_epi32(r[3], const2, r[7]);
417}
418
419static inline void core_transpose_16x16_block(__m512i r[], __m512i u[]) {
420 // a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 f8f9
421 // e10e11 f10f11
422 u[0] = _mm512_unpacklo_epi32(r[0], r[1]);
423 // a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 f4f5 e6e7 f6f7
424 // e12e13 f12f13 e14e15 f14f15
425 u[1] = _mm512_unpackhi_epi32(r[0], r[1]);
426 // c0c1 d0d1 c2c3 d2d3 c8c9 d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9
427 // g10g11 h10h11
428 u[2] = _mm512_unpacklo_epi32(r[2], r[3]);
429 // c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7
430 // g12g13 h12h13 g14g15 h14h15
431 u[3] = _mm512_unpackhi_epi32(r[2], r[3]);
432 // i j m n
433 u[4] = _mm512_unpacklo_epi32(r[4], r[5]);
434 u[5] = _mm512_unpackhi_epi32(r[4], r[5]);
435 // k l o p
436 u[6] = _mm512_unpacklo_epi32(r[6], r[7]);
437 u[7] = _mm512_unpackhi_epi32(r[6], r[7]);
438
439 // a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 g8g9
440 // h8h9
441 r[0] = _mm512_unpacklo_epi64(u[0], u[2]);
442 // a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 g2g3 h2h3 e10e11
443 // f10f11 g10g11 h10h11
444 r[1] = _mm512_unpackhi_epi64(u[0], u[2]);
445 // a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 c12c13 d12d13
446 r[2] = _mm512_unpacklo_epi64(u[1], u[3]);
447 // a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15
448 r[3] = _mm512_unpackhi_epi64(u[1], u[3]);
449 // i j k l m n o p
450 r[4] = _mm512_unpacklo_epi64(u[4], u[6]);
451 r[5] = _mm512_unpackhi_epi64(u[4], u[6]);
452 r[6] = _mm512_unpacklo_epi64(u[5], u[7]);
453 r[7] = _mm512_unpackhi_epi64(u[5], u[7]);
454
455 __m512i const1 = _mm512_set_epi32(
456 0x00370035,
457 0x00330031,
458 0x00270025,
459 0x00230021,
460 0x00170015,
461 0x00130011,
462 0x00070005,
463 0x00030001,
464 0x00360034,
465 0x00320030,
466 0x00260024,
467 0x00220020,
468 0x00160014,
469 0x00120010,
470 0x00060004,
471 0x00020000);
472 __m512i const2 = _mm512_set_epi32(
473 0x003f003d,
474 0x003b0039,
475 0x002f002d,
476 0x002b0029,
477 0x001f001d,
478 0x001b0019,
479 0x000f000d,
480 0x000b0009,
481 0x003e003c,
482 0x003a0038,
483 0x002e002c,
484 0x002a0028,
485 0x001e001c,
486 0x001a0018,
487 0x000e000c,
488 0x000a0008);
489
490 // merge values from two regs
491 u[0] = _mm512_permutex2var_epi16(r[0], const1, r[4]); // 0-- 1--
492 u[4] = _mm512_permutex2var_epi16(r[0], const2, r[4]); // 8-- 9--
493 u[2] = _mm512_permutex2var_epi16(r[2], const1, r[6]); // 4-- 5--
494 u[6] = _mm512_permutex2var_epi16(r[2], const2, r[6]); // 12-- 13--
495 u[1] = _mm512_permutex2var_epi16(r[1], const1, r[5]); // 2-- 3--
496 u[5] = _mm512_permutex2var_epi16(r[1], const2, r[5]); // 10-- 11--
497 u[3] = _mm512_permutex2var_epi16(r[3], const1, r[7]); // 6-- 7--
498 u[7] = _mm512_permutex2var_epi16(r[3], const2, r[7]); // 14-- 15--
499}
500
501static inline void load_with_remainders_i16(
502 const uint16_t* src,
503 int64_t ld_src,
504 __m512i r[],
505 int mrem,
506 int nrem) {
507 __m512i t[16];
508 if (nrem < 16) {
509 __mmask32 mask_nrem_v = (((long long)1) << nrem) - 1;
510 for (int i = 0; i < mrem; ++i) {
511 // mask load
512 t[i] = _mm512_maskz_loadu_epi16(mask_nrem_v, src + i * ld_src);
513 }
514 } else {
515 for (int i = 0; i < mrem; ++i) {
516 // normal load
517 t[i] = _mm512_castsi256_si512(_mm256_loadu_si256(
518 reinterpret_cast<const __m256i*>(src + i * ld_src)));
519 }
520 }
521 r[0] = _mm512_inserti64x4(t[0], _mm512_castsi512_si256(t[4]), 0x01);
522 r[1] = _mm512_inserti64x4(t[1], _mm512_castsi512_si256(t[5]), 0x01);
523 r[2] = _mm512_inserti64x4(t[2], _mm512_castsi512_si256(t[6]), 0x01);
524 r[3] = _mm512_inserti64x4(t[3], _mm512_castsi512_si256(t[7]), 0x01);
525 r[4] = _mm512_inserti64x4(t[8], _mm512_castsi512_si256(t[12]), 0x01);
526 r[5] = _mm512_inserti64x4(t[9], _mm512_castsi512_si256(t[13]), 0x01);
527 r[6] = _mm512_inserti64x4(t[10], _mm512_castsi512_si256(t[14]), 0x01);
528 r[7] = _mm512_inserti64x4(t[11], _mm512_castsi512_si256(t[15]), 0x01);
529}
530
531static inline void load_with_remainders_i8(
532 const uint8_t* src,
533 int64_t ld_src,
534 __m512i r[],
535 int mrem,
536 int nrem) {
537 __m512i t[16];
538 if (nrem < 32) {
539 __mmask64 mask_nrem_v = (((long long)1) << nrem) - 1;
540 for (int i = 0; i < mrem; ++i) {
541 // mask load
542 t[i] = _mm512_maskz_loadu_epi8(mask_nrem_v, src + i * ld_src);
543 }
544 } else {
545 for (int i = 0; i < mrem; ++i) {
546 // normal load
547 t[i] = _mm512_castsi256_si512(_mm256_loadu_si256(
548 reinterpret_cast<const __m256i*>(src + i * ld_src)));
549 }
550 }
551 r[0] = _mm512_inserti64x4(t[0], _mm512_castsi512_si256(t[4]), 0x01);
552 r[1] = _mm512_inserti64x4(t[1], _mm512_castsi512_si256(t[5]), 0x01);
553 r[2] = _mm512_inserti64x4(t[2], _mm512_castsi512_si256(t[6]), 0x01);
554 r[3] = _mm512_inserti64x4(t[3], _mm512_castsi512_si256(t[7]), 0x01);
555 r[4] = _mm512_inserti64x4(t[8], _mm512_castsi512_si256(t[12]), 0x01);
556 r[5] = _mm512_inserti64x4(t[9], _mm512_castsi512_si256(t[13]), 0x01);
557 r[6] = _mm512_inserti64x4(t[10], _mm512_castsi512_si256(t[14]), 0x01);
558 r[7] = _mm512_inserti64x4(t[11], _mm512_castsi512_si256(t[15]), 0x01);
559}
560
561static inline void store_with_remainders_i16(
562 uint16_t* dst,
563 int64_t ld_dst,
564 __m512i u[],
565 int mrem,
566 int nrem) {
567 if (mrem < 16) {
568 __mmask32 mask_mrem_v = (((long long)1) << mrem) - 1;
569 int i = 0;
570
571 for (; i < nrem / 2 * 2; i += 2) {
572 // mask store
573 int reg_idx = i / 2;
574 _mm512_mask_storeu_epi16(
575 dst + (i + 0) * ld_dst,
576 mask_mrem_v,
577 _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x0)));
578 _mm512_mask_storeu_epi16(
579 dst + (i + 1) * ld_dst,
580 mask_mrem_v,
581 _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x1)));
582 }
583 if (i < nrem) {
584 int reg_idx = i / 2;
585 _mm512_mask_storeu_epi16(
586 dst + (i + 0) * ld_dst,
587 mask_mrem_v,
588 _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x0)));
589 }
590 } else {
591 int i = 0;
592 for (; i < nrem / 2 * 2; i += 2) {
593 // normal store
594 int reg_idx = i / 2;
595 _mm256_storeu_si256(
596 reinterpret_cast<__m256i*>(dst + (i + 0) * ld_dst),
597 _mm512_extracti32x8_epi32(u[reg_idx], 0x0));
598 _mm256_storeu_si256(
599 reinterpret_cast<__m256i*>(dst + (i + 1) * ld_dst),
600 _mm512_extracti32x8_epi32(u[reg_idx], 0x1));
601 }
602 if (i < nrem) {
603 int reg_idx = i / 2;
604 _mm256_storeu_si256(
605 reinterpret_cast<__m256i*>(dst + (i + 0) * ld_dst),
606 _mm512_extracti32x8_epi32(u[reg_idx], 0x0));
607 }
608 }
609}
610
611static inline void store_with_remainders_i8(
612 uint8_t* dst,
613 int64_t ld_dst,
614 __m512i u[],
615 int mrem,
616 int nrem) {
617 if (mrem < 16) {
618 __mmask64 mask_mrem_v = (((long long)1) << mrem) - 1;
619 int i = 0;
620 for (; i < nrem / 4 * 4; i += 4) {
621 // mask store
622 // we need 0, 4, 8, 16 => 0, 2, 4, 6
623 // and 16, 20, 24, 28 => 1, 3, 5, 7
624 // See stores for non-rem case
625 int reg_idx = i / 16 + 2 * ((i % 16) / 4);
626 _mm512_mask_storeu_epi8(
627 dst + (i + 0) * ld_dst,
628 mask_mrem_v,
629 _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x0)));
630 _mm512_mask_storeu_epi8(
631 dst + (i + 1) * ld_dst,
632 mask_mrem_v,
633 _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x1)));
634 _mm512_mask_storeu_epi8(
635 dst + (i + 2) * ld_dst,
636 mask_mrem_v,
637 _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x2)));
638 _mm512_mask_storeu_epi8(
639 dst + (i + 3) * ld_dst,
640 mask_mrem_v,
641 _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x3)));
642 }
643 int rem = nrem - i;
644 int reg_rem_idx = i / 16 + 2 * ((i % 16) / 4);
645 switch (rem) {
646 case 1:
647 _mm512_mask_storeu_epi8(
648 dst + (i + 0) * ld_dst,
649 mask_mrem_v,
650 _mm512_castsi128_si512(
651 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)));
652 break;
653 case 2:
654 _mm512_mask_storeu_epi8(
655 dst + (i + 0) * ld_dst,
656 mask_mrem_v,
657 _mm512_castsi128_si512(
658 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)));
659 _mm512_mask_storeu_epi8(
660 dst + (i + 1) * ld_dst,
661 mask_mrem_v,
662 _mm512_castsi128_si512(
663 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1)));
664 break;
665 case 3:
666 _mm512_mask_storeu_epi8(
667 dst + (i + 0) * ld_dst,
668 mask_mrem_v,
669 _mm512_castsi128_si512(
670 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)));
671 _mm512_mask_storeu_epi8(
672 dst + (i + 1) * ld_dst,
673 mask_mrem_v,
674 _mm512_castsi128_si512(
675 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1)));
676 _mm512_mask_storeu_epi8(
677 dst + (i + 2) * ld_dst,
678 mask_mrem_v,
679 _mm512_castsi128_si512(
680 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x2)));
681 break;
682 default:
683 break;
684 }
685
686 } else {
687 int i = 0;
688 for (; i < nrem / 4 * 4; i += 4) {
689 // normal store
690 int reg_idx = i / 16 + 2 * ((i % 16) / 4);
691 _mm_storeu_si128(
692 reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst),
693 _mm512_extracti32x4_epi32(u[reg_idx], 0x0));
694 _mm_storeu_si128(
695 reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst),
696 _mm512_extracti32x4_epi32(u[reg_idx], 0x1));
697 _mm_storeu_si128(
698 reinterpret_cast<__m128i*>(dst + (i + 2) * ld_dst),
699 _mm512_extracti32x4_epi32(u[reg_idx], 0x2));
700 _mm_storeu_si128(
701 reinterpret_cast<__m128i*>(dst + (i + 3) * ld_dst),
702 _mm512_extracti32x4_epi32(u[reg_idx], 0x3));
703 }
704 int rem = nrem - i;
705 int reg_rem_idx = i / 16 + 2 * ((i % 16) / 4);
706 switch (rem) {
707 case 1:
708 _mm_storeu_si128(
709 reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst),
710 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0));
711 break;
712 case 2:
713 _mm_storeu_si128(
714 reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst),
715 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0));
716 _mm_storeu_si128(
717 reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst),
718 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1));
719 break;
720 case 3:
721 _mm_storeu_si128(
722 reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst),
723 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0));
724 _mm_storeu_si128(
725 reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst),
726 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1));
727 _mm_storeu_si128(
728 reinterpret_cast<__m128i*>(dst + (i + 2) * ld_dst),
729 _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x2));
730 break;
731 default:
732 break;
733 }
734 }
735}
736
737static inline void transpose_contiguous_4x16_block(
738 const float* src,
739 float* dst,
740 int64_t ld_src,
741 int nrem = 16) {
742 __m512i r[4];
743 // load
744 if (nrem < 16) {
745 __mmask16 mask_mrem_v = (((long long)1) << nrem) - 1;
746 r[0] = _mm512_maskz_loadu_epi32(mask_mrem_v, src);
747 r[1] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src);
748 r[2] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 2 * ld_src);
749 r[3] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 3 * ld_src);
750
751 } else {
752 r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
753 r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
754 r[2] =
755 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 2 * ld_src));
756 r[3] =
757 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 3 * ld_src));
758 }
759 // transpose
760 // a0b0 a1b1 a4b4 a5b5 a8b8 a9b9 a12b12 a13b13
761 // a2b2 a3b3 a6b6 a7b7 a10b10 a11b11 a14b14 a15b15
762 // c0d0 c1d1 c4d4 c5d5 c8d8 c9d9 c12d12 c13d13
763 // c2d2 c3d3 c6d6 c7d7 c10b10 c11d11 c14d14 c15d15
764 __m512i t0 = _mm512_unpacklo_epi32(r[0], r[1]);
765 __m512i t1 = _mm512_unpackhi_epi32(r[0], r[1]);
766 __m512i t2 = _mm512_unpacklo_epi32(r[2], r[3]);
767 __m512i t3 = _mm512_unpackhi_epi32(r[2], r[3]);
768
769 r[0] = _mm512_unpacklo_epi64(t0, t2);
770 r[1] = _mm512_unpackhi_epi64(t0, t2);
771 r[2] = _mm512_unpacklo_epi64(t1, t3);
772 r[3] = _mm512_unpackhi_epi64(t1, t3);
773
774 t0 = _mm512_shuffle_i32x4(r[0], r[1], 0x44);
775 t1 = _mm512_shuffle_i32x4(r[0], r[1], 0xee);
776 t2 = _mm512_shuffle_i32x4(r[2], r[3], 0x44);
777 t3 = _mm512_shuffle_i32x4(r[2], r[3], 0xee);
778
779 r[0] = _mm512_shuffle_i32x4(t0, t2, 0x88);
780 r[1] = _mm512_shuffle_i32x4(t0, t2, 0xdd);
781 r[2] = _mm512_shuffle_i32x4(t1, t3, 0x88);
782 r[3] = _mm512_shuffle_i32x4(t1, t3, 0xdd);
783 // store
784 int i = 0;
785 for (; (i + 1) * 16 <= nrem * 4; i++) {
786 // normal store
787 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 16), r[i]);
788 }
789 int erem = nrem * 4 - i * 16;
790 if (erem > 0) {
791 // mask store
792 __mmask16 mask_rem_v = (((long long)1) << erem) - 1;
793 _mm512_mask_storeu_epi32(dst + i * 16, mask_rem_v, r[i]);
794 }
795}
796
797static inline void transpose_contiguous_4x32_block(
798 const uint16_t* src,
799 uint16_t* dst,
800 int64_t ld_src,
801 int nrem = 32) {
802 __m512i r[4], d[4];
803 // load
804 if (nrem < 32) {
805 __mmask32 mask_mrem_v = (((long long)1) << nrem) - 1;
806 r[0] = _mm512_maskz_loadu_epi16(mask_mrem_v, src);
807 r[1] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src);
808 r[2] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 2 * ld_src);
809 r[3] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 3 * ld_src);
810 } else {
811 r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
812 r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
813 r[2] =
814 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 2 * ld_src));
815 r[3] =
816 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 3 * ld_src));
817 }
818 // transpose
819 d[0] = _mm512_unpacklo_epi16(r[0], r[1]);
820 d[1] = _mm512_unpackhi_epi16(r[0], r[1]);
821 d[2] = _mm512_unpacklo_epi16(r[2], r[3]);
822 d[3] = _mm512_unpackhi_epi16(r[2], r[3]);
823
824 r[0] = _mm512_unpacklo_epi32(d[0], d[2]);
825 r[1] = _mm512_unpackhi_epi32(d[0], d[2]);
826 r[2] = _mm512_unpacklo_epi32(d[1], d[3]);
827 r[3] = _mm512_unpackhi_epi32(d[1], d[3]);
828
829 d[0] = _mm512_shuffle_i32x4(r[0], r[1], 0x44);
830 d[1] = _mm512_shuffle_i32x4(r[0], r[1], 0xee);
831 d[2] = _mm512_shuffle_i32x4(r[2], r[3], 0x44);
832 d[3] = _mm512_shuffle_i32x4(r[2], r[3], 0xee);
833
834 r[0] = _mm512_shuffle_i32x4(d[0], d[2], 0x88);
835 r[1] = _mm512_shuffle_i32x4(d[0], d[2], 0xdd);
836 r[2] = _mm512_shuffle_i32x4(d[1], d[3], 0x88);
837 r[3] = _mm512_shuffle_i32x4(d[1], d[3], 0xdd);
838 // store
839 int i = 0;
840 for (; (i + 1) * 32 <= nrem * 4; i++) {
841 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 32), r[i]);
842 }
843 int erem = nrem * 4 - i * 32;
844 if (erem > 0) {
845 // mask store
846 __mmask32 mask_rem_v = (((long long)1) << erem) - 1;
847 _mm512_mask_storeu_epi16(dst + i * 32, mask_rem_v, r[i]);
848 }
849}
850
851static inline void transpose_contiguous_16x4_block(
852 const float* src,
853 float* dst,
854 int64_t ld_dst,
855 int mrem = 16) {
856 __m512i r[4], d[4];
857 int i = 0;
858 for (; (i + 1) * 16 <= mrem * 4; i++) {
859 // normal load
860 r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 16));
861 }
862 if (i * 16 < mrem * 4) {
863 __mmask16 mask_mrem_v = (((long long)1) << (mrem * 4 - i * 16)) - 1;
864 r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16);
865 }
866
867 // transpose
868 __m512i index1 = _mm512_set_epi32(
869 0x0f,
870 0x0b,
871 0x07,
872 0x03,
873 0x0e,
874 0x0a,
875 0x06,
876 0x02,
877 0x0d,
878 0x09,
879 0x05,
880 0x01,
881 0x0c,
882 0x08,
883 0x04,
884 0x00);
885 d[0] = _mm512_permutexvar_epi32(index1, r[0]);
886 d[1] = _mm512_permutexvar_epi32(index1, r[1]);
887 d[2] = _mm512_permutexvar_epi32(index1, r[2]);
888 d[3] = _mm512_permutexvar_epi32(index1, r[3]);
889
890 r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
891 r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
892 r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
893 r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
894
895 d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
896 d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
897 d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
898 d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
899
900 if (mrem < 16) {
901 // mask store
902 __mmask16 mask_rem_v = (((long long)1) << mrem) - 1;
903 _mm512_mask_storeu_epi32(dst + 0 * ld_dst, mask_rem_v, d[0]);
904 _mm512_mask_storeu_epi32(dst + 1 * ld_dst, mask_rem_v, d[1]);
905 _mm512_mask_storeu_epi32(dst + 2 * ld_dst, mask_rem_v, d[2]);
906 _mm512_mask_storeu_epi32(dst + 3 * ld_dst, mask_rem_v, d[3]);
907 } else {
908 // normal load
909 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]);
910 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]);
911 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]);
912 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]);
913 }
914}
915
916static inline void transpose_contiguous_16x2_block(
917 const float* src,
918 float* dst,
919 int64_t ld_dst,
920 int mrem = 16) {
921 __m512i r[2], d[2];
922 int i = 0;
923 for (; (i + 1) * 16 <= mrem * 2; i++) {
924 // normal load
925 r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 16));
926 }
927 if (i * 16 < mrem * 2) {
928 __mmask16 mask_mrem_v = (((long long)1) << (mrem * 2 - i * 16)) - 1;
929 r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16);
930 }
931 // transpose
932 __m512i index1 = _mm512_set_epi32(
933 0x1e,
934 0x1c,
935 0x1a,
936 0x18,
937 0x16,
938 0x14,
939 0x12,
940 0x10,
941 0x0e,
942 0x0c,
943 0x0a,
944 0x08,
945 0x06,
946 0x04,
947 0x02,
948 0x00);
949 __m512i index2 = _mm512_set_epi32(
950 0x1f,
951 0x1d,
952 0x1b,
953 0x19,
954 0x17,
955 0x15,
956 0x13,
957 0x11,
958 0x0f,
959 0x0d,
960 0x0b,
961 0x09,
962 0x07,
963 0x05,
964 0x03,
965 0x01);
966
967 // a0--p0
968 // a1--p1
969 d[0] = _mm512_permutex2var_epi32(r[0], index1, r[1]);
970 d[1] = _mm512_permutex2var_epi32(r[0], index2, r[1]);
971
972 // store
973 if (mrem < 16) {
974 __mmask16 mask_rem_v = (((long long)1) << mrem) - 1;
975 // mask store
976 _mm512_mask_storeu_epi32(dst, mask_rem_v, d[0]);
977 _mm512_mask_storeu_epi32(dst + ld_dst, mask_rem_v, d[1]);
978 } else {
979 // normal store
980 _mm512_storeu_si512(dst, d[0]);
981 _mm512_storeu_si512(dst + ld_dst, d[1]);
982 }
983}
984
985static inline void transpose_contiguous_64x4_block(
986 const uint8_t* src,
987 uint8_t* dst,
988 int64_t ld_dst,
989 int mrem = 64) {
990 __m512i r[4], d[4];
991 // normal load
992 int i = 0;
993 for (; (i + 1) * 64 <= mrem * 4; i++) {
994 r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 64));
995 }
996 int erem = mrem * 4 - i * 64;
997 if (erem > 0) {
998 __mmask64 mask_mrem_v = (((long long)1) << erem) - 1;
999 r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64);
1000 }
1001
1002 // transpose
1003 __m512i index = _mm512_set_epi32(
1004 0x0f0b0703,
1005 0x0e0a0602,
1006 0x0d090501,
1007 0x0c080400,
1008 0x0f0b0703,
1009 0x0e0a0602,
1010 0x0d090501,
1011 0x0c080400,
1012 0x0f0b0703,
1013 0x0e0a0602,
1014 0x0d090501,
1015 0x0c080400,
1016 0x0f0b0703,
1017 0x0e0a0602,
1018 0x0d090501,
1019 0x0c080400);
1020
1021 d[0] = _mm512_shuffle_epi8(r[0], index);
1022 d[1] = _mm512_shuffle_epi8(r[1], index);
1023 d[2] = _mm512_shuffle_epi8(r[2], index);
1024 d[3] = _mm512_shuffle_epi8(r[3], index);
1025
1026 __m512i index2 =
1027 _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
1028 r[0] = _mm512_permutexvar_epi32(index2, d[0]);
1029 r[1] = _mm512_permutexvar_epi32(index2, d[1]);
1030 r[2] = _mm512_permutexvar_epi32(index2, d[2]);
1031 r[3] = _mm512_permutexvar_epi32(index2, d[3]);
1032
1033 __m512i t0 = _mm512_shuffle_i32x4(r[0], r[1], 0x44);
1034 __m512i t1 = _mm512_shuffle_i32x4(r[0], r[1], 0xee);
1035 __m512i t2 = _mm512_shuffle_i32x4(r[2], r[3], 0x44);
1036 __m512i t3 = _mm512_shuffle_i32x4(r[2], r[3], 0xee);
1037
1038 d[0] = _mm512_shuffle_i32x4(t0, t2, 0x88);
1039 d[1] = _mm512_shuffle_i32x4(t0, t2, 0xdd);
1040 d[2] = _mm512_shuffle_i32x4(t1, t3, 0x88);
1041 d[3] = _mm512_shuffle_i32x4(t1, t3, 0xdd);
1042
1043 // store
1044 if (mrem < 64) {
1045 __mmask64 mask_rem_v = (((long long)1) << mrem) - 1;
1046 // mask store
1047 _mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]);
1048 _mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]);
1049 _mm512_mask_storeu_epi8(dst + 2 * ld_dst, mask_rem_v, d[2]);
1050 _mm512_mask_storeu_epi8(dst + 3 * ld_dst, mask_rem_v, d[3]);
1051 } else {
1052 // normal store
1053 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]);
1054 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]);
1055 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]);
1056 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]);
1057 }
1058}
1059
1060static inline void transpose_contiguous_32x4_block(
1061 const uint16_t* src,
1062 uint16_t* dst,
1063 int64_t ld_dst,
1064 int mrem = 32) {
1065 __m512i r[4], d[4];
1066 int i = 0;
1067 for (; (i + 1) * 32 <= mrem * 4; i++) {
1068 // normal load
1069 r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 32));
1070 }
1071 if (i * 32 < mrem * 4) {
1072 __mmask32 mask_mrem_v = (((long long)1) << (mrem * 4 - i * 32)) - 1;
1073 r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32);
1074 }
1075 // transpose
1076 __m512i index = _mm512_set_epi32(
1077 0x001f001b,
1078 0x00170013,
1079 0x000f000b,
1080 0x00070003,
1081 0x001e001a,
1082 0x00160012,
1083 0x000e000a,
1084 0x00060002,
1085 0x001d0019,
1086 0x00150011,
1087 0x000d0009,
1088 0x00050001,
1089 0x001c0018,
1090 0x00140010,
1091 0x000c0008,
1092 0x00040000);
1093
1094 d[0] = _mm512_permutexvar_epi16(index, r[0]);
1095 d[1] = _mm512_permutexvar_epi16(index, r[1]);
1096 d[2] = _mm512_permutexvar_epi16(index, r[2]);
1097 d[3] = _mm512_permutexvar_epi16(index, r[3]);
1098
1099 r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
1100 r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
1101 r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
1102 r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
1103
1104 d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
1105 d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
1106 d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
1107 d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
1108
1109 if (mrem < 32) {
1110 // mask store
1111 __mmask32 mask_rem_v = (((long long)1) << mrem) - 1;
1112 _mm512_mask_storeu_epi16(dst + 0 * ld_dst, mask_rem_v, d[0]);
1113 _mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, d[1]);
1114 _mm512_mask_storeu_epi16(dst + 2 * ld_dst, mask_rem_v, d[2]);
1115 _mm512_mask_storeu_epi16(dst + 3 * ld_dst, mask_rem_v, d[3]);
1116 } else {
1117 // normal load
1118 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]);
1119 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]);
1120 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]);
1121 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]);
1122 }
1123}
1124
1125static inline void transpose_contiguous_2x16_block(
1126 const float* src,
1127 float* dst,
1128 int64_t ld_src,
1129 int nrem = 16) {
1130 __m512i r0, r1;
1131 // load
1132 if (nrem < 16) {
1133 __mmask16 mask_mrem_v = (((long long)1) << nrem) - 1;
1134 r0 = _mm512_maskz_loadu_epi32(mask_mrem_v, src);
1135 r1 = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src);
1136 } else {
1137 r0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
1138 r1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
1139 }
1140 // transpose
1141 __m512i index1 = _mm512_set_epi32(
1142 0x0017,
1143 0x0007,
1144 0x0016,
1145 0x0006,
1146 0x0015,
1147 0x0005,
1148 0x0014,
1149 0x0004,
1150 0x0013,
1151 0x0003,
1152 0x0012,
1153 0x0002,
1154 0x0011,
1155 0x0001,
1156 0x0010,
1157 0x0000);
1158 __m512i index2 = _mm512_set_epi32(
1159 0x001f,
1160 0x000f,
1161 0x001e,
1162 0x000e,
1163 0x001d,
1164 0x000d,
1165 0x001c,
1166 0x000c,
1167 0x001b,
1168 0x000b,
1169 0x001a,
1170 0x000a,
1171 0x0019,
1172 0x0009,
1173 0x0018,
1174 0x0008);
1175 // a0 b0 a1 b1 a2 b2 a3 b3 a4 b4 a5 b5 a6 b6 a7 b7
1176 // a8 b8 a9 b9 a10 b10 a11 b11 a12 b12 a13 b13 a14 b14 a15 b15
1177 __m512i u0 = _mm512_permutex2var_epi32(r0, index1, r1);
1178 __m512i u1 = _mm512_permutex2var_epi32(r0, index2, r1);
1179 // store
1180 if (nrem < 16) {
1181 // mask store
1182 if (nrem < 8) {
1183 __mmask16 mask_rem_v = (((long long)1) << (nrem * 2)) - 1;
1184 _mm512_mask_storeu_epi32(dst, mask_rem_v, u0);
1185 } else {
1186 __mmask16 mask_rem_v = (((long long)1) << ((nrem - 8) * 2)) - 1;
1187 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0);
1188 _mm512_mask_storeu_epi32(dst + 16, mask_rem_v, u1);
1189 }
1190 } else {
1191 // normal store
1192 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0);
1193 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 16), u1);
1194 }
1195}
1196
1197static inline void transpose_contiguous_64x2_block(
1198 const uint8_t* src,
1199 uint8_t* dst,
1200 int64_t ld_dst,
1201 int mrem = 64) {
1202 __m512i r[2], d[2];
1203 // normal load
1204 int i = 0;
1205 for (; (i + 1) * 64 <= mrem * 2; i++) {
1206 r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 64));
1207 }
1208 int erem = mrem * 2 - i * 64;
1209 if (erem > 0) {
1210 __mmask64 mask_mrem_v = (((long long)1) << erem) - 1;
1211 r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64);
1212 }
1213
1214 // transpose
1215 __m512i index1 = _mm512_set_epi32(
1216 0x0f0d0b09,
1217 0x07050301,
1218 0x0e0c0a08,
1219 0x06040200,
1220 0x0f0d0b09,
1221 0x07050301,
1222 0x0e0c0a08,
1223 0x06040200,
1224 0x0f0d0b09,
1225 0x07050301,
1226 0x0e0c0a08,
1227 0x06040200,
1228 0x0f0d0b09,
1229 0x07050301,
1230 0x0e0c0a08,
1231 0x06040200);
1232 r[0] = _mm512_shuffle_epi8(r[0], index1);
1233 r[1] = _mm512_shuffle_epi8(r[1], index1);
1234
1235 __m512i index2 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0);
1236 __m512i index3 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1);
1237
1238 d[0] = _mm512_permutex2var_epi64(r[0], index2, r[1]);
1239 d[1] = _mm512_permutex2var_epi64(r[0], index3, r[1]);
1240
1241 // store
1242 if (mrem < 64) {
1243 __mmask64 mask_rem_v = (((long long)1) << mrem) - 1;
1244 // mask store
1245 _mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]);
1246 _mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]);
1247 } else {
1248 // normal store
1249 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d[0]);
1250 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + ld_dst), d[1]);
1251 }
1252}
1253
1254static inline void transpose_contiguous_4x64_block(
1255 const uint8_t* src,
1256 uint8_t* dst,
1257 int64_t ld_src,
1258 int nrem = 64) {
1259 __m512i r[4], d[4];
1260 // load
1261 if (nrem < 64) {
1262 __mmask64 mask_mrem_v = (((long long)1) << nrem) - 1;
1263 r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src);
1264 r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src);
1265 r[2] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 2 * ld_src);
1266 r[3] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 3 * ld_src);
1267 } else {
1268 r[0] = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src));
1269 r[1] = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + ld_src));
1270 r[2] =
1271 _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + 2 * ld_src));
1272 r[3] =
1273 _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + 3 * ld_src));
1274 }
1275 // transpose
1276 d[0] = _mm512_unpacklo_epi32(r[0], r[1]);
1277 d[1] = _mm512_unpackhi_epi32(r[0], r[1]);
1278 d[2] = _mm512_unpacklo_epi32(r[2], r[3]);
1279 d[3] = _mm512_unpackhi_epi32(r[2], r[3]);
1280
1281 r[0] = _mm512_unpacklo_epi64(d[0], d[2]);
1282 r[1] = _mm512_unpackhi_epi64(d[0], d[2]);
1283 r[2] = _mm512_unpacklo_epi64(d[1], d[3]);
1284 r[3] = _mm512_unpackhi_epi64(d[1], d[3]);
1285
1286 d[0] = _mm512_shuffle_i32x4(r[0], r[1], 0x44);
1287 d[1] = _mm512_shuffle_i32x4(r[0], r[1], 0xee);
1288 d[2] = _mm512_shuffle_i32x4(r[2], r[3], 0x44);
1289 d[3] = _mm512_shuffle_i32x4(r[2], r[3], 0xee);
1290
1291 r[0] = _mm512_shuffle_i32x4(d[0], d[2], 0x88);
1292 r[1] = _mm512_shuffle_i32x4(d[0], d[2], 0xdd);
1293 r[2] = _mm512_shuffle_i32x4(d[1], d[3], 0x88);
1294 r[3] = _mm512_shuffle_i32x4(d[1], d[3], 0xdd);
1295
1296 __m512i index = _mm512_set_epi32(
1297 0x0f0b0703,
1298 0x0e0a0602,
1299 0x0d090501,
1300 0x0c080400,
1301 0x0f0b0703,
1302 0x0e0a0602,
1303 0x0d090501,
1304 0x0c080400,
1305 0x0f0b0703,
1306 0x0e0a0602,
1307 0x0d090501,
1308 0x0c080400,
1309 0x0f0b0703,
1310 0x0e0a0602,
1311 0x0d090501,
1312 0x0c080400);
1313
1314 d[0] = _mm512_shuffle_epi8(r[0], index);
1315 d[1] = _mm512_shuffle_epi8(r[1], index);
1316 d[2] = _mm512_shuffle_epi8(r[2], index);
1317 d[3] = _mm512_shuffle_epi8(r[3], index);
1318
1319 // store
1320 int i = 0;
1321 for (; (i + 1) * 64 <= nrem * 4; i++) {
1322 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 64), d[i]);
1323 }
1324 int erem = nrem * 4 - i * 64;
1325 if (erem > 0) {
1326 __mmask64 mask_rem_v = (((long long)1) << erem) - 1;
1327 _mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]);
1328 }
1329}
1330
1331static inline void transpose_contiguous_2x64_block(
1332 const uint8_t* src,
1333 uint8_t* dst,
1334 int64_t ld_src,
1335 int nrem = 64) {
1336 __m512i r[2];
1337 __m512i d[2];
1338 // load
1339 if (nrem < 64) {
1340 __mmask64 mask_mrem_v = (((long long)1) << nrem) - 1;
1341 r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src);
1342 r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src);
1343 } else {
1344 r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
1345 r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
1346 }
1347 // transpose
1348 // _mm512_mask_blend_epi8(0xaaaaaaaaaaaaaaaa, r0, r1);
1349 d[0] = _mm512_unpacklo_epi16(r[0], r[1]);
1350 d[1] = _mm512_unpackhi_epi16(r[0], r[1]);
1351 __m512i index1 = _mm512_set_epi32(
1352 0x0f0d0e0c,
1353 0x0b090a08,
1354 0x07050604,
1355 0x03010200,
1356 0x0f0d0e0c,
1357 0x0b090a08,
1358 0x07050604,
1359 0x03010200,
1360 0x0f0d0e0c,
1361 0x0b090a08,
1362 0x07050604,
1363 0x03010200,
1364 0x0f0d0e0c,
1365 0x0b090a08,
1366 0x07050604,
1367 0x03010200);
1368 r[0] = _mm512_shuffle_epi8(d[0], index1);
1369 r[1] = _mm512_shuffle_epi8(d[1], index1);
1370 __m512i index2 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
1371 __m512i index3 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
1372 // a0b0 a1b1 ... a31b31
1373 // a32b32 ... a63b63
1374 d[0] = _mm512_permutex2var_epi64(r[0], index2, r[1]);
1375 d[1] = _mm512_permutex2var_epi64(r[0], index3, r[1]);
1376
1377 int i = 0;
1378 for (; (i + 1) * 64 <= nrem * 2; i++) {
1379 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 64), d[i]);
1380 }
1381 int erem = nrem * 2 - i * 64;
1382 if (erem > 0) {
1383 __mmask64 mask_rem_v = (((long long)1) << erem) - 1;
1384 _mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]);
1385 }
1386}
1387
1388static inline void transpose_contiguous_2x32_block(
1389 const uint16_t* src,
1390 uint16_t* dst,
1391 int64_t ld_src,
1392 int nrem = 32) {
1393 __m512i r0, r1;
1394 __m512i d0, d1;
1395 // load
1396 if (nrem < 32) {
1397 __mmask32 mask_mrem_v = (((long long)1) << nrem) - 1;
1398 r0 = _mm512_maskz_loadu_epi16(mask_mrem_v, src);
1399 r1 = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src);
1400 } else {
1401 r0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
1402 r1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
1403 }
1404 // transpose
1405 d0 = _mm512_unpacklo_epi16(r0, r1);
1406 d1 = _mm512_unpackhi_epi16(r0, r1);
1407 r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
1408 r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
1409 d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
1410 d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
1411
1412 // store
1413 if (nrem < 16) {
1414 __mmask32 mask_rem_v = (((long long)1) << (nrem * 2)) - 1;
1415 _mm512_mask_storeu_epi16(dst, mask_rem_v, d0);
1416 } else if (nrem == 16) {
1417 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
1418 } else if (nrem < 32) {
1419 __mmask32 mask_rem_v = (((long long)1) << (nrem * 2 - 32)) - 1;
1420 _mm512_mask_storeu_epi16(dst, mask_rem_v, d0);
1421 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
1422 _mm512_mask_storeu_epi16(
1423 reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1);
1424 } else {
1425 // normal store
1426 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
1427 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1);
1428 }
1429}
1430
1431static inline void transpose_contiguous_32x2_block(
1432 const uint16_t* src,
1433 uint16_t* dst,
1434 int64_t ld_dst,
1435 int mrem = 32) {
1436 __m512i r[2], d[2];
1437 // load
1438 int i = 0;
1439 for (; (i + 1) * 32 <= mrem * 2; i++) {
1440 r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 32));
1441 }
1442 int erem = mrem * 2 - i * 32;
1443 if (erem > 0) {
1444 __mmask32 mask_mrem_v = (((long long)1) << erem) - 1;
1445 r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32);
1446 }
1447 // transpose
1448 __m512i index = _mm512_set_epi32(
1449 0x001f001d,
1450 0x001b0019,
1451 0x00170015,
1452 0x00130011,
1453 0x000f000d,
1454 0x000b0009,
1455 0x00070005,
1456 0x00030001,
1457 0x001e001c,
1458 0x001a0018,
1459 0x00160014,
1460 0x00120010,
1461 0x000e000c,
1462 0x000a0008,
1463 0x00060004,
1464 0x00020000);
1465 d[0] = _mm512_permutexvar_epi16(index, r[0]);
1466 d[1] = _mm512_permutexvar_epi16(index, r[1]);
1467 r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
1468 r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
1469
1470 // store
1471 if (mrem < 32) {
1472 __mmask32 mask_rem_v = (((long long)1) << mrem) - 1;
1473 // mask store
1474 _mm512_mask_storeu_epi16(dst, mask_rem_v, r[0]);
1475 _mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, r[1]);
1476 } else {
1477 // normal store
1478 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r[0]);
1479 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + ld_dst), r[1]);
1480 }
1481}
1482
1483template <bool MREM = false, bool NREM = false>
1484void transpose_16x16_block(
1485 const uint16_t* src,
1486 int64_t ld_src,
1487 uint16_t* dst,
1488 int64_t ld_dst,
1489 int mrem = 16,
1490 int nrem = 16) {
1491 __m512i r[8];
1492 if (MREM || NREM) {
1493 load_with_remainders_i16(src, ld_src, r, mrem, nrem);
1494 } else {
1495 __m256i t00 =
1496 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 0 * ld_src));
1497 __m256i t01 =
1498 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 1 * ld_src));
1499 __m256i t02 =
1500 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 2 * ld_src));
1501 __m256i t03 =
1502 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 3 * ld_src));
1503 __m256i t04 =
1504 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 4 * ld_src));
1505 __m256i t05 =
1506 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 5 * ld_src));
1507 __m256i t06 =
1508 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 6 * ld_src));
1509 __m256i t07 =
1510 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 7 * ld_src));
1511 __m256i t08 =
1512 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 8 * ld_src));
1513 __m256i t09 =
1514 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 9 * ld_src));
1515 __m256i t10 =
1516 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 10 * ld_src));
1517 __m256i t11 =
1518 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 11 * ld_src));
1519 __m256i t12 =
1520 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 12 * ld_src));
1521 __m256i t13 =
1522 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 13 * ld_src));
1523 __m256i t14 =
1524 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 14 * ld_src));
1525 __m256i t15 =
1526 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 15 * ld_src));
1527
1528 // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15
1529 // e0e1 e2e3 e4e5 e6e7 e8e9 e10e11 e12e13 e14e15
1530 r[0] = _mm512_inserti64x4(_mm512_castsi256_si512(t00), t04, 0x01);
1531 // b0-b15
1532 // f0-f15
1533 r[1] = _mm512_inserti64x4(_mm512_castsi256_si512(t01), t05, 0x01);
1534 // c0-c15
1535 // g0-g15
1536 r[2] = _mm512_inserti64x4(_mm512_castsi256_si512(t02), t06, 0x01);
1537 // d0-d15
1538 // g0-h15
1539 r[3] = _mm512_inserti64x4(_mm512_castsi256_si512(t03), t07, 0x01);
1540 // i0-i15
1541 // m0-m15
1542 r[4] = _mm512_inserti64x4(_mm512_castsi256_si512(t08), t12, 0x01);
1543 // j0-j15
1544 // n0-n15
1545 r[5] = _mm512_inserti64x4(_mm512_castsi256_si512(t09), t13, 0x01);
1546 // k0-k15
1547 // o0-o15
1548 r[6] = _mm512_inserti64x4(_mm512_castsi256_si512(t10), t14, 0x01);
1549 // l0-l15
1550 // p0-p15
1551 r[7] = _mm512_inserti64x4(_mm512_castsi256_si512(t11), t15, 0x01);
1552 }
1553 __m512i u[8];
1554 core_transpose_16x16_block(r, u);
1555 if (MREM || NREM) {
1556 store_with_remainders_i16(dst, ld_dst, u, mrem, nrem);
1557 } else {
1558 _mm256_storeu_si256(
1559 reinterpret_cast<__m256i*>(dst + 0 * ld_dst),
1560 _mm512_extracti32x8_epi32(u[0], 0x0));
1561 _mm256_storeu_si256(
1562 reinterpret_cast<__m256i*>(dst + 1 * ld_dst),
1563 _mm512_extracti32x8_epi32(u[0], 0x01));
1564 _mm256_storeu_si256(
1565 reinterpret_cast<__m256i*>(dst + 2 * ld_dst),
1566 _mm512_extracti32x8_epi32(u[1], 0x0));
1567 _mm256_storeu_si256(
1568 reinterpret_cast<__m256i*>(dst + 3 * ld_dst),
1569 _mm512_extracti32x8_epi32(u[1], 0x01));
1570 _mm256_storeu_si256(
1571 reinterpret_cast<__m256i*>(dst + 4 * ld_dst),
1572 _mm512_extracti32x8_epi32(u[2], 0x0));
1573 _mm256_storeu_si256(
1574 reinterpret_cast<__m256i*>(dst + 5 * ld_dst),
1575 _mm512_extracti32x8_epi32(u[2], 0x01));
1576 _mm256_storeu_si256(
1577 reinterpret_cast<__m256i*>(dst + 6 * ld_dst),
1578 _mm512_extracti32x8_epi32(u[3], 0x0));
1579 _mm256_storeu_si256(
1580 reinterpret_cast<__m256i*>(dst + 7 * ld_dst),
1581 _mm512_extracti32x8_epi32(u[3], 0x01));
1582 _mm256_storeu_si256(
1583 reinterpret_cast<__m256i*>(dst + 8 * ld_dst),
1584 _mm512_extracti32x8_epi32(u[4], 0x0));
1585 _mm256_storeu_si256(
1586 reinterpret_cast<__m256i*>(dst + 9 * ld_dst),
1587 _mm512_extracti32x8_epi32(u[4], 0x01));
1588 _mm256_storeu_si256(
1589 reinterpret_cast<__m256i*>(dst + 10 * ld_dst),
1590 _mm512_extracti32x8_epi32(u[5], 0x0));
1591 _mm256_storeu_si256(
1592 reinterpret_cast<__m256i*>(dst + 11 * ld_dst),
1593 _mm512_extracti32x8_epi32(u[5], 0x01));
1594 _mm256_storeu_si256(
1595 reinterpret_cast<__m256i*>(dst + 12 * ld_dst),
1596 _mm512_extracti32x8_epi32(u[6], 0x0));
1597 _mm256_storeu_si256(
1598 reinterpret_cast<__m256i*>(dst + 13 * ld_dst),
1599 _mm512_extracti32x8_epi32(u[6], 0x01));
1600 _mm256_storeu_si256(
1601 reinterpret_cast<__m256i*>(dst + 14 * ld_dst),
1602 _mm512_extracti32x8_epi32(u[7], 0x0));
1603 _mm256_storeu_si256(
1604 reinterpret_cast<__m256i*>(dst + 15 * ld_dst),
1605 _mm512_extracti32x8_epi32(u[7], 0x01));
1606 }
1607}
1608
1609template <bool MREM = false, bool NREM = false>
1610void transpose_16x32_block(
1611 const uint8_t* src,
1612 int64_t ld_src,
1613 uint8_t* dst,
1614 int64_t ld_dst,
1615 int mrem = 16,
1616 int nrem = 32) {
1617 // Treat the numbers in a row as 4-Byte integers.
1618 // Thus 03_04 is is 4-byte element in 03 row and 04 column
1619 //
1620 // 00_00 00_01 00_02 00_03 00_04 00_05 00_06 00_07
1621 // 01_00 01_01 01_02 01_03 01_04 01_05 01_06 01_07
1622 // 02_00 02_01 02_02 02_03 02_04 02_05 02_06 02_07
1623 // 03_00 03_01 03_02 03_03 03_04 03_05 03_06 03_07
1624 // 04_00 04_01 04_02 04_03 04_04 04_05 04_06 04_07
1625 // 05_00 05_01 05_02 05_03 05_04 05_05 05_06 05_07
1626 // 06_00 06_01 06_02 06_03 06_04 06_05 06_06 06_07
1627 // 07_00 07_01 07_02 07_03 07_04 07_05 07_06 07_07
1628 // 08_00 08_01 08_02 08_03 08_04 08_05 08_06 08_07
1629 // 09_00 09_01 09_02 09_03 09_04 09_05 09_06 09_07
1630 // 10_00 10_01 10_02 10_03 10_04 10_05 10_06 10_07
1631 // 11_00 11_01 11_02 11_03 11_04 11_05 11_06 11_07
1632 // 12_00 12_01 12_02 12_03 12_04 12_05 12_06 12_07
1633 // 13_00 13_01 13_02 13_03 13_04 13_05 13_06 13_07
1634 // 14_00 14_01 14_02 14_03 14_04 14_05 14_06 14_07
1635 // 15_00 15_01 15_02 15_03 15_04 15_05 15_06 15_07
1636
1637 __m512i r[8];
1638 if (MREM || NREM) {
1639 load_with_remainders_i8(src, ld_src, r, mrem, nrem);
1640 } else {
1641 __m256i t00 =
1642 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 0 * ld_src));
1643 __m256i t04 =
1644 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 4 * ld_src));
1645 // 00_00 00_01 00_02 00_03 00_04 00_05 00_06 00_07 04_00 04_01 04_02 04_03
1646 // 04_04 04_05 04_06 04_07
1647 r[0] = _mm512_inserti64x4(_mm512_castsi256_si512(t00), t04, 0x01);
1648
1649 __m256i t01 =
1650 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 1 * ld_src));
1651 __m256i t05 =
1652 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 5 * ld_src));
1653 // 01_00 01_01 01_02 01_03 01_04 01_05 01_06 01_07 05_00 05_01 05_02 05_03
1654 // 05_04 05_05 05_06 05_07
1655 r[1] = _mm512_inserti64x4(_mm512_castsi256_si512(t01), t05, 0x01);
1656
1657 __m256i t02 =
1658 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 2 * ld_src));
1659 __m256i t06 =
1660 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 6 * ld_src));
1661 // 02_00 02_01 02_02 02_03 02_04 02_05 02_06 02_07 06_00 06_01 06_02 06_03
1662 // 06_04 06_05 06_06 06_07
1663 r[2] = _mm512_inserti64x4(_mm512_castsi256_si512(t02), t06, 0x01);
1664
1665 __m256i t03 =
1666 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 3 * ld_src));
1667 __m256i t07 =
1668 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 7 * ld_src));
1669 // 03_00 03_01 03_02 03_03 03_04 03_05 03_06 03_07 07_00 07_01 07_02 07_03
1670 // 07_04 07_05 07_06 07_07
1671 r[3] = _mm512_inserti64x4(_mm512_castsi256_si512(t03), t07, 0x01);
1672
1673 __m256i t08 =
1674 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 8 * ld_src));
1675 __m256i t12 =
1676 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 12 * ld_src));
1677 // 08_00 08_01 08_02 08_03 08_04 08_05 08_06 08_07 12_00 12_01 12_02 12_03
1678 // 12_04 12_05 12_06 12_07
1679 r[4] = _mm512_inserti64x4(_mm512_castsi256_si512(t08), t12, 0x01);
1680
1681 __m256i t09 =
1682 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 9 * ld_src));
1683 __m256i t13 =
1684 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 13 * ld_src));
1685 // 09_00 09_01 09_02 09_03 09_04 09_05 09_06 09_07 13_00 13_01 13_02 13_03
1686 // 13_04 13_05 13_06 13_07
1687 r[5] = _mm512_inserti64x4(_mm512_castsi256_si512(t09), t13, 0x01);
1688
1689 __m256i t10 =
1690 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 10 * ld_src));
1691 __m256i t14 =
1692 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 14 * ld_src));
1693 // 10_00 10_01 10_02 10_03 10_04 10_05 10_06 10_07 14_00 14_01 14_02 14_03
1694 // 14_04 14_05 14_06 14_07
1695 r[6] = _mm512_inserti64x4(_mm512_castsi256_si512(t10), t14, 0x01);
1696
1697 __m256i t11 =
1698 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 11 * ld_src));
1699 __m256i t15 =
1700 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 15 * ld_src));
1701 // 11_00 11_01 11_02 11_03 11_04 11_05 11_06 11_07 15_00 15_01 15_02 15_03
1702 // 15_04 15_05 15_06 15_07
1703 r[7] = _mm512_inserti64x4(_mm512_castsi256_si512(t11), t15, 0x01);
1704 }
1705
1706 __m512i u[8];
1707 core_transpose_16x32_block_i8(r, u);
1708
1709 if (MREM || NREM) {
1710 store_with_remainders_i8(dst, ld_dst, u, mrem, nrem);
1711 } else {
1712 _mm_storeu_si128(
1713 reinterpret_cast<__m128i*>(dst + 0 * ld_dst),
1714 _mm512_extracti32x4_epi32(u[0], 0x0));
1715 _mm_storeu_si128(
1716 reinterpret_cast<__m128i*>(dst + 1 * ld_dst),
1717 _mm512_extracti32x4_epi32(u[0], 0x1));
1718 _mm_storeu_si128(
1719 reinterpret_cast<__m128i*>(dst + 2 * ld_dst),
1720 _mm512_extracti32x4_epi32(u[0], 0x2));
1721 _mm_storeu_si128(
1722 reinterpret_cast<__m128i*>(dst + 3 * ld_dst),
1723 _mm512_extracti32x4_epi32(u[0], 0x3));
1724 _mm_storeu_si128(
1725 reinterpret_cast<__m128i*>(dst + 16 * ld_dst),
1726 _mm512_extracti32x4_epi32(u[1], 0x0));
1727 _mm_storeu_si128(
1728 reinterpret_cast<__m128i*>(dst + 17 * ld_dst),
1729 _mm512_extracti32x4_epi32(u[1], 0x1));
1730 _mm_storeu_si128(
1731 reinterpret_cast<__m128i*>(dst + 18 * ld_dst),
1732 _mm512_extracti32x4_epi32(u[1], 0x2));
1733 _mm_storeu_si128(
1734 reinterpret_cast<__m128i*>(dst + 19 * ld_dst),
1735 _mm512_extracti32x4_epi32(u[1], 0x3));
1736 _mm_storeu_si128(
1737 reinterpret_cast<__m128i*>(dst + 4 * ld_dst),
1738 _mm512_extracti32x4_epi32(u[2], 0x0));
1739 _mm_storeu_si128(
1740 reinterpret_cast<__m128i*>(dst + 5 * ld_dst),
1741 _mm512_extracti32x4_epi32(u[2], 0x1));
1742 _mm_storeu_si128(
1743 reinterpret_cast<__m128i*>(dst + 6 * ld_dst),
1744 _mm512_extracti32x4_epi32(u[2], 0x2));
1745 _mm_storeu_si128(
1746 reinterpret_cast<__m128i*>(dst + 7 * ld_dst),
1747 _mm512_extracti32x4_epi32(u[2], 0x3));
1748 _mm_storeu_si128(
1749 reinterpret_cast<__m128i*>(dst + 20 * ld_dst),
1750 _mm512_extracti32x4_epi32(u[3], 0x0));
1751 _mm_storeu_si128(
1752 reinterpret_cast<__m128i*>(dst + 21 * ld_dst),
1753 _mm512_extracti32x4_epi32(u[3], 0x1));
1754 _mm_storeu_si128(
1755 reinterpret_cast<__m128i*>(dst + 22 * ld_dst),
1756 _mm512_extracti32x4_epi32(u[3], 0x2));
1757 _mm_storeu_si128(
1758 reinterpret_cast<__m128i*>(dst + 23 * ld_dst),
1759 _mm512_extracti32x4_epi32(u[3], 0x3));
1760 _mm_storeu_si128(
1761 reinterpret_cast<__m128i*>(dst + 8 * ld_dst),
1762 _mm512_extracti32x4_epi32(u[4], 0x0));
1763 _mm_storeu_si128(
1764 reinterpret_cast<__m128i*>(dst + 9 * ld_dst),
1765 _mm512_extracti32x4_epi32(u[4], 0x1));
1766 _mm_storeu_si128(
1767 reinterpret_cast<__m128i*>(dst + 10 * ld_dst),
1768 _mm512_extracti32x4_epi32(u[4], 0x2));
1769 _mm_storeu_si128(
1770 reinterpret_cast<__m128i*>(dst + 11 * ld_dst),
1771 _mm512_extracti32x4_epi32(u[4], 0x3));
1772 _mm_storeu_si128(
1773 reinterpret_cast<__m128i*>(dst + 24 * ld_dst),
1774 _mm512_extracti32x4_epi32(u[5], 0x0));
1775 _mm_storeu_si128(
1776 reinterpret_cast<__m128i*>(dst + 25 * ld_dst),
1777 _mm512_extracti32x4_epi32(u[5], 0x1));
1778 _mm_storeu_si128(
1779 reinterpret_cast<__m128i*>(dst + 26 * ld_dst),
1780 _mm512_extracti32x4_epi32(u[5], 0x2));
1781 _mm_storeu_si128(
1782 reinterpret_cast<__m128i*>(dst + 27 * ld_dst),
1783 _mm512_extracti32x4_epi32(u[5], 0x3));
1784 _mm_storeu_si128(
1785 reinterpret_cast<__m128i*>(dst + 12 * ld_dst),
1786 _mm512_extracti32x4_epi32(u[6], 0x0));
1787 _mm_storeu_si128(
1788 reinterpret_cast<__m128i*>(dst + 13 * ld_dst),
1789 _mm512_extracti32x4_epi32(u[6], 0x1));
1790 _mm_storeu_si128(
1791 reinterpret_cast<__m128i*>(dst + 14 * ld_dst),
1792 _mm512_extracti32x4_epi32(u[6], 0x2));
1793 _mm_storeu_si128(
1794 reinterpret_cast<__m128i*>(dst + 15 * ld_dst),
1795 _mm512_extracti32x4_epi32(u[6], 0x3));
1796 _mm_storeu_si128(
1797 reinterpret_cast<__m128i*>(dst + 28 * ld_dst),
1798 _mm512_extracti32x4_epi32(u[7], 0x0));
1799 _mm_storeu_si128(
1800 reinterpret_cast<__m128i*>(dst + 29 * ld_dst),
1801 _mm512_extracti32x4_epi32(u[7], 0x1));
1802 _mm_storeu_si128(
1803 reinterpret_cast<__m128i*>(dst + 30 * ld_dst),
1804 _mm512_extracti32x4_epi32(u[7], 0x2));
1805 _mm_storeu_si128(
1806 reinterpret_cast<__m128i*>(dst + 31 * ld_dst),
1807 _mm512_extracti32x4_epi32(u[7], 0x3));
1808 }
1809}
1810
1811template <>
1812void transpose_avx512_contiguous_thin(
1813 int64_t M,
1814 int64_t N,
1815 const float* src,
1816 int64_t ld_src,
1817 float* dst,
1818 int64_t ld_dst) {
1819 if (N == 2) {
1820 int64_t i = 0;
1821 for (; i < M / 16 * 16; i += 16) {
1822 transpose_contiguous_16x2_block(src + i * ld_src, dst + i, ld_dst);
1823 }
1824 int mrem = M - i;
1825 if (mrem > 0) {
1826 transpose_contiguous_16x2_block(src + i * ld_src, dst + i, ld_dst, mrem);
1827 }
1828 } else if (N == 4) {
1829 int64_t i = 0;
1830 for (; i < M / 16 * 16; i += 16) {
1831 transpose_contiguous_16x4_block(src + i * ld_src, dst + i, ld_dst);
1832 }
1833 int mrem = M - i;
1834 if (mrem > 0) {
1835 transpose_contiguous_16x4_block(src + i * ld_src, dst + i, ld_dst, mrem);
1836 }
1837 }
1838}
1839
1840template <>
1841void transpose_avx512_contiguous_thin(
1842 int64_t M,
1843 int64_t N,
1844 const uint16_t* src,
1845 int64_t ld_src,
1846 uint16_t* dst,
1847 int64_t ld_dst) {
1848 if (N == 2) {
1849 int64_t i = 0;
1850 for (; i < M / 32 * 32; i += 32) {
1851 transpose_contiguous_32x2_block(src + i * ld_src, dst + i, ld_dst);
1852 }
1853 int mrem = M - i;
1854 if (mrem > 0) {
1855 transpose_contiguous_32x2_block(src + i * ld_src, dst + i, ld_dst, mrem);
1856 }
1857 } else if (N == 4) {
1858 int64_t i = 0;
1859 for (; i < M / 32 * 32; i += 32) {
1860 transpose_contiguous_32x4_block(src + i * ld_src, dst + i, ld_dst);
1861 }
1862 int mrem = M - i;
1863 if (mrem > 0) {
1864 transpose_contiguous_32x4_block(src + i * ld_src, dst + i, ld_dst, mrem);
1865 }
1866 }
1867}
1868
1869template <>
1870void transpose_avx512_contiguous_thin(
1871 int64_t M,
1872 int64_t N,
1873 const uint8_t* src,
1874 int64_t ld_src,
1875 uint8_t* dst,
1876 int64_t ld_dst) {
1877 if (N == 2) {
1878 int64_t i = 0;
1879 for (; i < M / 64 * 64; i += 64) {
1880 transpose_contiguous_64x2_block(src + i * ld_src, dst + i, ld_dst);
1881 }
1882 int mrem = M - i;
1883 if (mrem > 0) {
1884 transpose_contiguous_64x2_block(src + i * ld_src, dst + i, ld_dst, mrem);
1885 }
1886 } else if (N == 4) {
1887 int64_t i = 0;
1888 for (; i < M / 64 * 64; i += 64) {
1889 transpose_contiguous_64x4_block(src + i * ld_src, dst + i, ld_dst);
1890 }
1891 int mrem = M - i;
1892 if (mrem > 0) {
1893 transpose_contiguous_64x4_block(src + i * ld_src, dst + i, ld_dst, mrem);
1894 }
1895 }
1896}
1897
1898template <>
1899void transpose_avx512_contiguous_wide(
1900 int64_t M,
1901 int64_t N,
1902 const float* src,
1903 int64_t ld_src,
1904 float* dst,
1905 int64_t ld_dst) {
1906 if (M == 2) {
1907 int64_t i = 0;
1908 for (; i < N / 16 * 16; i += 16) {
1909 transpose_contiguous_2x16_block(src + i, dst + i * ld_dst, ld_src);
1910 }
1911 int nrem = N - i;
1912 if (nrem > 0) {
1913 transpose_contiguous_2x16_block(src + i, dst + i * ld_dst, ld_src, nrem);
1914 }
1915 } else if (M == 4) {
1916 int64_t i = 0;
1917 for (; i < N / 16 * 16; i += 16) {
1918 transpose_contiguous_4x16_block(src + i, dst + i * ld_dst, ld_src);
1919 }
1920 int nrem = N - i;
1921 if (nrem > 0) {
1922 transpose_contiguous_4x16_block(src + i, dst + i * ld_dst, ld_src, nrem);
1923 }
1924 }
1925}
1926
1927template <>
1928void transpose_avx512_contiguous_wide(
1929 int64_t M,
1930 int64_t N,
1931 const uint16_t* src,
1932 int64_t ld_src,
1933 uint16_t* dst,
1934 int64_t ld_dst) {
1935 if (M == 2) {
1936 int64_t i = 0;
1937 for (; i < N / 32 * 32; i += 32) {
1938 transpose_contiguous_2x32_block(src + i, dst + i * ld_dst, ld_src);
1939 }
1940 int nrem = N - i;
1941 if (nrem > 0) {
1942 transpose_contiguous_2x32_block(src + i, dst + i * ld_dst, ld_src, nrem);
1943 }
1944 } else if (M == 4) {
1945 int64_t i = 0;
1946 for (; i < N / 32 * 32; i += 32) {
1947 transpose_contiguous_4x32_block(src + i, dst + i * ld_dst, ld_src);
1948 }
1949 int nrem = N - i;
1950 if (nrem > 0) {
1951 transpose_contiguous_4x32_block(src + i, dst + i * ld_dst, ld_src, nrem);
1952 }
1953 }
1954}
1955
1956template <>
1957void transpose_avx512_contiguous_wide(
1958 int64_t M,
1959 int64_t N,
1960 const uint8_t* src,
1961 int64_t ld_src,
1962 uint8_t* dst,
1963 int64_t ld_dst) {
1964 if (M == 2) {
1965 int64_t i = 0;
1966 for (; i < N / 64 * 64; i += 64) {
1967 transpose_contiguous_2x64_block(src + i, dst + i * ld_dst, ld_src);
1968 }
1969 int nrem = N - i;
1970 if (nrem > 0) {
1971 transpose_contiguous_2x64_block(src + i, dst + i * ld_dst, ld_src, nrem);
1972 }
1973 } else if (M == 4) {
1974 int64_t i = 0;
1975 for (; i < N / 64 * 64; i += 64) {
1976 transpose_contiguous_4x64_block(src + i, dst + i * ld_dst, ld_src);
1977 }
1978 int nrem = N - i;
1979 if (nrem > 0) {
1980 transpose_contiguous_4x64_block(src + i, dst + i * ld_dst, ld_src, nrem);
1981 }
1982 }
1983}
1984
1985template <>
1986void transpose_avx512(
1987 int64_t M,
1988 int64_t N,
1989 const float* src,
1990 int64_t ld_src,
1991 float* dst,
1992 int64_t ld_dst) {
1993 if (M == ld_dst && (M == 2 || M == 4)) {
1994 transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst);
1995 } else if (N == ld_src && (N == 2 || N == 4)) {
1996 transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst);
1997 } else {
1998 int64_t ib = 0, jb = 0;
1999 if (N % 16 > 0 && N % 16 < 4) {
2000 // If the remainder has n < 4 columns, we use the SSE kernel for the
2001 // remainder because it requires 4 * (2 * 4 + 2 * N) = 32 + 8N
2002 // instructions instead of 4 * 16 + 2 * N = 64 + 2N instructions needed in
2003 // the masked AVX512 kernel.
2004 for (ib = 0; ib + 16 <= M; ib += 16) {
2005 for (jb = 0; jb + 16 <= N; jb += 16) {
2006 transpose_kernel_16x16_avx512(
2007 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
2008 }
2009 for (int64_t i = ib; i < ib + 16; i += 4) {
2010 transpose_kernel_mxn_sse<4>(
2011 N - jb,
2012 &src[i * ld_src + jb],
2013 ld_src,
2014 &dst[i + jb * ld_dst],
2015 ld_dst);
2016 }
2017 }
2018 } else if (N % 16 == 4) {
2019 // If the remainder has 4 columns, we use the SSE kernel for the remainder
2020 // because it requires 4 * 16 = 64 instructions instead of 4 * 16 + 2 * 4
2021 // = 72 instructions needed in the masked AVX512 kernel.
2022 for (ib = 0; ib + 16 <= M; ib += 16) {
2023 for (jb = 0; jb + 16 <= N; jb += 16) {
2024 transpose_kernel_16x16_avx512(
2025 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
2026 }
2027 for (int64_t i = ib; i < ib + 16; i += 4) {
2028 transpose_kernel_4x4_sse(
2029 &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
2030 }
2031 }
2032 } else if (N % 16 == 8) {
2033 // If the remainder has 8 columns, we use the AVX kenrel for the remainder
2034 // because it requires 2 * 40 = 80 instructions instead of 4 * 16 + 2 * 8
2035 // = 80 instructions + looping overhead in the masked AVX512 kernel.
2036 for (ib = 0; ib + 16 <= M; ib += 16) {
2037 for (jb = 0; jb + 16 <= N; jb += 16) {
2038 transpose_kernel_16x16_avx512(
2039 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
2040 }
2041 for (int64_t i = ib; i < ib + 16; i += 8) {
2042 transpose_kernel_8x8_avx2(
2043 &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
2044 }
2045 }
2046 } else {
2047 for (ib = 0; ib + 16 <= M; ib += 16) {
2048 for (jb = 0; jb + 16 <= N; jb += 16) {
2049 transpose_kernel_16x16_avx512(
2050 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
2051 }
2052 if (jb < N) {
2053 transpose_kernel_mxn_avx512<16>(
2054 N - jb,
2055 &src[ib * ld_src + jb],
2056 ld_src,
2057 &dst[ib + jb * ld_dst],
2058 ld_dst);
2059 }
2060 }
2061 }
2062
2063 // Specialization for small M - ib cases so that the compiler can inline
2064 // transpose_kernel_mxn_avx512 and unroll the loops whose iteration count
2065 // depends on by M - ib .
2066 // Specialization for m helps more than for n in transpose_kernel_mxn_avx512
2067 // because we have more loops in that function whose iteration count depends
2068 // on m.
2069 switch (M - ib) {
2070 case 1:
2071 for (int64_t j = 0; j < N; ++j) {
2072 dst[ib + j * ld_dst] = src[ib * ld_src + j];
2073 }
2074 break;
2075 case 2:
2076 for (jb = 0; jb + 4 <= N; jb += 4) {
2077 transpose_kernel_mxn_sse<2>(
2078 4,
2079 &src[ib * ld_src + jb],
2080 ld_src,
2081 &dst[ib + jb * ld_dst],
2082 ld_dst);
2083 }
2084 if (jb < N) {
2085 transpose_kernel_mxn_sse<2>(
2086 N - jb,
2087 &src[ib * ld_src + jb],
2088 ld_src,
2089 &dst[ib + jb * ld_dst],
2090 ld_dst);
2091 }
2092 break;
2093 case 3:
2094 for (jb = 0; jb + 4 <= N; jb += 4) {
2095 transpose_kernel_mxn_sse<3>(
2096 4,
2097 &src[ib * ld_src + jb],
2098 ld_src,
2099 &dst[ib + jb * ld_dst],
2100 ld_dst);
2101 }
2102 if (jb < N) {
2103 transpose_kernel_mxn_sse<3>(
2104 N - jb,
2105 &src[ib * ld_src + jb],
2106 ld_src,
2107 &dst[ib + jb * ld_dst],
2108 ld_dst);
2109 }
2110 break;
2111 case 4:
2112 for (jb = 0; jb + 4 <= N; jb += 4) {
2113 transpose_kernel_4x4_sse(
2114 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
2115 }
2116 if (jb < N) {
2117 transpose_kernel_mxn_sse<4>(
2118 N - jb,
2119 &src[ib * ld_src + jb],
2120 ld_src,
2121 &dst[ib + jb * ld_dst],
2122 ld_dst);
2123 }
2124 break;
2125 case 5:
2126 for (jb = 0; jb + 8 <= N; jb += 8) {
2127 transpose_kernel_mxn_avx2<5>(
2128 8,
2129 &src[ib * ld_src + jb],
2130 ld_src,
2131 &dst[ib + jb * ld_dst],
2132 ld_dst);
2133 }
2134 if (jb < N) {
2135 transpose_kernel_mxn_avx2<5>(
2136 N - jb,
2137 &src[ib * ld_src + jb],
2138 ld_src,
2139 &dst[ib + jb * ld_dst],
2140 ld_dst);
2141 }
2142 break;
2143 case 6:
2144 for (jb = 0; jb + 8 <= N; jb += 8) {
2145 transpose_kernel_mxn_avx2<6>(
2146 8,
2147 &src[ib * ld_src + jb],
2148 ld_src,
2149 &dst[ib + jb * ld_dst],
2150 ld_dst);
2151 }
2152 if (jb < N) {
2153 transpose_kernel_mxn_avx2<6>(
2154 N - jb,
2155 &src[ib * ld_src + jb],
2156 ld_src,
2157 &dst[ib + jb * ld_dst],
2158 ld_dst);
2159 }
2160 break;
2161 case 7:
2162 for (jb = 0; jb + 16 <= N; jb += 16) {
2163 transpose_kernel_mxn_avx512<7>(
2164 16,
2165 &src[ib * ld_src + jb],
2166 ld_src,
2167 &dst[ib + jb * ld_dst],
2168 ld_dst);
2169 }
2170 if (jb < N) {
2171 transpose_kernel_mxn_avx512<7>(
2172 N - jb,
2173 &src[ib * ld_src + jb],
2174 ld_src,
2175 &dst[ib + jb * ld_dst],
2176 ld_dst);
2177 }
2178 break;
2179 case 8:
2180 for (jb = 0; jb + 8 <= N; jb += 8) {
2181 transpose_kernel_8x8_avx2(
2182 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
2183 }
2184 if (jb < N) {
2185 transpose_kernel_mxn_avx2<8>(
2186 N - jb,
2187 &src[ib * ld_src + jb],
2188 ld_src,
2189 &dst[ib + jb * ld_dst],
2190 ld_dst);
2191 }
2192 break;
2193 case 9:
2194 for (jb = 0; jb + 16 <= N; jb += 16) {
2195 transpose_kernel_mxn_avx512<9>(
2196 16,
2197 &src[ib * ld_src + jb],
2198 ld_src,
2199 &dst[ib + jb * ld_dst],
2200 ld_dst);
2201 }
2202 if (jb < N) {
2203 transpose_kernel_mxn_avx512<9>(
2204 N - jb,
2205 &src[ib * ld_src + jb],
2206 ld_src,
2207 &dst[ib + jb * ld_dst],
2208 ld_dst);
2209 }
2210 break;
2211 case 10:
2212 for (jb = 0; jb + 16 <= N; jb += 16) {
2213 transpose_kernel_mxn_avx512<10>(
2214 16,
2215 &src[ib * ld_src + jb],
2216 ld_src,
2217 &dst[ib + jb * ld_dst],
2218 ld_dst);
2219 }
2220 if (jb < N) {
2221 transpose_kernel_mxn_avx512<10>(
2222 N - jb,
2223 &src[ib * ld_src + jb],
2224 ld_src,
2225 &dst[ib + jb * ld_dst],
2226 ld_dst);
2227 }
2228 break;
2229 case 11:
2230 for (jb = 0; jb + 16 <= N; jb += 16) {
2231 transpose_kernel_mxn_avx512<11>(
2232 16,
2233 &src[ib * ld_src + jb],
2234 ld_src,
2235 &dst[ib + jb * ld_dst],
2236 ld_dst);
2237 }
2238 if (jb < N) {
2239 transpose_kernel_mxn_avx512<11>(
2240 N - jb,
2241 &src[ib * ld_src + jb],
2242 ld_src,
2243 &dst[ib + jb * ld_dst],
2244 ld_dst);
2245 }
2246 break;
2247 case 12:
2248 for (jb = 0; jb + 16 <= N; jb += 16) {
2249 transpose_kernel_mxn_avx512<12>(
2250 16,
2251 &src[ib * ld_src + jb],
2252 ld_src,
2253 &dst[ib + jb * ld_dst],
2254 ld_dst);
2255 }
2256 if (jb < N) {
2257 transpose_kernel_mxn_avx512<12>(
2258 N - jb,
2259 &src[ib * ld_src + jb],
2260 ld_src,
2261 &dst[ib + jb * ld_dst],
2262 ld_dst);
2263 }
2264 break;
2265 case 13:
2266 for (jb = 0; jb + 16 <= N; jb += 16) {
2267 transpose_kernel_mxn_avx512<13>(
2268 16,
2269 &src[ib * ld_src + jb],
2270 ld_src,
2271 &dst[ib + jb * ld_dst],
2272 ld_dst);
2273 }
2274 if (jb < N) {
2275 transpose_kernel_mxn_avx512<13>(
2276 N - jb,
2277 &src[ib * ld_src + jb],
2278 ld_src,
2279 &dst[ib + jb * ld_dst],
2280 ld_dst);
2281 }
2282 break;
2283 case 14:
2284 for (jb = 0; jb + 16 <= N; jb += 16) {
2285 transpose_kernel_mxn_avx512<14>(
2286 16,
2287 &src[ib * ld_src + jb],
2288 ld_src,
2289 &dst[ib + jb * ld_dst],
2290 ld_dst);
2291 }
2292 if (jb < N) {
2293 transpose_kernel_mxn_avx512<14>(
2294 N - jb,
2295 &src[ib * ld_src + jb],
2296 ld_src,
2297 &dst[ib + jb * ld_dst],
2298 ld_dst);
2299 }
2300 break;
2301 case 15:
2302 for (jb = 0; jb + 16 <= N; jb += 16) {
2303 transpose_kernel_mxn_avx512<15>(
2304 16,
2305 &src[ib * ld_src + jb],
2306 ld_src,
2307 &dst[ib + jb * ld_dst],
2308 ld_dst);
2309 }
2310 if (jb < N) {
2311 transpose_kernel_mxn_avx512<15>(
2312 N - jb,
2313 &src[ib * ld_src + jb],
2314 ld_src,
2315 &dst[ib + jb * ld_dst],
2316 ld_dst);
2317 }
2318 break;
2319 }
2320 }
2321}
2322
2323template <>
2324void transpose_avx512(
2325 int64_t M,
2326 int64_t N,
2327 const uint16_t* src,
2328 int64_t ld_src,
2329 uint16_t* dst,
2330 int64_t ld_dst) {
2331 if (M == ld_dst && (M == 2 || M == 4)) {
2332 transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst);
2333 } else if (N == ld_src && (N == 2 || N == 4)) {
2334 transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst);
2335 } else {
2336 int64_t i = 0;
2337 for (; i < M / 16 * 16; i += 16) {
2338 int64_t j = 0;
2339 for (; j < N / 16 * 16; j += 16) {
2340 transpose_16x16_block<false, false>(
2341 src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst);
2342 }
2343 // handle j rem
2344 int nrem = N - j;
2345 if (nrem > 0) {
2346 transpose_16x16_block<false, true>(
2347 src + i * ld_src + j,
2348 ld_src,
2349 dst + j * ld_dst + i,
2350 ld_dst,
2351 16,
2352 nrem);
2353 }
2354 }
2355 // handle i rem
2356 int mrem = M - i;
2357 if (mrem > 0) {
2358 int j = 0;
2359 for (; j < N / 16 * 16; j += 16) {
2360 transpose_16x16_block<true, false>(
2361 src + i * ld_src + j,
2362 ld_src,
2363 dst + j * ld_dst + i,
2364 ld_dst,
2365 mrem,
2366 16);
2367 }
2368 // handle j rem
2369 int nrem = N - j;
2370 transpose_16x16_block<true, true>(
2371 src + i * ld_src + j,
2372 ld_src,
2373 dst + j * ld_dst + i,
2374 ld_dst,
2375 mrem,
2376 nrem);
2377 }
2378 }
2379}
2380
2381template <>
2382void transpose_avx512(
2383 int64_t M,
2384 int64_t N,
2385 const uint8_t* src,
2386 int64_t ld_src,
2387 uint8_t* dst,
2388 int64_t ld_dst) {
2389 if (M == ld_dst && (M == 2 || M == 4)) {
2390 transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst);
2391 } else if (N == ld_src && (N == 2 || N == 4)) {
2392 transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst);
2393 } else {
2394 int64_t i = 0;
2395 for (; i < M / 16 * 16; i += 16) {
2396 int64_t j = 0;
2397 for (; j < N / 32 * 32; j += 32) {
2398 transpose_16x32_block<false, false>(
2399 src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst);
2400 }
2401 // handle j rem
2402 int nrem = N - j;
2403 if (nrem > 0) {
2404 transpose_16x32_block<false, true>(
2405 src + i * ld_src + j,
2406 ld_src,
2407 dst + j * ld_dst + i,
2408 ld_dst,
2409 16,
2410 nrem);
2411 }
2412 }
2413
2414 // handle i rem
2415 int mrem = M - i;
2416 if (mrem > 0) {
2417 int64_t j = 0;
2418 for (; j < N / 32 * 32; j += 32) {
2419 transpose_16x32_block<true, false>(
2420 src + i * ld_src + j,
2421 ld_src,
2422 dst + j * ld_dst + i,
2423 ld_dst,
2424 mrem,
2425 32);
2426 }
2427 // handle j rem
2428 int nrem = N - j;
2429 transpose_16x32_block<true, true>(
2430 src + i * ld_src + j,
2431 ld_src,
2432 dst + j * ld_dst + i,
2433 ld_dst,
2434 mrem,
2435 nrem);
2436 }
2437 }
2438}
2439
2440} // namespace internal
2441
2442} // namespace fbgemm
2443