1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | |
8 | #if defined(__x86_64__) || defined(__i386__) || \ |
9 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
10 | #include <immintrin.h> |
11 | #endif |
12 | #include "./TransposeUtils.h" |
13 | #include "./TransposeUtilsAvx2.h" |
14 | namespace fbgemm { |
15 | |
16 | namespace { |
17 | |
18 | // 16 * 6 = 96 instructions |
19 | inline void transpose_kernel_16x16_avx512( |
20 | const float* src, |
21 | int64_t ld_src, |
22 | float* dst, |
23 | int64_t ld_dst) { |
24 | // load from src to registers |
25 | // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 |
26 | // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 |
27 | // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 |
28 | // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 |
29 | // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 |
30 | // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 |
31 | // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 |
32 | // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 |
33 | // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 |
34 | // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 |
35 | // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 |
36 | // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 |
37 | // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 |
38 | // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 |
39 | // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 |
40 | // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 |
41 | __m512 a = _mm512_loadu_ps(&src[0 * ld_src]); |
42 | __m512 b = _mm512_loadu_ps(&src[1 * ld_src]); |
43 | __m512 c = _mm512_loadu_ps(&src[2 * ld_src]); |
44 | __m512 d = _mm512_loadu_ps(&src[3 * ld_src]); |
45 | __m512 e = _mm512_loadu_ps(&src[4 * ld_src]); |
46 | __m512 f = _mm512_loadu_ps(&src[5 * ld_src]); |
47 | __m512 g = _mm512_loadu_ps(&src[6 * ld_src]); |
48 | __m512 h = _mm512_loadu_ps(&src[7 * ld_src]); |
49 | __m512 i = _mm512_loadu_ps(&src[8 * ld_src]); |
50 | __m512 j = _mm512_loadu_ps(&src[9 * ld_src]); |
51 | __m512 k = _mm512_loadu_ps(&src[10 * ld_src]); |
52 | __m512 l = _mm512_loadu_ps(&src[11 * ld_src]); |
53 | __m512 m = _mm512_loadu_ps(&src[12 * ld_src]); |
54 | __m512 n = _mm512_loadu_ps(&src[13 * ld_src]); |
55 | __m512 o = _mm512_loadu_ps(&src[14 * ld_src]); |
56 | __m512 p = _mm512_loadu_ps(&src[15 * ld_src]); |
57 | |
58 | __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq; |
59 | // unpacking and interleaving 32-bit elements |
60 | // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13 |
61 | // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15 |
62 | // c0 d0 c1 d1 ... |
63 | // c2 d2 c3 d3 ... |
64 | // e0 f0 e1 f1 ... |
65 | // e2 f2 e3 f3 ... |
66 | // g0 h0 g1 h1 ... |
67 | // g2 h2 g3 h3 ... |
68 | // i0 ... |
69 | // i2 ... |
70 | // k0 ... |
71 | // k2 ... |
72 | // m0 ... |
73 | // m2 ... |
74 | // o0 ... |
75 | // o1 ... |
76 | ta = _mm512_unpacklo_ps(a, b); |
77 | tb = _mm512_unpackhi_ps(a, b); |
78 | tc = _mm512_unpacklo_ps(c, d); |
79 | td = _mm512_unpackhi_ps(c, d); |
80 | te = _mm512_unpacklo_ps(e, f); |
81 | tf = _mm512_unpackhi_ps(e, f); |
82 | tg = _mm512_unpacklo_ps(g, h); |
83 | th = _mm512_unpackhi_ps(g, h); |
84 | ti = _mm512_unpacklo_ps(i, j); |
85 | tj = _mm512_unpackhi_ps(i, j); |
86 | tk = _mm512_unpacklo_ps(k, l); |
87 | tl = _mm512_unpackhi_ps(k, l); |
88 | tm = _mm512_unpacklo_ps(m, n); |
89 | tn = _mm512_unpackhi_ps(m, n); |
90 | to = _mm512_unpacklo_ps(o, p); |
91 | tq = _mm512_unpackhi_ps(o, p); |
92 | |
93 | // unpacking and interleaving 64-bit elements |
94 | // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12 |
95 | // a1 b1 c1 d1 ... |
96 | // a2 b2 c2 d2 ... |
97 | // a3 b3 c3 d3 ... |
98 | // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12 |
99 | // e1 f1 g1 h1 ... |
100 | // e2 f2 g2 h2 ... |
101 | // e3 f3 g3 h3 ... |
102 | // i0 j0 k0 l0 ... |
103 | // i1 j1 k1 l1 ... |
104 | // i2 j2 k2 l2 ... |
105 | // i3 j3 k3 l3 ... |
106 | // m0 n0 o0 p0 ... |
107 | // m1 n1 o1 p1 ... |
108 | // m2 n2 o2 p2 ... |
109 | // m3 n3 o3 p3 ... |
110 | a = _mm512_castpd_ps( |
111 | _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc))); |
112 | b = _mm512_castpd_ps( |
113 | _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc))); |
114 | c = _mm512_castpd_ps( |
115 | _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td))); |
116 | d = _mm512_castpd_ps( |
117 | _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td))); |
118 | e = _mm512_castpd_ps( |
119 | _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg))); |
120 | f = _mm512_castpd_ps( |
121 | _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg))); |
122 | g = _mm512_castpd_ps( |
123 | _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th))); |
124 | h = _mm512_castpd_ps( |
125 | _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th))); |
126 | i = _mm512_castpd_ps( |
127 | _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk))); |
128 | j = _mm512_castpd_ps( |
129 | _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk))); |
130 | k = _mm512_castpd_ps( |
131 | _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl))); |
132 | l = _mm512_castpd_ps( |
133 | _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl))); |
134 | m = _mm512_castpd_ps( |
135 | _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to))); |
136 | n = _mm512_castpd_ps( |
137 | _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to))); |
138 | o = _mm512_castpd_ps( |
139 | _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq))); |
140 | p = _mm512_castpd_ps( |
141 | _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq))); |
142 | |
143 | // shuffle 128-bits (composed of 4 32-bit elements) |
144 | // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8 |
145 | // a1 b1 c1 d1 ... |
146 | // a2 b2 c2 d2 ... |
147 | // a3 b3 c3 d3 ... |
148 | // a4 b4 c4 d4 ... |
149 | // a5 b5 c5 d5 ... |
150 | // a6 b6 c6 d6 ... |
151 | // a7 b7 c7 d7 ... |
152 | // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8 |
153 | // i1 j1 k1 l1 ... |
154 | // i2 j2 k2 l2 ... |
155 | // i3 j3 k3 l3 ... |
156 | // i4 j4 k4 l4 ... |
157 | // i5 j5 k5 l5 ... |
158 | // i6 j6 k6 l6 ... |
159 | // i7 j7 k7 l7 ... |
160 | ta = _mm512_shuffle_f32x4(a, e, 0x88); |
161 | tb = _mm512_shuffle_f32x4(b, f, 0x88); |
162 | tc = _mm512_shuffle_f32x4(c, g, 0x88); |
163 | td = _mm512_shuffle_f32x4(d, h, 0x88); |
164 | te = _mm512_shuffle_f32x4(a, e, 0xdd); |
165 | tf = _mm512_shuffle_f32x4(b, f, 0xdd); |
166 | tg = _mm512_shuffle_f32x4(c, g, 0xdd); |
167 | th = _mm512_shuffle_f32x4(d, h, 0xdd); |
168 | ti = _mm512_shuffle_f32x4(i, m, 0x88); |
169 | tj = _mm512_shuffle_f32x4(j, n, 0x88); |
170 | tk = _mm512_shuffle_f32x4(k, o, 0x88); |
171 | tl = _mm512_shuffle_f32x4(l, p, 0x88); |
172 | tm = _mm512_shuffle_f32x4(i, m, 0xdd); |
173 | tn = _mm512_shuffle_f32x4(j, n, 0xdd); |
174 | to = _mm512_shuffle_f32x4(k, o, 0xdd); |
175 | tq = _mm512_shuffle_f32x4(l, p, 0xdd); |
176 | |
177 | // shuffle 128-bits (composed of 4 32-bit elements) |
178 | // a0 b0 c0 d0 ... o0 |
179 | // a1 b1 c1 d1 ... o1 |
180 | // a2 b2 c2 d2 ... o2 |
181 | // a3 b3 c3 d3 ... o3 |
182 | // a4 ... |
183 | // a5 ... |
184 | // a6 ... |
185 | // a7 ... |
186 | // a8 ... |
187 | // a9 ... |
188 | // a10 ... |
189 | // a11 ... |
190 | // a12 ... |
191 | // a13 ... |
192 | // a14 ... |
193 | // a15 b15 c15 d15 ... o15 |
194 | a = _mm512_shuffle_f32x4(ta, ti, 0x88); |
195 | b = _mm512_shuffle_f32x4(tb, tj, 0x88); |
196 | c = _mm512_shuffle_f32x4(tc, tk, 0x88); |
197 | d = _mm512_shuffle_f32x4(td, tl, 0x88); |
198 | e = _mm512_shuffle_f32x4(te, tm, 0x88); |
199 | f = _mm512_shuffle_f32x4(tf, tn, 0x88); |
200 | g = _mm512_shuffle_f32x4(tg, to, 0x88); |
201 | h = _mm512_shuffle_f32x4(th, tq, 0x88); |
202 | i = _mm512_shuffle_f32x4(ta, ti, 0xdd); |
203 | j = _mm512_shuffle_f32x4(tb, tj, 0xdd); |
204 | k = _mm512_shuffle_f32x4(tc, tk, 0xdd); |
205 | l = _mm512_shuffle_f32x4(td, tl, 0xdd); |
206 | m = _mm512_shuffle_f32x4(te, tm, 0xdd); |
207 | n = _mm512_shuffle_f32x4(tf, tn, 0xdd); |
208 | o = _mm512_shuffle_f32x4(tg, to, 0xdd); |
209 | p = _mm512_shuffle_f32x4(th, tq, 0xdd); |
210 | |
211 | // store from registers to dst |
212 | _mm512_storeu_ps(&dst[0 * ld_dst], a); |
213 | _mm512_storeu_ps(&dst[1 * ld_dst], b); |
214 | _mm512_storeu_ps(&dst[2 * ld_dst], c); |
215 | _mm512_storeu_ps(&dst[3 * ld_dst], d); |
216 | _mm512_storeu_ps(&dst[4 * ld_dst], e); |
217 | _mm512_storeu_ps(&dst[5 * ld_dst], f); |
218 | _mm512_storeu_ps(&dst[6 * ld_dst], g); |
219 | _mm512_storeu_ps(&dst[7 * ld_dst], h); |
220 | _mm512_storeu_ps(&dst[8 * ld_dst], i); |
221 | _mm512_storeu_ps(&dst[9 * ld_dst], j); |
222 | _mm512_storeu_ps(&dst[10 * ld_dst], k); |
223 | _mm512_storeu_ps(&dst[11 * ld_dst], l); |
224 | _mm512_storeu_ps(&dst[12 * ld_dst], m); |
225 | _mm512_storeu_ps(&dst[13 * ld_dst], n); |
226 | _mm512_storeu_ps(&dst[14 * ld_dst], o); |
227 | _mm512_storeu_ps(&dst[15 * ld_dst], p); |
228 | } |
229 | |
230 | // kernel for transposing mxn where m, n <= 16 |
231 | // M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions |
232 | template <int M> |
233 | void transpose_kernel_mxn_avx512( |
234 | int N, |
235 | const float* src, |
236 | int64_t ld_src, |
237 | float* dst, |
238 | int64_t ld_dst) { |
239 | // load from src to registers |
240 | __mmask16 src_mask = (1 << N) - 1; |
241 | __m512 input[16]; |
242 | int i; |
243 | for (i = 0; i < M; ++i) { |
244 | input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); |
245 | } |
246 | for (; i < 16; ++i) { |
247 | // Not really needed but to avoid uninitialized variable warning. |
248 | // Shouldn't be much overhead because xor can be executed in parallel with |
249 | // other instructions. |
250 | input[i] = _mm512_setzero_ps(); |
251 | } |
252 | |
253 | // unpacking and interleaving 32-bit elements |
254 | __m512 temp[16]; |
255 | for (i = 0; i < (M + 1) / 2; ++i) { |
256 | temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]); |
257 | temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]); |
258 | } |
259 | for (i = i * 2; i < 16; ++i) { |
260 | temp[i] = _mm512_setzero_ps(); |
261 | } |
262 | |
263 | // unpacking and interleaving 64-bit elements |
264 | for (i = 0; i < (M + 3) / 4; ++i) { |
265 | input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd( |
266 | _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); |
267 | input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd( |
268 | _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); |
269 | input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd( |
270 | _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); |
271 | input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd( |
272 | _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); |
273 | } |
274 | |
275 | // shuffle 128-bits (composed of 4 32-bit elements) |
276 | for (i = 0; i < (M + 7) / 8; ++i) { |
277 | temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88); |
278 | temp[8 * i + 1] = |
279 | _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88); |
280 | temp[8 * i + 2] = |
281 | _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88); |
282 | temp[8 * i + 3] = |
283 | _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88); |
284 | temp[8 * i + 4] = |
285 | _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd); |
286 | temp[8 * i + 5] = |
287 | _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd); |
288 | temp[8 * i + 6] = |
289 | _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd); |
290 | temp[8 * i + 7] = |
291 | _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd); |
292 | } |
293 | |
294 | // store from registers to dst |
295 | __mmask16 dst_mask = (1 << M) - 1; |
296 | for (i = 0; i < N; ++i) { |
297 | if (i < 8) { |
298 | input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88); |
299 | } else { |
300 | input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd); |
301 | } |
302 | _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); |
303 | } |
304 | } |
305 | |
306 | } // namespace |
307 | |
308 | namespace internal { |
309 | |
310 | template <typename T> |
311 | void transpose_avx512_contiguous_thin( |
312 | const int64_t M, |
313 | const int64_t N, |
314 | const T* src, |
315 | int64_t ld_src, |
316 | T* dst, |
317 | int64_t ld_dst); |
318 | |
319 | template <typename T> |
320 | void transpose_avx512_contiguous_wide( |
321 | const int64_t M, |
322 | const int64_t N, |
323 | const T* src, |
324 | int64_t ld_src, |
325 | T* dst, |
326 | int64_t ld_dst); |
327 | |
328 | // Permute elements in 128 bit lane |
329 | // e.g., if a 128-bit lane has the following elements: |
330 | // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 |
331 | // |
332 | // After this function call, it becomes |
333 | // 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 |
334 | // The same happens with other 3 lanes. |
335 | static inline __m512i permute_row(__m512i row) { |
336 | // clang-format off |
337 | __m256i shuffle_256v0 = _mm256_set_epi8( |
338 | 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, |
339 | 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); |
340 | // clang-format on |
341 | __m512i shuffle_512v = _mm512_castsi256_si512(shuffle_256v0); |
342 | row = _mm512_shuffle_epi8( |
343 | row, _mm512_inserti64x4(shuffle_512v, shuffle_256v0, 1)); |
344 | return row; |
345 | } |
346 | |
347 | static inline void core_transpose_16x32_block_i8(__m512i r[], __m512i u[]) { |
348 | // Result after this operation; Read in conjunction with comments in |
349 | // transpose_16x32_block |
350 | // 00_00 00_01 01_00 01_01 00_04 00_05 01_04 01_05 04_00 04_01 05_00 05_01 |
351 | // 04_04 04_05 05_04 05_05 |
352 | u[0] = _mm512_unpacklo_epi64(r[0], r[1]); |
353 | // 00_02 00_03 01_02 01_03 00_06 00_07 01_06 01_07 04_02 04_03 05_02 05_03 |
354 | // 04_06 04_07 05_06 05_07 |
355 | u[1] = _mm512_unpackhi_epi64(r[0], r[1]); |
356 | // 02_00 02_01 03_00 03_01 02_04 02_05 03_04 03_05 06_00 06_01 07_00 07_01 |
357 | // 06_04 06_05 07_04 07_05 |
358 | u[2] = _mm512_unpacklo_epi64(r[2], r[3]); |
359 | // 02_02 02_03 03_02 03_03 02_06 02_07 03_06 03_07 06_02 06_03 07_02 07_03 |
360 | // 06_06 06_07 07_06 07_07 |
361 | u[3] = _mm512_unpackhi_epi64(r[2], r[3]); |
362 | // 08_00 08_01 09_00 09_01 08_04 08_05 09_04 09_05 12_00 12_01 13_00 13_01 |
363 | // 12_04 12_05 13_04 13_05 |
364 | u[4] = _mm512_unpacklo_epi64(r[4], r[5]); |
365 | u[5] = _mm512_unpackhi_epi64(r[4], r[5]); |
366 | u[6] = _mm512_unpacklo_epi64(r[6], r[7]); |
367 | u[7] = _mm512_unpackhi_epi64(r[6], r[7]); |
368 | |
369 | // This instruction doesn't exist for epi32 so casting to ps |
370 | // 00_00 01_00 02_00 03_00 00_04 01_04 02_04 03_04 04_00 05_00 06_00 07_00 |
371 | // 04_04 05_04 06_04 07_04 |
372 | r[0] = _mm512_castps_si512(_mm512_shuffle_ps( |
373 | _mm512_castsi512_ps(u[0]), _mm512_castsi512_ps(u[2]), 0x88)); |
374 | // 00_01 01_01 02_01 03_01 00_05 01_05 02_05 03_05 04_01 05_01 06_01 07_01 |
375 | // 04_05 05_05 06_05 07_05 |
376 | r[1] = _mm512_castps_si512(_mm512_shuffle_ps( |
377 | _mm512_castsi512_ps(u[0]), _mm512_castsi512_ps(u[2]), 0xDD)); |
378 | r[2] = _mm512_castps_si512(_mm512_shuffle_ps( |
379 | _mm512_castsi512_ps(u[1]), _mm512_castsi512_ps(u[3]), 0x88)); |
380 | r[3] = _mm512_castps_si512(_mm512_shuffle_ps( |
381 | _mm512_castsi512_ps(u[1]), _mm512_castsi512_ps(u[3]), 0xDD)); |
382 | // 08_00 09_00 10_00 11_00 08_04 09_04 10_04 11_04 12_00 13_00 14_00 15_00 |
383 | // 12_04 13_04 14_04 15_04 |
384 | r[4] = _mm512_castps_si512(_mm512_shuffle_ps( |
385 | _mm512_castsi512_ps(u[4]), _mm512_castsi512_ps(u[6]), 0x88)); |
386 | r[5] = _mm512_castps_si512(_mm512_shuffle_ps( |
387 | _mm512_castsi512_ps(u[4]), _mm512_castsi512_ps(u[6]), 0xDD)); |
388 | r[6] = _mm512_castps_si512(_mm512_shuffle_ps( |
389 | _mm512_castsi512_ps(u[5]), _mm512_castsi512_ps(u[7]), 0x88)); |
390 | r[7] = _mm512_castps_si512(_mm512_shuffle_ps( |
391 | _mm512_castsi512_ps(u[5]), _mm512_castsi512_ps(u[7]), 0xDD)); |
392 | |
393 | // permute among 128-bit lanes |
394 | r[0] = permute_row(r[0]); |
395 | r[1] = permute_row(r[1]); |
396 | r[2] = permute_row(r[2]); |
397 | r[3] = permute_row(r[3]); |
398 | r[4] = permute_row(r[4]); |
399 | r[5] = permute_row(r[5]); |
400 | r[6] = permute_row(r[6]); |
401 | r[7] = permute_row(r[7]); |
402 | |
403 | __m512i const1 = _mm512_set_epi32( |
404 | 27, 19, 11, 3, 26, 18, 10, 2, 25, 17, 9, 1, 24, 16, 8, 0); |
405 | __m512i const2 = _mm512_set_epi32( |
406 | 31, 23, 15, 7, 30, 22, 14, 6, 29, 21, 13, 5, 28, 20, 12, 4); |
407 | |
408 | // merge 128-bit values from two regs |
409 | u[0] = _mm512_permutex2var_epi32(r[0], const1, r[4]); |
410 | u[1] = _mm512_permutex2var_epi32(r[0], const2, r[4]); |
411 | u[2] = _mm512_permutex2var_epi32(r[1], const1, r[5]); |
412 | u[3] = _mm512_permutex2var_epi32(r[1], const2, r[5]); |
413 | u[4] = _mm512_permutex2var_epi32(r[2], const1, r[6]); |
414 | u[5] = _mm512_permutex2var_epi32(r[2], const2, r[6]); |
415 | u[6] = _mm512_permutex2var_epi32(r[3], const1, r[7]); |
416 | u[7] = _mm512_permutex2var_epi32(r[3], const2, r[7]); |
417 | } |
418 | |
419 | static inline void core_transpose_16x16_block(__m512i r[], __m512i u[]) { |
420 | // a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 f8f9 |
421 | // e10e11 f10f11 |
422 | u[0] = _mm512_unpacklo_epi32(r[0], r[1]); |
423 | // a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 f4f5 e6e7 f6f7 |
424 | // e12e13 f12f13 e14e15 f14f15 |
425 | u[1] = _mm512_unpackhi_epi32(r[0], r[1]); |
426 | // c0c1 d0d1 c2c3 d2d3 c8c9 d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 |
427 | // g10g11 h10h11 |
428 | u[2] = _mm512_unpacklo_epi32(r[2], r[3]); |
429 | // c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 |
430 | // g12g13 h12h13 g14g15 h14h15 |
431 | u[3] = _mm512_unpackhi_epi32(r[2], r[3]); |
432 | // i j m n |
433 | u[4] = _mm512_unpacklo_epi32(r[4], r[5]); |
434 | u[5] = _mm512_unpackhi_epi32(r[4], r[5]); |
435 | // k l o p |
436 | u[6] = _mm512_unpacklo_epi32(r[6], r[7]); |
437 | u[7] = _mm512_unpackhi_epi32(r[6], r[7]); |
438 | |
439 | // a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 g8g9 |
440 | // h8h9 |
441 | r[0] = _mm512_unpacklo_epi64(u[0], u[2]); |
442 | // a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 g2g3 h2h3 e10e11 |
443 | // f10f11 g10g11 h10h11 |
444 | r[1] = _mm512_unpackhi_epi64(u[0], u[2]); |
445 | // a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 c12c13 d12d13 |
446 | r[2] = _mm512_unpacklo_epi64(u[1], u[3]); |
447 | // a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15 |
448 | r[3] = _mm512_unpackhi_epi64(u[1], u[3]); |
449 | // i j k l m n o p |
450 | r[4] = _mm512_unpacklo_epi64(u[4], u[6]); |
451 | r[5] = _mm512_unpackhi_epi64(u[4], u[6]); |
452 | r[6] = _mm512_unpacklo_epi64(u[5], u[7]); |
453 | r[7] = _mm512_unpackhi_epi64(u[5], u[7]); |
454 | |
455 | __m512i const1 = _mm512_set_epi32( |
456 | 0x00370035, |
457 | 0x00330031, |
458 | 0x00270025, |
459 | 0x00230021, |
460 | 0x00170015, |
461 | 0x00130011, |
462 | 0x00070005, |
463 | 0x00030001, |
464 | 0x00360034, |
465 | 0x00320030, |
466 | 0x00260024, |
467 | 0x00220020, |
468 | 0x00160014, |
469 | 0x00120010, |
470 | 0x00060004, |
471 | 0x00020000); |
472 | __m512i const2 = _mm512_set_epi32( |
473 | 0x003f003d, |
474 | 0x003b0039, |
475 | 0x002f002d, |
476 | 0x002b0029, |
477 | 0x001f001d, |
478 | 0x001b0019, |
479 | 0x000f000d, |
480 | 0x000b0009, |
481 | 0x003e003c, |
482 | 0x003a0038, |
483 | 0x002e002c, |
484 | 0x002a0028, |
485 | 0x001e001c, |
486 | 0x001a0018, |
487 | 0x000e000c, |
488 | 0x000a0008); |
489 | |
490 | // merge values from two regs |
491 | u[0] = _mm512_permutex2var_epi16(r[0], const1, r[4]); // 0-- 1-- |
492 | u[4] = _mm512_permutex2var_epi16(r[0], const2, r[4]); // 8-- 9-- |
493 | u[2] = _mm512_permutex2var_epi16(r[2], const1, r[6]); // 4-- 5-- |
494 | u[6] = _mm512_permutex2var_epi16(r[2], const2, r[6]); // 12-- 13-- |
495 | u[1] = _mm512_permutex2var_epi16(r[1], const1, r[5]); // 2-- 3-- |
496 | u[5] = _mm512_permutex2var_epi16(r[1], const2, r[5]); // 10-- 11-- |
497 | u[3] = _mm512_permutex2var_epi16(r[3], const1, r[7]); // 6-- 7-- |
498 | u[7] = _mm512_permutex2var_epi16(r[3], const2, r[7]); // 14-- 15-- |
499 | } |
500 | |
501 | static inline void load_with_remainders_i16( |
502 | const uint16_t* src, |
503 | int64_t ld_src, |
504 | __m512i r[], |
505 | int mrem, |
506 | int nrem) { |
507 | __m512i t[16]; |
508 | if (nrem < 16) { |
509 | __mmask32 mask_nrem_v = (((long long)1) << nrem) - 1; |
510 | for (int i = 0; i < mrem; ++i) { |
511 | // mask load |
512 | t[i] = _mm512_maskz_loadu_epi16(mask_nrem_v, src + i * ld_src); |
513 | } |
514 | } else { |
515 | for (int i = 0; i < mrem; ++i) { |
516 | // normal load |
517 | t[i] = _mm512_castsi256_si512(_mm256_loadu_si256( |
518 | reinterpret_cast<const __m256i*>(src + i * ld_src))); |
519 | } |
520 | } |
521 | r[0] = _mm512_inserti64x4(t[0], _mm512_castsi512_si256(t[4]), 0x01); |
522 | r[1] = _mm512_inserti64x4(t[1], _mm512_castsi512_si256(t[5]), 0x01); |
523 | r[2] = _mm512_inserti64x4(t[2], _mm512_castsi512_si256(t[6]), 0x01); |
524 | r[3] = _mm512_inserti64x4(t[3], _mm512_castsi512_si256(t[7]), 0x01); |
525 | r[4] = _mm512_inserti64x4(t[8], _mm512_castsi512_si256(t[12]), 0x01); |
526 | r[5] = _mm512_inserti64x4(t[9], _mm512_castsi512_si256(t[13]), 0x01); |
527 | r[6] = _mm512_inserti64x4(t[10], _mm512_castsi512_si256(t[14]), 0x01); |
528 | r[7] = _mm512_inserti64x4(t[11], _mm512_castsi512_si256(t[15]), 0x01); |
529 | } |
530 | |
531 | static inline void load_with_remainders_i8( |
532 | const uint8_t* src, |
533 | int64_t ld_src, |
534 | __m512i r[], |
535 | int mrem, |
536 | int nrem) { |
537 | __m512i t[16]; |
538 | if (nrem < 32) { |
539 | __mmask64 mask_nrem_v = (((long long)1) << nrem) - 1; |
540 | for (int i = 0; i < mrem; ++i) { |
541 | // mask load |
542 | t[i] = _mm512_maskz_loadu_epi8(mask_nrem_v, src + i * ld_src); |
543 | } |
544 | } else { |
545 | for (int i = 0; i < mrem; ++i) { |
546 | // normal load |
547 | t[i] = _mm512_castsi256_si512(_mm256_loadu_si256( |
548 | reinterpret_cast<const __m256i*>(src + i * ld_src))); |
549 | } |
550 | } |
551 | r[0] = _mm512_inserti64x4(t[0], _mm512_castsi512_si256(t[4]), 0x01); |
552 | r[1] = _mm512_inserti64x4(t[1], _mm512_castsi512_si256(t[5]), 0x01); |
553 | r[2] = _mm512_inserti64x4(t[2], _mm512_castsi512_si256(t[6]), 0x01); |
554 | r[3] = _mm512_inserti64x4(t[3], _mm512_castsi512_si256(t[7]), 0x01); |
555 | r[4] = _mm512_inserti64x4(t[8], _mm512_castsi512_si256(t[12]), 0x01); |
556 | r[5] = _mm512_inserti64x4(t[9], _mm512_castsi512_si256(t[13]), 0x01); |
557 | r[6] = _mm512_inserti64x4(t[10], _mm512_castsi512_si256(t[14]), 0x01); |
558 | r[7] = _mm512_inserti64x4(t[11], _mm512_castsi512_si256(t[15]), 0x01); |
559 | } |
560 | |
561 | static inline void store_with_remainders_i16( |
562 | uint16_t* dst, |
563 | int64_t ld_dst, |
564 | __m512i u[], |
565 | int mrem, |
566 | int nrem) { |
567 | if (mrem < 16) { |
568 | __mmask32 mask_mrem_v = (((long long)1) << mrem) - 1; |
569 | int i = 0; |
570 | |
571 | for (; i < nrem / 2 * 2; i += 2) { |
572 | // mask store |
573 | int reg_idx = i / 2; |
574 | _mm512_mask_storeu_epi16( |
575 | dst + (i + 0) * ld_dst, |
576 | mask_mrem_v, |
577 | _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x0))); |
578 | _mm512_mask_storeu_epi16( |
579 | dst + (i + 1) * ld_dst, |
580 | mask_mrem_v, |
581 | _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x1))); |
582 | } |
583 | if (i < nrem) { |
584 | int reg_idx = i / 2; |
585 | _mm512_mask_storeu_epi16( |
586 | dst + (i + 0) * ld_dst, |
587 | mask_mrem_v, |
588 | _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x0))); |
589 | } |
590 | } else { |
591 | int i = 0; |
592 | for (; i < nrem / 2 * 2; i += 2) { |
593 | // normal store |
594 | int reg_idx = i / 2; |
595 | _mm256_storeu_si256( |
596 | reinterpret_cast<__m256i*>(dst + (i + 0) * ld_dst), |
597 | _mm512_extracti32x8_epi32(u[reg_idx], 0x0)); |
598 | _mm256_storeu_si256( |
599 | reinterpret_cast<__m256i*>(dst + (i + 1) * ld_dst), |
600 | _mm512_extracti32x8_epi32(u[reg_idx], 0x1)); |
601 | } |
602 | if (i < nrem) { |
603 | int reg_idx = i / 2; |
604 | _mm256_storeu_si256( |
605 | reinterpret_cast<__m256i*>(dst + (i + 0) * ld_dst), |
606 | _mm512_extracti32x8_epi32(u[reg_idx], 0x0)); |
607 | } |
608 | } |
609 | } |
610 | |
611 | static inline void store_with_remainders_i8( |
612 | uint8_t* dst, |
613 | int64_t ld_dst, |
614 | __m512i u[], |
615 | int mrem, |
616 | int nrem) { |
617 | if (mrem < 16) { |
618 | __mmask64 mask_mrem_v = (((long long)1) << mrem) - 1; |
619 | int i = 0; |
620 | for (; i < nrem / 4 * 4; i += 4) { |
621 | // mask store |
622 | // we need 0, 4, 8, 16 => 0, 2, 4, 6 |
623 | // and 16, 20, 24, 28 => 1, 3, 5, 7 |
624 | // See stores for non-rem case |
625 | int reg_idx = i / 16 + 2 * ((i % 16) / 4); |
626 | _mm512_mask_storeu_epi8( |
627 | dst + (i + 0) * ld_dst, |
628 | mask_mrem_v, |
629 | _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x0))); |
630 | _mm512_mask_storeu_epi8( |
631 | dst + (i + 1) * ld_dst, |
632 | mask_mrem_v, |
633 | _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x1))); |
634 | _mm512_mask_storeu_epi8( |
635 | dst + (i + 2) * ld_dst, |
636 | mask_mrem_v, |
637 | _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x2))); |
638 | _mm512_mask_storeu_epi8( |
639 | dst + (i + 3) * ld_dst, |
640 | mask_mrem_v, |
641 | _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x3))); |
642 | } |
643 | int rem = nrem - i; |
644 | int reg_rem_idx = i / 16 + 2 * ((i % 16) / 4); |
645 | switch (rem) { |
646 | case 1: |
647 | _mm512_mask_storeu_epi8( |
648 | dst + (i + 0) * ld_dst, |
649 | mask_mrem_v, |
650 | _mm512_castsi128_si512( |
651 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0))); |
652 | break; |
653 | case 2: |
654 | _mm512_mask_storeu_epi8( |
655 | dst + (i + 0) * ld_dst, |
656 | mask_mrem_v, |
657 | _mm512_castsi128_si512( |
658 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0))); |
659 | _mm512_mask_storeu_epi8( |
660 | dst + (i + 1) * ld_dst, |
661 | mask_mrem_v, |
662 | _mm512_castsi128_si512( |
663 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1))); |
664 | break; |
665 | case 3: |
666 | _mm512_mask_storeu_epi8( |
667 | dst + (i + 0) * ld_dst, |
668 | mask_mrem_v, |
669 | _mm512_castsi128_si512( |
670 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0))); |
671 | _mm512_mask_storeu_epi8( |
672 | dst + (i + 1) * ld_dst, |
673 | mask_mrem_v, |
674 | _mm512_castsi128_si512( |
675 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1))); |
676 | _mm512_mask_storeu_epi8( |
677 | dst + (i + 2) * ld_dst, |
678 | mask_mrem_v, |
679 | _mm512_castsi128_si512( |
680 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x2))); |
681 | break; |
682 | default: |
683 | break; |
684 | } |
685 | |
686 | } else { |
687 | int i = 0; |
688 | for (; i < nrem / 4 * 4; i += 4) { |
689 | // normal store |
690 | int reg_idx = i / 16 + 2 * ((i % 16) / 4); |
691 | _mm_storeu_si128( |
692 | reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst), |
693 | _mm512_extracti32x4_epi32(u[reg_idx], 0x0)); |
694 | _mm_storeu_si128( |
695 | reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst), |
696 | _mm512_extracti32x4_epi32(u[reg_idx], 0x1)); |
697 | _mm_storeu_si128( |
698 | reinterpret_cast<__m128i*>(dst + (i + 2) * ld_dst), |
699 | _mm512_extracti32x4_epi32(u[reg_idx], 0x2)); |
700 | _mm_storeu_si128( |
701 | reinterpret_cast<__m128i*>(dst + (i + 3) * ld_dst), |
702 | _mm512_extracti32x4_epi32(u[reg_idx], 0x3)); |
703 | } |
704 | int rem = nrem - i; |
705 | int reg_rem_idx = i / 16 + 2 * ((i % 16) / 4); |
706 | switch (rem) { |
707 | case 1: |
708 | _mm_storeu_si128( |
709 | reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst), |
710 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)); |
711 | break; |
712 | case 2: |
713 | _mm_storeu_si128( |
714 | reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst), |
715 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)); |
716 | _mm_storeu_si128( |
717 | reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst), |
718 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1)); |
719 | break; |
720 | case 3: |
721 | _mm_storeu_si128( |
722 | reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst), |
723 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)); |
724 | _mm_storeu_si128( |
725 | reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst), |
726 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1)); |
727 | _mm_storeu_si128( |
728 | reinterpret_cast<__m128i*>(dst + (i + 2) * ld_dst), |
729 | _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x2)); |
730 | break; |
731 | default: |
732 | break; |
733 | } |
734 | } |
735 | } |
736 | |
737 | static inline void transpose_contiguous_4x16_block( |
738 | const float* src, |
739 | float* dst, |
740 | int64_t ld_src, |
741 | int nrem = 16) { |
742 | __m512i r[4]; |
743 | // load |
744 | if (nrem < 16) { |
745 | __mmask16 mask_mrem_v = (((long long)1) << nrem) - 1; |
746 | r[0] = _mm512_maskz_loadu_epi32(mask_mrem_v, src); |
747 | r[1] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src); |
748 | r[2] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 2 * ld_src); |
749 | r[3] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 3 * ld_src); |
750 | |
751 | } else { |
752 | r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src)); |
753 | r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src)); |
754 | r[2] = |
755 | _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 2 * ld_src)); |
756 | r[3] = |
757 | _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 3 * ld_src)); |
758 | } |
759 | // transpose |
760 | // a0b0 a1b1 a4b4 a5b5 a8b8 a9b9 a12b12 a13b13 |
761 | // a2b2 a3b3 a6b6 a7b7 a10b10 a11b11 a14b14 a15b15 |
762 | // c0d0 c1d1 c4d4 c5d5 c8d8 c9d9 c12d12 c13d13 |
763 | // c2d2 c3d3 c6d6 c7d7 c10b10 c11d11 c14d14 c15d15 |
764 | __m512i t0 = _mm512_unpacklo_epi32(r[0], r[1]); |
765 | __m512i t1 = _mm512_unpackhi_epi32(r[0], r[1]); |
766 | __m512i t2 = _mm512_unpacklo_epi32(r[2], r[3]); |
767 | __m512i t3 = _mm512_unpackhi_epi32(r[2], r[3]); |
768 | |
769 | r[0] = _mm512_unpacklo_epi64(t0, t2); |
770 | r[1] = _mm512_unpackhi_epi64(t0, t2); |
771 | r[2] = _mm512_unpacklo_epi64(t1, t3); |
772 | r[3] = _mm512_unpackhi_epi64(t1, t3); |
773 | |
774 | t0 = _mm512_shuffle_i32x4(r[0], r[1], 0x44); |
775 | t1 = _mm512_shuffle_i32x4(r[0], r[1], 0xee); |
776 | t2 = _mm512_shuffle_i32x4(r[2], r[3], 0x44); |
777 | t3 = _mm512_shuffle_i32x4(r[2], r[3], 0xee); |
778 | |
779 | r[0] = _mm512_shuffle_i32x4(t0, t2, 0x88); |
780 | r[1] = _mm512_shuffle_i32x4(t0, t2, 0xdd); |
781 | r[2] = _mm512_shuffle_i32x4(t1, t3, 0x88); |
782 | r[3] = _mm512_shuffle_i32x4(t1, t3, 0xdd); |
783 | // store |
784 | int i = 0; |
785 | for (; (i + 1) * 16 <= nrem * 4; i++) { |
786 | // normal store |
787 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 16), r[i]); |
788 | } |
789 | int erem = nrem * 4 - i * 16; |
790 | if (erem > 0) { |
791 | // mask store |
792 | __mmask16 mask_rem_v = (((long long)1) << erem) - 1; |
793 | _mm512_mask_storeu_epi32(dst + i * 16, mask_rem_v, r[i]); |
794 | } |
795 | } |
796 | |
797 | static inline void transpose_contiguous_4x32_block( |
798 | const uint16_t* src, |
799 | uint16_t* dst, |
800 | int64_t ld_src, |
801 | int nrem = 32) { |
802 | __m512i r[4], d[4]; |
803 | // load |
804 | if (nrem < 32) { |
805 | __mmask32 mask_mrem_v = (((long long)1) << nrem) - 1; |
806 | r[0] = _mm512_maskz_loadu_epi16(mask_mrem_v, src); |
807 | r[1] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src); |
808 | r[2] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 2 * ld_src); |
809 | r[3] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 3 * ld_src); |
810 | } else { |
811 | r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src)); |
812 | r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src)); |
813 | r[2] = |
814 | _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 2 * ld_src)); |
815 | r[3] = |
816 | _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 3 * ld_src)); |
817 | } |
818 | // transpose |
819 | d[0] = _mm512_unpacklo_epi16(r[0], r[1]); |
820 | d[1] = _mm512_unpackhi_epi16(r[0], r[1]); |
821 | d[2] = _mm512_unpacklo_epi16(r[2], r[3]); |
822 | d[3] = _mm512_unpackhi_epi16(r[2], r[3]); |
823 | |
824 | r[0] = _mm512_unpacklo_epi32(d[0], d[2]); |
825 | r[1] = _mm512_unpackhi_epi32(d[0], d[2]); |
826 | r[2] = _mm512_unpacklo_epi32(d[1], d[3]); |
827 | r[3] = _mm512_unpackhi_epi32(d[1], d[3]); |
828 | |
829 | d[0] = _mm512_shuffle_i32x4(r[0], r[1], 0x44); |
830 | d[1] = _mm512_shuffle_i32x4(r[0], r[1], 0xee); |
831 | d[2] = _mm512_shuffle_i32x4(r[2], r[3], 0x44); |
832 | d[3] = _mm512_shuffle_i32x4(r[2], r[3], 0xee); |
833 | |
834 | r[0] = _mm512_shuffle_i32x4(d[0], d[2], 0x88); |
835 | r[1] = _mm512_shuffle_i32x4(d[0], d[2], 0xdd); |
836 | r[2] = _mm512_shuffle_i32x4(d[1], d[3], 0x88); |
837 | r[3] = _mm512_shuffle_i32x4(d[1], d[3], 0xdd); |
838 | // store |
839 | int i = 0; |
840 | for (; (i + 1) * 32 <= nrem * 4; i++) { |
841 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 32), r[i]); |
842 | } |
843 | int erem = nrem * 4 - i * 32; |
844 | if (erem > 0) { |
845 | // mask store |
846 | __mmask32 mask_rem_v = (((long long)1) << erem) - 1; |
847 | _mm512_mask_storeu_epi16(dst + i * 32, mask_rem_v, r[i]); |
848 | } |
849 | } |
850 | |
851 | static inline void transpose_contiguous_16x4_block( |
852 | const float* src, |
853 | float* dst, |
854 | int64_t ld_dst, |
855 | int mrem = 16) { |
856 | __m512i r[4], d[4]; |
857 | int i = 0; |
858 | for (; (i + 1) * 16 <= mrem * 4; i++) { |
859 | // normal load |
860 | r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 16)); |
861 | } |
862 | if (i * 16 < mrem * 4) { |
863 | __mmask16 mask_mrem_v = (((long long)1) << (mrem * 4 - i * 16)) - 1; |
864 | r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16); |
865 | } |
866 | |
867 | // transpose |
868 | __m512i index1 = _mm512_set_epi32( |
869 | 0x0f, |
870 | 0x0b, |
871 | 0x07, |
872 | 0x03, |
873 | 0x0e, |
874 | 0x0a, |
875 | 0x06, |
876 | 0x02, |
877 | 0x0d, |
878 | 0x09, |
879 | 0x05, |
880 | 0x01, |
881 | 0x0c, |
882 | 0x08, |
883 | 0x04, |
884 | 0x00); |
885 | d[0] = _mm512_permutexvar_epi32(index1, r[0]); |
886 | d[1] = _mm512_permutexvar_epi32(index1, r[1]); |
887 | d[2] = _mm512_permutexvar_epi32(index1, r[2]); |
888 | d[3] = _mm512_permutexvar_epi32(index1, r[3]); |
889 | |
890 | r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); |
891 | r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); |
892 | r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); |
893 | r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); |
894 | |
895 | d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); |
896 | d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); |
897 | d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); |
898 | d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); |
899 | |
900 | if (mrem < 16) { |
901 | // mask store |
902 | __mmask16 mask_rem_v = (((long long)1) << mrem) - 1; |
903 | _mm512_mask_storeu_epi32(dst + 0 * ld_dst, mask_rem_v, d[0]); |
904 | _mm512_mask_storeu_epi32(dst + 1 * ld_dst, mask_rem_v, d[1]); |
905 | _mm512_mask_storeu_epi32(dst + 2 * ld_dst, mask_rem_v, d[2]); |
906 | _mm512_mask_storeu_epi32(dst + 3 * ld_dst, mask_rem_v, d[3]); |
907 | } else { |
908 | // normal load |
909 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]); |
910 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]); |
911 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]); |
912 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]); |
913 | } |
914 | } |
915 | |
916 | static inline void transpose_contiguous_16x2_block( |
917 | const float* src, |
918 | float* dst, |
919 | int64_t ld_dst, |
920 | int mrem = 16) { |
921 | __m512i r[2], d[2]; |
922 | int i = 0; |
923 | for (; (i + 1) * 16 <= mrem * 2; i++) { |
924 | // normal load |
925 | r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 16)); |
926 | } |
927 | if (i * 16 < mrem * 2) { |
928 | __mmask16 mask_mrem_v = (((long long)1) << (mrem * 2 - i * 16)) - 1; |
929 | r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16); |
930 | } |
931 | // transpose |
932 | __m512i index1 = _mm512_set_epi32( |
933 | 0x1e, |
934 | 0x1c, |
935 | 0x1a, |
936 | 0x18, |
937 | 0x16, |
938 | 0x14, |
939 | 0x12, |
940 | 0x10, |
941 | 0x0e, |
942 | 0x0c, |
943 | 0x0a, |
944 | 0x08, |
945 | 0x06, |
946 | 0x04, |
947 | 0x02, |
948 | 0x00); |
949 | __m512i index2 = _mm512_set_epi32( |
950 | 0x1f, |
951 | 0x1d, |
952 | 0x1b, |
953 | 0x19, |
954 | 0x17, |
955 | 0x15, |
956 | 0x13, |
957 | 0x11, |
958 | 0x0f, |
959 | 0x0d, |
960 | 0x0b, |
961 | 0x09, |
962 | 0x07, |
963 | 0x05, |
964 | 0x03, |
965 | 0x01); |
966 | |
967 | // a0--p0 |
968 | // a1--p1 |
969 | d[0] = _mm512_permutex2var_epi32(r[0], index1, r[1]); |
970 | d[1] = _mm512_permutex2var_epi32(r[0], index2, r[1]); |
971 | |
972 | // store |
973 | if (mrem < 16) { |
974 | __mmask16 mask_rem_v = (((long long)1) << mrem) - 1; |
975 | // mask store |
976 | _mm512_mask_storeu_epi32(dst, mask_rem_v, d[0]); |
977 | _mm512_mask_storeu_epi32(dst + ld_dst, mask_rem_v, d[1]); |
978 | } else { |
979 | // normal store |
980 | _mm512_storeu_si512(dst, d[0]); |
981 | _mm512_storeu_si512(dst + ld_dst, d[1]); |
982 | } |
983 | } |
984 | |
985 | static inline void transpose_contiguous_64x4_block( |
986 | const uint8_t* src, |
987 | uint8_t* dst, |
988 | int64_t ld_dst, |
989 | int mrem = 64) { |
990 | __m512i r[4], d[4]; |
991 | // normal load |
992 | int i = 0; |
993 | for (; (i + 1) * 64 <= mrem * 4; i++) { |
994 | r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 64)); |
995 | } |
996 | int erem = mrem * 4 - i * 64; |
997 | if (erem > 0) { |
998 | __mmask64 mask_mrem_v = (((long long)1) << erem) - 1; |
999 | r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64); |
1000 | } |
1001 | |
1002 | // transpose |
1003 | __m512i index = _mm512_set_epi32( |
1004 | 0x0f0b0703, |
1005 | 0x0e0a0602, |
1006 | 0x0d090501, |
1007 | 0x0c080400, |
1008 | 0x0f0b0703, |
1009 | 0x0e0a0602, |
1010 | 0x0d090501, |
1011 | 0x0c080400, |
1012 | 0x0f0b0703, |
1013 | 0x0e0a0602, |
1014 | 0x0d090501, |
1015 | 0x0c080400, |
1016 | 0x0f0b0703, |
1017 | 0x0e0a0602, |
1018 | 0x0d090501, |
1019 | 0x0c080400); |
1020 | |
1021 | d[0] = _mm512_shuffle_epi8(r[0], index); |
1022 | d[1] = _mm512_shuffle_epi8(r[1], index); |
1023 | d[2] = _mm512_shuffle_epi8(r[2], index); |
1024 | d[3] = _mm512_shuffle_epi8(r[3], index); |
1025 | |
1026 | __m512i index2 = |
1027 | _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); |
1028 | r[0] = _mm512_permutexvar_epi32(index2, d[0]); |
1029 | r[1] = _mm512_permutexvar_epi32(index2, d[1]); |
1030 | r[2] = _mm512_permutexvar_epi32(index2, d[2]); |
1031 | r[3] = _mm512_permutexvar_epi32(index2, d[3]); |
1032 | |
1033 | __m512i t0 = _mm512_shuffle_i32x4(r[0], r[1], 0x44); |
1034 | __m512i t1 = _mm512_shuffle_i32x4(r[0], r[1], 0xee); |
1035 | __m512i t2 = _mm512_shuffle_i32x4(r[2], r[3], 0x44); |
1036 | __m512i t3 = _mm512_shuffle_i32x4(r[2], r[3], 0xee); |
1037 | |
1038 | d[0] = _mm512_shuffle_i32x4(t0, t2, 0x88); |
1039 | d[1] = _mm512_shuffle_i32x4(t0, t2, 0xdd); |
1040 | d[2] = _mm512_shuffle_i32x4(t1, t3, 0x88); |
1041 | d[3] = _mm512_shuffle_i32x4(t1, t3, 0xdd); |
1042 | |
1043 | // store |
1044 | if (mrem < 64) { |
1045 | __mmask64 mask_rem_v = (((long long)1) << mrem) - 1; |
1046 | // mask store |
1047 | _mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]); |
1048 | _mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]); |
1049 | _mm512_mask_storeu_epi8(dst + 2 * ld_dst, mask_rem_v, d[2]); |
1050 | _mm512_mask_storeu_epi8(dst + 3 * ld_dst, mask_rem_v, d[3]); |
1051 | } else { |
1052 | // normal store |
1053 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]); |
1054 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]); |
1055 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]); |
1056 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]); |
1057 | } |
1058 | } |
1059 | |
1060 | static inline void transpose_contiguous_32x4_block( |
1061 | const uint16_t* src, |
1062 | uint16_t* dst, |
1063 | int64_t ld_dst, |
1064 | int mrem = 32) { |
1065 | __m512i r[4], d[4]; |
1066 | int i = 0; |
1067 | for (; (i + 1) * 32 <= mrem * 4; i++) { |
1068 | // normal load |
1069 | r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 32)); |
1070 | } |
1071 | if (i * 32 < mrem * 4) { |
1072 | __mmask32 mask_mrem_v = (((long long)1) << (mrem * 4 - i * 32)) - 1; |
1073 | r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32); |
1074 | } |
1075 | // transpose |
1076 | __m512i index = _mm512_set_epi32( |
1077 | 0x001f001b, |
1078 | 0x00170013, |
1079 | 0x000f000b, |
1080 | 0x00070003, |
1081 | 0x001e001a, |
1082 | 0x00160012, |
1083 | 0x000e000a, |
1084 | 0x00060002, |
1085 | 0x001d0019, |
1086 | 0x00150011, |
1087 | 0x000d0009, |
1088 | 0x00050001, |
1089 | 0x001c0018, |
1090 | 0x00140010, |
1091 | 0x000c0008, |
1092 | 0x00040000); |
1093 | |
1094 | d[0] = _mm512_permutexvar_epi16(index, r[0]); |
1095 | d[1] = _mm512_permutexvar_epi16(index, r[1]); |
1096 | d[2] = _mm512_permutexvar_epi16(index, r[2]); |
1097 | d[3] = _mm512_permutexvar_epi16(index, r[3]); |
1098 | |
1099 | r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); |
1100 | r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); |
1101 | r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); |
1102 | r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); |
1103 | |
1104 | d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); |
1105 | d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); |
1106 | d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); |
1107 | d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); |
1108 | |
1109 | if (mrem < 32) { |
1110 | // mask store |
1111 | __mmask32 mask_rem_v = (((long long)1) << mrem) - 1; |
1112 | _mm512_mask_storeu_epi16(dst + 0 * ld_dst, mask_rem_v, d[0]); |
1113 | _mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, d[1]); |
1114 | _mm512_mask_storeu_epi16(dst + 2 * ld_dst, mask_rem_v, d[2]); |
1115 | _mm512_mask_storeu_epi16(dst + 3 * ld_dst, mask_rem_v, d[3]); |
1116 | } else { |
1117 | // normal load |
1118 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]); |
1119 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]); |
1120 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]); |
1121 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]); |
1122 | } |
1123 | } |
1124 | |
1125 | static inline void transpose_contiguous_2x16_block( |
1126 | const float* src, |
1127 | float* dst, |
1128 | int64_t ld_src, |
1129 | int nrem = 16) { |
1130 | __m512i r0, r1; |
1131 | // load |
1132 | if (nrem < 16) { |
1133 | __mmask16 mask_mrem_v = (((long long)1) << nrem) - 1; |
1134 | r0 = _mm512_maskz_loadu_epi32(mask_mrem_v, src); |
1135 | r1 = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src); |
1136 | } else { |
1137 | r0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src)); |
1138 | r1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src)); |
1139 | } |
1140 | // transpose |
1141 | __m512i index1 = _mm512_set_epi32( |
1142 | 0x0017, |
1143 | 0x0007, |
1144 | 0x0016, |
1145 | 0x0006, |
1146 | 0x0015, |
1147 | 0x0005, |
1148 | 0x0014, |
1149 | 0x0004, |
1150 | 0x0013, |
1151 | 0x0003, |
1152 | 0x0012, |
1153 | 0x0002, |
1154 | 0x0011, |
1155 | 0x0001, |
1156 | 0x0010, |
1157 | 0x0000); |
1158 | __m512i index2 = _mm512_set_epi32( |
1159 | 0x001f, |
1160 | 0x000f, |
1161 | 0x001e, |
1162 | 0x000e, |
1163 | 0x001d, |
1164 | 0x000d, |
1165 | 0x001c, |
1166 | 0x000c, |
1167 | 0x001b, |
1168 | 0x000b, |
1169 | 0x001a, |
1170 | 0x000a, |
1171 | 0x0019, |
1172 | 0x0009, |
1173 | 0x0018, |
1174 | 0x0008); |
1175 | // a0 b0 a1 b1 a2 b2 a3 b3 a4 b4 a5 b5 a6 b6 a7 b7 |
1176 | // a8 b8 a9 b9 a10 b10 a11 b11 a12 b12 a13 b13 a14 b14 a15 b15 |
1177 | __m512i u0 = _mm512_permutex2var_epi32(r0, index1, r1); |
1178 | __m512i u1 = _mm512_permutex2var_epi32(r0, index2, r1); |
1179 | // store |
1180 | if (nrem < 16) { |
1181 | // mask store |
1182 | if (nrem < 8) { |
1183 | __mmask16 mask_rem_v = (((long long)1) << (nrem * 2)) - 1; |
1184 | _mm512_mask_storeu_epi32(dst, mask_rem_v, u0); |
1185 | } else { |
1186 | __mmask16 mask_rem_v = (((long long)1) << ((nrem - 8) * 2)) - 1; |
1187 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0); |
1188 | _mm512_mask_storeu_epi32(dst + 16, mask_rem_v, u1); |
1189 | } |
1190 | } else { |
1191 | // normal store |
1192 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0); |
1193 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 16), u1); |
1194 | } |
1195 | } |
1196 | |
1197 | static inline void transpose_contiguous_64x2_block( |
1198 | const uint8_t* src, |
1199 | uint8_t* dst, |
1200 | int64_t ld_dst, |
1201 | int mrem = 64) { |
1202 | __m512i r[2], d[2]; |
1203 | // normal load |
1204 | int i = 0; |
1205 | for (; (i + 1) * 64 <= mrem * 2; i++) { |
1206 | r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 64)); |
1207 | } |
1208 | int erem = mrem * 2 - i * 64; |
1209 | if (erem > 0) { |
1210 | __mmask64 mask_mrem_v = (((long long)1) << erem) - 1; |
1211 | r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64); |
1212 | } |
1213 | |
1214 | // transpose |
1215 | __m512i index1 = _mm512_set_epi32( |
1216 | 0x0f0d0b09, |
1217 | 0x07050301, |
1218 | 0x0e0c0a08, |
1219 | 0x06040200, |
1220 | 0x0f0d0b09, |
1221 | 0x07050301, |
1222 | 0x0e0c0a08, |
1223 | 0x06040200, |
1224 | 0x0f0d0b09, |
1225 | 0x07050301, |
1226 | 0x0e0c0a08, |
1227 | 0x06040200, |
1228 | 0x0f0d0b09, |
1229 | 0x07050301, |
1230 | 0x0e0c0a08, |
1231 | 0x06040200); |
1232 | r[0] = _mm512_shuffle_epi8(r[0], index1); |
1233 | r[1] = _mm512_shuffle_epi8(r[1], index1); |
1234 | |
1235 | __m512i index2 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); |
1236 | __m512i index3 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); |
1237 | |
1238 | d[0] = _mm512_permutex2var_epi64(r[0], index2, r[1]); |
1239 | d[1] = _mm512_permutex2var_epi64(r[0], index3, r[1]); |
1240 | |
1241 | // store |
1242 | if (mrem < 64) { |
1243 | __mmask64 mask_rem_v = (((long long)1) << mrem) - 1; |
1244 | // mask store |
1245 | _mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]); |
1246 | _mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]); |
1247 | } else { |
1248 | // normal store |
1249 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d[0]); |
1250 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + ld_dst), d[1]); |
1251 | } |
1252 | } |
1253 | |
1254 | static inline void transpose_contiguous_4x64_block( |
1255 | const uint8_t* src, |
1256 | uint8_t* dst, |
1257 | int64_t ld_src, |
1258 | int nrem = 64) { |
1259 | __m512i r[4], d[4]; |
1260 | // load |
1261 | if (nrem < 64) { |
1262 | __mmask64 mask_mrem_v = (((long long)1) << nrem) - 1; |
1263 | r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src); |
1264 | r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src); |
1265 | r[2] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 2 * ld_src); |
1266 | r[3] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 3 * ld_src); |
1267 | } else { |
1268 | r[0] = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src)); |
1269 | r[1] = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + ld_src)); |
1270 | r[2] = |
1271 | _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + 2 * ld_src)); |
1272 | r[3] = |
1273 | _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + 3 * ld_src)); |
1274 | } |
1275 | // transpose |
1276 | d[0] = _mm512_unpacklo_epi32(r[0], r[1]); |
1277 | d[1] = _mm512_unpackhi_epi32(r[0], r[1]); |
1278 | d[2] = _mm512_unpacklo_epi32(r[2], r[3]); |
1279 | d[3] = _mm512_unpackhi_epi32(r[2], r[3]); |
1280 | |
1281 | r[0] = _mm512_unpacklo_epi64(d[0], d[2]); |
1282 | r[1] = _mm512_unpackhi_epi64(d[0], d[2]); |
1283 | r[2] = _mm512_unpacklo_epi64(d[1], d[3]); |
1284 | r[3] = _mm512_unpackhi_epi64(d[1], d[3]); |
1285 | |
1286 | d[0] = _mm512_shuffle_i32x4(r[0], r[1], 0x44); |
1287 | d[1] = _mm512_shuffle_i32x4(r[0], r[1], 0xee); |
1288 | d[2] = _mm512_shuffle_i32x4(r[2], r[3], 0x44); |
1289 | d[3] = _mm512_shuffle_i32x4(r[2], r[3], 0xee); |
1290 | |
1291 | r[0] = _mm512_shuffle_i32x4(d[0], d[2], 0x88); |
1292 | r[1] = _mm512_shuffle_i32x4(d[0], d[2], 0xdd); |
1293 | r[2] = _mm512_shuffle_i32x4(d[1], d[3], 0x88); |
1294 | r[3] = _mm512_shuffle_i32x4(d[1], d[3], 0xdd); |
1295 | |
1296 | __m512i index = _mm512_set_epi32( |
1297 | 0x0f0b0703, |
1298 | 0x0e0a0602, |
1299 | 0x0d090501, |
1300 | 0x0c080400, |
1301 | 0x0f0b0703, |
1302 | 0x0e0a0602, |
1303 | 0x0d090501, |
1304 | 0x0c080400, |
1305 | 0x0f0b0703, |
1306 | 0x0e0a0602, |
1307 | 0x0d090501, |
1308 | 0x0c080400, |
1309 | 0x0f0b0703, |
1310 | 0x0e0a0602, |
1311 | 0x0d090501, |
1312 | 0x0c080400); |
1313 | |
1314 | d[0] = _mm512_shuffle_epi8(r[0], index); |
1315 | d[1] = _mm512_shuffle_epi8(r[1], index); |
1316 | d[2] = _mm512_shuffle_epi8(r[2], index); |
1317 | d[3] = _mm512_shuffle_epi8(r[3], index); |
1318 | |
1319 | // store |
1320 | int i = 0; |
1321 | for (; (i + 1) * 64 <= nrem * 4; i++) { |
1322 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 64), d[i]); |
1323 | } |
1324 | int erem = nrem * 4 - i * 64; |
1325 | if (erem > 0) { |
1326 | __mmask64 mask_rem_v = (((long long)1) << erem) - 1; |
1327 | _mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]); |
1328 | } |
1329 | } |
1330 | |
1331 | static inline void transpose_contiguous_2x64_block( |
1332 | const uint8_t* src, |
1333 | uint8_t* dst, |
1334 | int64_t ld_src, |
1335 | int nrem = 64) { |
1336 | __m512i r[2]; |
1337 | __m512i d[2]; |
1338 | // load |
1339 | if (nrem < 64) { |
1340 | __mmask64 mask_mrem_v = (((long long)1) << nrem) - 1; |
1341 | r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src); |
1342 | r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src); |
1343 | } else { |
1344 | r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src)); |
1345 | r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src)); |
1346 | } |
1347 | // transpose |
1348 | // _mm512_mask_blend_epi8(0xaaaaaaaaaaaaaaaa, r0, r1); |
1349 | d[0] = _mm512_unpacklo_epi16(r[0], r[1]); |
1350 | d[1] = _mm512_unpackhi_epi16(r[0], r[1]); |
1351 | __m512i index1 = _mm512_set_epi32( |
1352 | 0x0f0d0e0c, |
1353 | 0x0b090a08, |
1354 | 0x07050604, |
1355 | 0x03010200, |
1356 | 0x0f0d0e0c, |
1357 | 0x0b090a08, |
1358 | 0x07050604, |
1359 | 0x03010200, |
1360 | 0x0f0d0e0c, |
1361 | 0x0b090a08, |
1362 | 0x07050604, |
1363 | 0x03010200, |
1364 | 0x0f0d0e0c, |
1365 | 0x0b090a08, |
1366 | 0x07050604, |
1367 | 0x03010200); |
1368 | r[0] = _mm512_shuffle_epi8(d[0], index1); |
1369 | r[1] = _mm512_shuffle_epi8(d[1], index1); |
1370 | __m512i index2 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); |
1371 | __m512i index3 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); |
1372 | // a0b0 a1b1 ... a31b31 |
1373 | // a32b32 ... a63b63 |
1374 | d[0] = _mm512_permutex2var_epi64(r[0], index2, r[1]); |
1375 | d[1] = _mm512_permutex2var_epi64(r[0], index3, r[1]); |
1376 | |
1377 | int i = 0; |
1378 | for (; (i + 1) * 64 <= nrem * 2; i++) { |
1379 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 64), d[i]); |
1380 | } |
1381 | int erem = nrem * 2 - i * 64; |
1382 | if (erem > 0) { |
1383 | __mmask64 mask_rem_v = (((long long)1) << erem) - 1; |
1384 | _mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]); |
1385 | } |
1386 | } |
1387 | |
1388 | static inline void transpose_contiguous_2x32_block( |
1389 | const uint16_t* src, |
1390 | uint16_t* dst, |
1391 | int64_t ld_src, |
1392 | int nrem = 32) { |
1393 | __m512i r0, r1; |
1394 | __m512i d0, d1; |
1395 | // load |
1396 | if (nrem < 32) { |
1397 | __mmask32 mask_mrem_v = (((long long)1) << nrem) - 1; |
1398 | r0 = _mm512_maskz_loadu_epi16(mask_mrem_v, src); |
1399 | r1 = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src); |
1400 | } else { |
1401 | r0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src)); |
1402 | r1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src)); |
1403 | } |
1404 | // transpose |
1405 | d0 = _mm512_unpacklo_epi16(r0, r1); |
1406 | d1 = _mm512_unpackhi_epi16(r0, r1); |
1407 | r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); |
1408 | r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); |
1409 | d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); |
1410 | d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); |
1411 | |
1412 | // store |
1413 | if (nrem < 16) { |
1414 | __mmask32 mask_rem_v = (((long long)1) << (nrem * 2)) - 1; |
1415 | _mm512_mask_storeu_epi16(dst, mask_rem_v, d0); |
1416 | } else if (nrem == 16) { |
1417 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); |
1418 | } else if (nrem < 32) { |
1419 | __mmask32 mask_rem_v = (((long long)1) << (nrem * 2 - 32)) - 1; |
1420 | _mm512_mask_storeu_epi16(dst, mask_rem_v, d0); |
1421 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); |
1422 | _mm512_mask_storeu_epi16( |
1423 | reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1); |
1424 | } else { |
1425 | // normal store |
1426 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); |
1427 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); |
1428 | } |
1429 | } |
1430 | |
1431 | static inline void transpose_contiguous_32x2_block( |
1432 | const uint16_t* src, |
1433 | uint16_t* dst, |
1434 | int64_t ld_dst, |
1435 | int mrem = 32) { |
1436 | __m512i r[2], d[2]; |
1437 | // load |
1438 | int i = 0; |
1439 | for (; (i + 1) * 32 <= mrem * 2; i++) { |
1440 | r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 32)); |
1441 | } |
1442 | int erem = mrem * 2 - i * 32; |
1443 | if (erem > 0) { |
1444 | __mmask32 mask_mrem_v = (((long long)1) << erem) - 1; |
1445 | r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32); |
1446 | } |
1447 | // transpose |
1448 | __m512i index = _mm512_set_epi32( |
1449 | 0x001f001d, |
1450 | 0x001b0019, |
1451 | 0x00170015, |
1452 | 0x00130011, |
1453 | 0x000f000d, |
1454 | 0x000b0009, |
1455 | 0x00070005, |
1456 | 0x00030001, |
1457 | 0x001e001c, |
1458 | 0x001a0018, |
1459 | 0x00160014, |
1460 | 0x00120010, |
1461 | 0x000e000c, |
1462 | 0x000a0008, |
1463 | 0x00060004, |
1464 | 0x00020000); |
1465 | d[0] = _mm512_permutexvar_epi16(index, r[0]); |
1466 | d[1] = _mm512_permutexvar_epi16(index, r[1]); |
1467 | r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); |
1468 | r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); |
1469 | |
1470 | // store |
1471 | if (mrem < 32) { |
1472 | __mmask32 mask_rem_v = (((long long)1) << mrem) - 1; |
1473 | // mask store |
1474 | _mm512_mask_storeu_epi16(dst, mask_rem_v, r[0]); |
1475 | _mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, r[1]); |
1476 | } else { |
1477 | // normal store |
1478 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r[0]); |
1479 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + ld_dst), r[1]); |
1480 | } |
1481 | } |
1482 | |
1483 | template <bool MREM = false, bool NREM = false> |
1484 | void transpose_16x16_block( |
1485 | const uint16_t* src, |
1486 | int64_t ld_src, |
1487 | uint16_t* dst, |
1488 | int64_t ld_dst, |
1489 | int mrem = 16, |
1490 | int nrem = 16) { |
1491 | __m512i r[8]; |
1492 | if (MREM || NREM) { |
1493 | load_with_remainders_i16(src, ld_src, r, mrem, nrem); |
1494 | } else { |
1495 | __m256i t00 = |
1496 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 0 * ld_src)); |
1497 | __m256i t01 = |
1498 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 1 * ld_src)); |
1499 | __m256i t02 = |
1500 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 2 * ld_src)); |
1501 | __m256i t03 = |
1502 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 3 * ld_src)); |
1503 | __m256i t04 = |
1504 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 4 * ld_src)); |
1505 | __m256i t05 = |
1506 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 5 * ld_src)); |
1507 | __m256i t06 = |
1508 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 6 * ld_src)); |
1509 | __m256i t07 = |
1510 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 7 * ld_src)); |
1511 | __m256i t08 = |
1512 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 8 * ld_src)); |
1513 | __m256i t09 = |
1514 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 9 * ld_src)); |
1515 | __m256i t10 = |
1516 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 10 * ld_src)); |
1517 | __m256i t11 = |
1518 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 11 * ld_src)); |
1519 | __m256i t12 = |
1520 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 12 * ld_src)); |
1521 | __m256i t13 = |
1522 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 13 * ld_src)); |
1523 | __m256i t14 = |
1524 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 14 * ld_src)); |
1525 | __m256i t15 = |
1526 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 15 * ld_src)); |
1527 | |
1528 | // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15 |
1529 | // e0e1 e2e3 e4e5 e6e7 e8e9 e10e11 e12e13 e14e15 |
1530 | r[0] = _mm512_inserti64x4(_mm512_castsi256_si512(t00), t04, 0x01); |
1531 | // b0-b15 |
1532 | // f0-f15 |
1533 | r[1] = _mm512_inserti64x4(_mm512_castsi256_si512(t01), t05, 0x01); |
1534 | // c0-c15 |
1535 | // g0-g15 |
1536 | r[2] = _mm512_inserti64x4(_mm512_castsi256_si512(t02), t06, 0x01); |
1537 | // d0-d15 |
1538 | // g0-h15 |
1539 | r[3] = _mm512_inserti64x4(_mm512_castsi256_si512(t03), t07, 0x01); |
1540 | // i0-i15 |
1541 | // m0-m15 |
1542 | r[4] = _mm512_inserti64x4(_mm512_castsi256_si512(t08), t12, 0x01); |
1543 | // j0-j15 |
1544 | // n0-n15 |
1545 | r[5] = _mm512_inserti64x4(_mm512_castsi256_si512(t09), t13, 0x01); |
1546 | // k0-k15 |
1547 | // o0-o15 |
1548 | r[6] = _mm512_inserti64x4(_mm512_castsi256_si512(t10), t14, 0x01); |
1549 | // l0-l15 |
1550 | // p0-p15 |
1551 | r[7] = _mm512_inserti64x4(_mm512_castsi256_si512(t11), t15, 0x01); |
1552 | } |
1553 | __m512i u[8]; |
1554 | core_transpose_16x16_block(r, u); |
1555 | if (MREM || NREM) { |
1556 | store_with_remainders_i16(dst, ld_dst, u, mrem, nrem); |
1557 | } else { |
1558 | _mm256_storeu_si256( |
1559 | reinterpret_cast<__m256i*>(dst + 0 * ld_dst), |
1560 | _mm512_extracti32x8_epi32(u[0], 0x0)); |
1561 | _mm256_storeu_si256( |
1562 | reinterpret_cast<__m256i*>(dst + 1 * ld_dst), |
1563 | _mm512_extracti32x8_epi32(u[0], 0x01)); |
1564 | _mm256_storeu_si256( |
1565 | reinterpret_cast<__m256i*>(dst + 2 * ld_dst), |
1566 | _mm512_extracti32x8_epi32(u[1], 0x0)); |
1567 | _mm256_storeu_si256( |
1568 | reinterpret_cast<__m256i*>(dst + 3 * ld_dst), |
1569 | _mm512_extracti32x8_epi32(u[1], 0x01)); |
1570 | _mm256_storeu_si256( |
1571 | reinterpret_cast<__m256i*>(dst + 4 * ld_dst), |
1572 | _mm512_extracti32x8_epi32(u[2], 0x0)); |
1573 | _mm256_storeu_si256( |
1574 | reinterpret_cast<__m256i*>(dst + 5 * ld_dst), |
1575 | _mm512_extracti32x8_epi32(u[2], 0x01)); |
1576 | _mm256_storeu_si256( |
1577 | reinterpret_cast<__m256i*>(dst + 6 * ld_dst), |
1578 | _mm512_extracti32x8_epi32(u[3], 0x0)); |
1579 | _mm256_storeu_si256( |
1580 | reinterpret_cast<__m256i*>(dst + 7 * ld_dst), |
1581 | _mm512_extracti32x8_epi32(u[3], 0x01)); |
1582 | _mm256_storeu_si256( |
1583 | reinterpret_cast<__m256i*>(dst + 8 * ld_dst), |
1584 | _mm512_extracti32x8_epi32(u[4], 0x0)); |
1585 | _mm256_storeu_si256( |
1586 | reinterpret_cast<__m256i*>(dst + 9 * ld_dst), |
1587 | _mm512_extracti32x8_epi32(u[4], 0x01)); |
1588 | _mm256_storeu_si256( |
1589 | reinterpret_cast<__m256i*>(dst + 10 * ld_dst), |
1590 | _mm512_extracti32x8_epi32(u[5], 0x0)); |
1591 | _mm256_storeu_si256( |
1592 | reinterpret_cast<__m256i*>(dst + 11 * ld_dst), |
1593 | _mm512_extracti32x8_epi32(u[5], 0x01)); |
1594 | _mm256_storeu_si256( |
1595 | reinterpret_cast<__m256i*>(dst + 12 * ld_dst), |
1596 | _mm512_extracti32x8_epi32(u[6], 0x0)); |
1597 | _mm256_storeu_si256( |
1598 | reinterpret_cast<__m256i*>(dst + 13 * ld_dst), |
1599 | _mm512_extracti32x8_epi32(u[6], 0x01)); |
1600 | _mm256_storeu_si256( |
1601 | reinterpret_cast<__m256i*>(dst + 14 * ld_dst), |
1602 | _mm512_extracti32x8_epi32(u[7], 0x0)); |
1603 | _mm256_storeu_si256( |
1604 | reinterpret_cast<__m256i*>(dst + 15 * ld_dst), |
1605 | _mm512_extracti32x8_epi32(u[7], 0x01)); |
1606 | } |
1607 | } |
1608 | |
1609 | template <bool MREM = false, bool NREM = false> |
1610 | void transpose_16x32_block( |
1611 | const uint8_t* src, |
1612 | int64_t ld_src, |
1613 | uint8_t* dst, |
1614 | int64_t ld_dst, |
1615 | int mrem = 16, |
1616 | int nrem = 32) { |
1617 | // Treat the numbers in a row as 4-Byte integers. |
1618 | // Thus 03_04 is is 4-byte element in 03 row and 04 column |
1619 | // |
1620 | // 00_00 00_01 00_02 00_03 00_04 00_05 00_06 00_07 |
1621 | // 01_00 01_01 01_02 01_03 01_04 01_05 01_06 01_07 |
1622 | // 02_00 02_01 02_02 02_03 02_04 02_05 02_06 02_07 |
1623 | // 03_00 03_01 03_02 03_03 03_04 03_05 03_06 03_07 |
1624 | // 04_00 04_01 04_02 04_03 04_04 04_05 04_06 04_07 |
1625 | // 05_00 05_01 05_02 05_03 05_04 05_05 05_06 05_07 |
1626 | // 06_00 06_01 06_02 06_03 06_04 06_05 06_06 06_07 |
1627 | // 07_00 07_01 07_02 07_03 07_04 07_05 07_06 07_07 |
1628 | // 08_00 08_01 08_02 08_03 08_04 08_05 08_06 08_07 |
1629 | // 09_00 09_01 09_02 09_03 09_04 09_05 09_06 09_07 |
1630 | // 10_00 10_01 10_02 10_03 10_04 10_05 10_06 10_07 |
1631 | // 11_00 11_01 11_02 11_03 11_04 11_05 11_06 11_07 |
1632 | // 12_00 12_01 12_02 12_03 12_04 12_05 12_06 12_07 |
1633 | // 13_00 13_01 13_02 13_03 13_04 13_05 13_06 13_07 |
1634 | // 14_00 14_01 14_02 14_03 14_04 14_05 14_06 14_07 |
1635 | // 15_00 15_01 15_02 15_03 15_04 15_05 15_06 15_07 |
1636 | |
1637 | __m512i r[8]; |
1638 | if (MREM || NREM) { |
1639 | load_with_remainders_i8(src, ld_src, r, mrem, nrem); |
1640 | } else { |
1641 | __m256i t00 = |
1642 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 0 * ld_src)); |
1643 | __m256i t04 = |
1644 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 4 * ld_src)); |
1645 | // 00_00 00_01 00_02 00_03 00_04 00_05 00_06 00_07 04_00 04_01 04_02 04_03 |
1646 | // 04_04 04_05 04_06 04_07 |
1647 | r[0] = _mm512_inserti64x4(_mm512_castsi256_si512(t00), t04, 0x01); |
1648 | |
1649 | __m256i t01 = |
1650 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 1 * ld_src)); |
1651 | __m256i t05 = |
1652 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 5 * ld_src)); |
1653 | // 01_00 01_01 01_02 01_03 01_04 01_05 01_06 01_07 05_00 05_01 05_02 05_03 |
1654 | // 05_04 05_05 05_06 05_07 |
1655 | r[1] = _mm512_inserti64x4(_mm512_castsi256_si512(t01), t05, 0x01); |
1656 | |
1657 | __m256i t02 = |
1658 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 2 * ld_src)); |
1659 | __m256i t06 = |
1660 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 6 * ld_src)); |
1661 | // 02_00 02_01 02_02 02_03 02_04 02_05 02_06 02_07 06_00 06_01 06_02 06_03 |
1662 | // 06_04 06_05 06_06 06_07 |
1663 | r[2] = _mm512_inserti64x4(_mm512_castsi256_si512(t02), t06, 0x01); |
1664 | |
1665 | __m256i t03 = |
1666 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 3 * ld_src)); |
1667 | __m256i t07 = |
1668 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 7 * ld_src)); |
1669 | // 03_00 03_01 03_02 03_03 03_04 03_05 03_06 03_07 07_00 07_01 07_02 07_03 |
1670 | // 07_04 07_05 07_06 07_07 |
1671 | r[3] = _mm512_inserti64x4(_mm512_castsi256_si512(t03), t07, 0x01); |
1672 | |
1673 | __m256i t08 = |
1674 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 8 * ld_src)); |
1675 | __m256i t12 = |
1676 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 12 * ld_src)); |
1677 | // 08_00 08_01 08_02 08_03 08_04 08_05 08_06 08_07 12_00 12_01 12_02 12_03 |
1678 | // 12_04 12_05 12_06 12_07 |
1679 | r[4] = _mm512_inserti64x4(_mm512_castsi256_si512(t08), t12, 0x01); |
1680 | |
1681 | __m256i t09 = |
1682 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 9 * ld_src)); |
1683 | __m256i t13 = |
1684 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 13 * ld_src)); |
1685 | // 09_00 09_01 09_02 09_03 09_04 09_05 09_06 09_07 13_00 13_01 13_02 13_03 |
1686 | // 13_04 13_05 13_06 13_07 |
1687 | r[5] = _mm512_inserti64x4(_mm512_castsi256_si512(t09), t13, 0x01); |
1688 | |
1689 | __m256i t10 = |
1690 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 10 * ld_src)); |
1691 | __m256i t14 = |
1692 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 14 * ld_src)); |
1693 | // 10_00 10_01 10_02 10_03 10_04 10_05 10_06 10_07 14_00 14_01 14_02 14_03 |
1694 | // 14_04 14_05 14_06 14_07 |
1695 | r[6] = _mm512_inserti64x4(_mm512_castsi256_si512(t10), t14, 0x01); |
1696 | |
1697 | __m256i t11 = |
1698 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 11 * ld_src)); |
1699 | __m256i t15 = |
1700 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 15 * ld_src)); |
1701 | // 11_00 11_01 11_02 11_03 11_04 11_05 11_06 11_07 15_00 15_01 15_02 15_03 |
1702 | // 15_04 15_05 15_06 15_07 |
1703 | r[7] = _mm512_inserti64x4(_mm512_castsi256_si512(t11), t15, 0x01); |
1704 | } |
1705 | |
1706 | __m512i u[8]; |
1707 | core_transpose_16x32_block_i8(r, u); |
1708 | |
1709 | if (MREM || NREM) { |
1710 | store_with_remainders_i8(dst, ld_dst, u, mrem, nrem); |
1711 | } else { |
1712 | _mm_storeu_si128( |
1713 | reinterpret_cast<__m128i*>(dst + 0 * ld_dst), |
1714 | _mm512_extracti32x4_epi32(u[0], 0x0)); |
1715 | _mm_storeu_si128( |
1716 | reinterpret_cast<__m128i*>(dst + 1 * ld_dst), |
1717 | _mm512_extracti32x4_epi32(u[0], 0x1)); |
1718 | _mm_storeu_si128( |
1719 | reinterpret_cast<__m128i*>(dst + 2 * ld_dst), |
1720 | _mm512_extracti32x4_epi32(u[0], 0x2)); |
1721 | _mm_storeu_si128( |
1722 | reinterpret_cast<__m128i*>(dst + 3 * ld_dst), |
1723 | _mm512_extracti32x4_epi32(u[0], 0x3)); |
1724 | _mm_storeu_si128( |
1725 | reinterpret_cast<__m128i*>(dst + 16 * ld_dst), |
1726 | _mm512_extracti32x4_epi32(u[1], 0x0)); |
1727 | _mm_storeu_si128( |
1728 | reinterpret_cast<__m128i*>(dst + 17 * ld_dst), |
1729 | _mm512_extracti32x4_epi32(u[1], 0x1)); |
1730 | _mm_storeu_si128( |
1731 | reinterpret_cast<__m128i*>(dst + 18 * ld_dst), |
1732 | _mm512_extracti32x4_epi32(u[1], 0x2)); |
1733 | _mm_storeu_si128( |
1734 | reinterpret_cast<__m128i*>(dst + 19 * ld_dst), |
1735 | _mm512_extracti32x4_epi32(u[1], 0x3)); |
1736 | _mm_storeu_si128( |
1737 | reinterpret_cast<__m128i*>(dst + 4 * ld_dst), |
1738 | _mm512_extracti32x4_epi32(u[2], 0x0)); |
1739 | _mm_storeu_si128( |
1740 | reinterpret_cast<__m128i*>(dst + 5 * ld_dst), |
1741 | _mm512_extracti32x4_epi32(u[2], 0x1)); |
1742 | _mm_storeu_si128( |
1743 | reinterpret_cast<__m128i*>(dst + 6 * ld_dst), |
1744 | _mm512_extracti32x4_epi32(u[2], 0x2)); |
1745 | _mm_storeu_si128( |
1746 | reinterpret_cast<__m128i*>(dst + 7 * ld_dst), |
1747 | _mm512_extracti32x4_epi32(u[2], 0x3)); |
1748 | _mm_storeu_si128( |
1749 | reinterpret_cast<__m128i*>(dst + 20 * ld_dst), |
1750 | _mm512_extracti32x4_epi32(u[3], 0x0)); |
1751 | _mm_storeu_si128( |
1752 | reinterpret_cast<__m128i*>(dst + 21 * ld_dst), |
1753 | _mm512_extracti32x4_epi32(u[3], 0x1)); |
1754 | _mm_storeu_si128( |
1755 | reinterpret_cast<__m128i*>(dst + 22 * ld_dst), |
1756 | _mm512_extracti32x4_epi32(u[3], 0x2)); |
1757 | _mm_storeu_si128( |
1758 | reinterpret_cast<__m128i*>(dst + 23 * ld_dst), |
1759 | _mm512_extracti32x4_epi32(u[3], 0x3)); |
1760 | _mm_storeu_si128( |
1761 | reinterpret_cast<__m128i*>(dst + 8 * ld_dst), |
1762 | _mm512_extracti32x4_epi32(u[4], 0x0)); |
1763 | _mm_storeu_si128( |
1764 | reinterpret_cast<__m128i*>(dst + 9 * ld_dst), |
1765 | _mm512_extracti32x4_epi32(u[4], 0x1)); |
1766 | _mm_storeu_si128( |
1767 | reinterpret_cast<__m128i*>(dst + 10 * ld_dst), |
1768 | _mm512_extracti32x4_epi32(u[4], 0x2)); |
1769 | _mm_storeu_si128( |
1770 | reinterpret_cast<__m128i*>(dst + 11 * ld_dst), |
1771 | _mm512_extracti32x4_epi32(u[4], 0x3)); |
1772 | _mm_storeu_si128( |
1773 | reinterpret_cast<__m128i*>(dst + 24 * ld_dst), |
1774 | _mm512_extracti32x4_epi32(u[5], 0x0)); |
1775 | _mm_storeu_si128( |
1776 | reinterpret_cast<__m128i*>(dst + 25 * ld_dst), |
1777 | _mm512_extracti32x4_epi32(u[5], 0x1)); |
1778 | _mm_storeu_si128( |
1779 | reinterpret_cast<__m128i*>(dst + 26 * ld_dst), |
1780 | _mm512_extracti32x4_epi32(u[5], 0x2)); |
1781 | _mm_storeu_si128( |
1782 | reinterpret_cast<__m128i*>(dst + 27 * ld_dst), |
1783 | _mm512_extracti32x4_epi32(u[5], 0x3)); |
1784 | _mm_storeu_si128( |
1785 | reinterpret_cast<__m128i*>(dst + 12 * ld_dst), |
1786 | _mm512_extracti32x4_epi32(u[6], 0x0)); |
1787 | _mm_storeu_si128( |
1788 | reinterpret_cast<__m128i*>(dst + 13 * ld_dst), |
1789 | _mm512_extracti32x4_epi32(u[6], 0x1)); |
1790 | _mm_storeu_si128( |
1791 | reinterpret_cast<__m128i*>(dst + 14 * ld_dst), |
1792 | _mm512_extracti32x4_epi32(u[6], 0x2)); |
1793 | _mm_storeu_si128( |
1794 | reinterpret_cast<__m128i*>(dst + 15 * ld_dst), |
1795 | _mm512_extracti32x4_epi32(u[6], 0x3)); |
1796 | _mm_storeu_si128( |
1797 | reinterpret_cast<__m128i*>(dst + 28 * ld_dst), |
1798 | _mm512_extracti32x4_epi32(u[7], 0x0)); |
1799 | _mm_storeu_si128( |
1800 | reinterpret_cast<__m128i*>(dst + 29 * ld_dst), |
1801 | _mm512_extracti32x4_epi32(u[7], 0x1)); |
1802 | _mm_storeu_si128( |
1803 | reinterpret_cast<__m128i*>(dst + 30 * ld_dst), |
1804 | _mm512_extracti32x4_epi32(u[7], 0x2)); |
1805 | _mm_storeu_si128( |
1806 | reinterpret_cast<__m128i*>(dst + 31 * ld_dst), |
1807 | _mm512_extracti32x4_epi32(u[7], 0x3)); |
1808 | } |
1809 | } |
1810 | |
1811 | template <> |
1812 | void transpose_avx512_contiguous_thin( |
1813 | int64_t M, |
1814 | int64_t N, |
1815 | const float* src, |
1816 | int64_t ld_src, |
1817 | float* dst, |
1818 | int64_t ld_dst) { |
1819 | if (N == 2) { |
1820 | int64_t i = 0; |
1821 | for (; i < M / 16 * 16; i += 16) { |
1822 | transpose_contiguous_16x2_block(src + i * ld_src, dst + i, ld_dst); |
1823 | } |
1824 | int mrem = M - i; |
1825 | if (mrem > 0) { |
1826 | transpose_contiguous_16x2_block(src + i * ld_src, dst + i, ld_dst, mrem); |
1827 | } |
1828 | } else if (N == 4) { |
1829 | int64_t i = 0; |
1830 | for (; i < M / 16 * 16; i += 16) { |
1831 | transpose_contiguous_16x4_block(src + i * ld_src, dst + i, ld_dst); |
1832 | } |
1833 | int mrem = M - i; |
1834 | if (mrem > 0) { |
1835 | transpose_contiguous_16x4_block(src + i * ld_src, dst + i, ld_dst, mrem); |
1836 | } |
1837 | } |
1838 | } |
1839 | |
1840 | template <> |
1841 | void transpose_avx512_contiguous_thin( |
1842 | int64_t M, |
1843 | int64_t N, |
1844 | const uint16_t* src, |
1845 | int64_t ld_src, |
1846 | uint16_t* dst, |
1847 | int64_t ld_dst) { |
1848 | if (N == 2) { |
1849 | int64_t i = 0; |
1850 | for (; i < M / 32 * 32; i += 32) { |
1851 | transpose_contiguous_32x2_block(src + i * ld_src, dst + i, ld_dst); |
1852 | } |
1853 | int mrem = M - i; |
1854 | if (mrem > 0) { |
1855 | transpose_contiguous_32x2_block(src + i * ld_src, dst + i, ld_dst, mrem); |
1856 | } |
1857 | } else if (N == 4) { |
1858 | int64_t i = 0; |
1859 | for (; i < M / 32 * 32; i += 32) { |
1860 | transpose_contiguous_32x4_block(src + i * ld_src, dst + i, ld_dst); |
1861 | } |
1862 | int mrem = M - i; |
1863 | if (mrem > 0) { |
1864 | transpose_contiguous_32x4_block(src + i * ld_src, dst + i, ld_dst, mrem); |
1865 | } |
1866 | } |
1867 | } |
1868 | |
1869 | template <> |
1870 | void transpose_avx512_contiguous_thin( |
1871 | int64_t M, |
1872 | int64_t N, |
1873 | const uint8_t* src, |
1874 | int64_t ld_src, |
1875 | uint8_t* dst, |
1876 | int64_t ld_dst) { |
1877 | if (N == 2) { |
1878 | int64_t i = 0; |
1879 | for (; i < M / 64 * 64; i += 64) { |
1880 | transpose_contiguous_64x2_block(src + i * ld_src, dst + i, ld_dst); |
1881 | } |
1882 | int mrem = M - i; |
1883 | if (mrem > 0) { |
1884 | transpose_contiguous_64x2_block(src + i * ld_src, dst + i, ld_dst, mrem); |
1885 | } |
1886 | } else if (N == 4) { |
1887 | int64_t i = 0; |
1888 | for (; i < M / 64 * 64; i += 64) { |
1889 | transpose_contiguous_64x4_block(src + i * ld_src, dst + i, ld_dst); |
1890 | } |
1891 | int mrem = M - i; |
1892 | if (mrem > 0) { |
1893 | transpose_contiguous_64x4_block(src + i * ld_src, dst + i, ld_dst, mrem); |
1894 | } |
1895 | } |
1896 | } |
1897 | |
1898 | template <> |
1899 | void transpose_avx512_contiguous_wide( |
1900 | int64_t M, |
1901 | int64_t N, |
1902 | const float* src, |
1903 | int64_t ld_src, |
1904 | float* dst, |
1905 | int64_t ld_dst) { |
1906 | if (M == 2) { |
1907 | int64_t i = 0; |
1908 | for (; i < N / 16 * 16; i += 16) { |
1909 | transpose_contiguous_2x16_block(src + i, dst + i * ld_dst, ld_src); |
1910 | } |
1911 | int nrem = N - i; |
1912 | if (nrem > 0) { |
1913 | transpose_contiguous_2x16_block(src + i, dst + i * ld_dst, ld_src, nrem); |
1914 | } |
1915 | } else if (M == 4) { |
1916 | int64_t i = 0; |
1917 | for (; i < N / 16 * 16; i += 16) { |
1918 | transpose_contiguous_4x16_block(src + i, dst + i * ld_dst, ld_src); |
1919 | } |
1920 | int nrem = N - i; |
1921 | if (nrem > 0) { |
1922 | transpose_contiguous_4x16_block(src + i, dst + i * ld_dst, ld_src, nrem); |
1923 | } |
1924 | } |
1925 | } |
1926 | |
1927 | template <> |
1928 | void transpose_avx512_contiguous_wide( |
1929 | int64_t M, |
1930 | int64_t N, |
1931 | const uint16_t* src, |
1932 | int64_t ld_src, |
1933 | uint16_t* dst, |
1934 | int64_t ld_dst) { |
1935 | if (M == 2) { |
1936 | int64_t i = 0; |
1937 | for (; i < N / 32 * 32; i += 32) { |
1938 | transpose_contiguous_2x32_block(src + i, dst + i * ld_dst, ld_src); |
1939 | } |
1940 | int nrem = N - i; |
1941 | if (nrem > 0) { |
1942 | transpose_contiguous_2x32_block(src + i, dst + i * ld_dst, ld_src, nrem); |
1943 | } |
1944 | } else if (M == 4) { |
1945 | int64_t i = 0; |
1946 | for (; i < N / 32 * 32; i += 32) { |
1947 | transpose_contiguous_4x32_block(src + i, dst + i * ld_dst, ld_src); |
1948 | } |
1949 | int nrem = N - i; |
1950 | if (nrem > 0) { |
1951 | transpose_contiguous_4x32_block(src + i, dst + i * ld_dst, ld_src, nrem); |
1952 | } |
1953 | } |
1954 | } |
1955 | |
1956 | template <> |
1957 | void transpose_avx512_contiguous_wide( |
1958 | int64_t M, |
1959 | int64_t N, |
1960 | const uint8_t* src, |
1961 | int64_t ld_src, |
1962 | uint8_t* dst, |
1963 | int64_t ld_dst) { |
1964 | if (M == 2) { |
1965 | int64_t i = 0; |
1966 | for (; i < N / 64 * 64; i += 64) { |
1967 | transpose_contiguous_2x64_block(src + i, dst + i * ld_dst, ld_src); |
1968 | } |
1969 | int nrem = N - i; |
1970 | if (nrem > 0) { |
1971 | transpose_contiguous_2x64_block(src + i, dst + i * ld_dst, ld_src, nrem); |
1972 | } |
1973 | } else if (M == 4) { |
1974 | int64_t i = 0; |
1975 | for (; i < N / 64 * 64; i += 64) { |
1976 | transpose_contiguous_4x64_block(src + i, dst + i * ld_dst, ld_src); |
1977 | } |
1978 | int nrem = N - i; |
1979 | if (nrem > 0) { |
1980 | transpose_contiguous_4x64_block(src + i, dst + i * ld_dst, ld_src, nrem); |
1981 | } |
1982 | } |
1983 | } |
1984 | |
1985 | template <> |
1986 | void transpose_avx512( |
1987 | int64_t M, |
1988 | int64_t N, |
1989 | const float* src, |
1990 | int64_t ld_src, |
1991 | float* dst, |
1992 | int64_t ld_dst) { |
1993 | if (M == ld_dst && (M == 2 || M == 4)) { |
1994 | transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst); |
1995 | } else if (N == ld_src && (N == 2 || N == 4)) { |
1996 | transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst); |
1997 | } else { |
1998 | int64_t ib = 0, jb = 0; |
1999 | if (N % 16 > 0 && N % 16 < 4) { |
2000 | // If the remainder has n < 4 columns, we use the SSE kernel for the |
2001 | // remainder because it requires 4 * (2 * 4 + 2 * N) = 32 + 8N |
2002 | // instructions instead of 4 * 16 + 2 * N = 64 + 2N instructions needed in |
2003 | // the masked AVX512 kernel. |
2004 | for (ib = 0; ib + 16 <= M; ib += 16) { |
2005 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2006 | transpose_kernel_16x16_avx512( |
2007 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
2008 | } |
2009 | for (int64_t i = ib; i < ib + 16; i += 4) { |
2010 | transpose_kernel_mxn_sse<4>( |
2011 | N - jb, |
2012 | &src[i * ld_src + jb], |
2013 | ld_src, |
2014 | &dst[i + jb * ld_dst], |
2015 | ld_dst); |
2016 | } |
2017 | } |
2018 | } else if (N % 16 == 4) { |
2019 | // If the remainder has 4 columns, we use the SSE kernel for the remainder |
2020 | // because it requires 4 * 16 = 64 instructions instead of 4 * 16 + 2 * 4 |
2021 | // = 72 instructions needed in the masked AVX512 kernel. |
2022 | for (ib = 0; ib + 16 <= M; ib += 16) { |
2023 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2024 | transpose_kernel_16x16_avx512( |
2025 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
2026 | } |
2027 | for (int64_t i = ib; i < ib + 16; i += 4) { |
2028 | transpose_kernel_4x4_sse( |
2029 | &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst); |
2030 | } |
2031 | } |
2032 | } else if (N % 16 == 8) { |
2033 | // If the remainder has 8 columns, we use the AVX kenrel for the remainder |
2034 | // because it requires 2 * 40 = 80 instructions instead of 4 * 16 + 2 * 8 |
2035 | // = 80 instructions + looping overhead in the masked AVX512 kernel. |
2036 | for (ib = 0; ib + 16 <= M; ib += 16) { |
2037 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2038 | transpose_kernel_16x16_avx512( |
2039 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
2040 | } |
2041 | for (int64_t i = ib; i < ib + 16; i += 8) { |
2042 | transpose_kernel_8x8_avx2( |
2043 | &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst); |
2044 | } |
2045 | } |
2046 | } else { |
2047 | for (ib = 0; ib + 16 <= M; ib += 16) { |
2048 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2049 | transpose_kernel_16x16_avx512( |
2050 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
2051 | } |
2052 | if (jb < N) { |
2053 | transpose_kernel_mxn_avx512<16>( |
2054 | N - jb, |
2055 | &src[ib * ld_src + jb], |
2056 | ld_src, |
2057 | &dst[ib + jb * ld_dst], |
2058 | ld_dst); |
2059 | } |
2060 | } |
2061 | } |
2062 | |
2063 | // Specialization for small M - ib cases so that the compiler can inline |
2064 | // transpose_kernel_mxn_avx512 and unroll the loops whose iteration count |
2065 | // depends on by M - ib . |
2066 | // Specialization for m helps more than for n in transpose_kernel_mxn_avx512 |
2067 | // because we have more loops in that function whose iteration count depends |
2068 | // on m. |
2069 | switch (M - ib) { |
2070 | case 1: |
2071 | for (int64_t j = 0; j < N; ++j) { |
2072 | dst[ib + j * ld_dst] = src[ib * ld_src + j]; |
2073 | } |
2074 | break; |
2075 | case 2: |
2076 | for (jb = 0; jb + 4 <= N; jb += 4) { |
2077 | transpose_kernel_mxn_sse<2>( |
2078 | 4, |
2079 | &src[ib * ld_src + jb], |
2080 | ld_src, |
2081 | &dst[ib + jb * ld_dst], |
2082 | ld_dst); |
2083 | } |
2084 | if (jb < N) { |
2085 | transpose_kernel_mxn_sse<2>( |
2086 | N - jb, |
2087 | &src[ib * ld_src + jb], |
2088 | ld_src, |
2089 | &dst[ib + jb * ld_dst], |
2090 | ld_dst); |
2091 | } |
2092 | break; |
2093 | case 3: |
2094 | for (jb = 0; jb + 4 <= N; jb += 4) { |
2095 | transpose_kernel_mxn_sse<3>( |
2096 | 4, |
2097 | &src[ib * ld_src + jb], |
2098 | ld_src, |
2099 | &dst[ib + jb * ld_dst], |
2100 | ld_dst); |
2101 | } |
2102 | if (jb < N) { |
2103 | transpose_kernel_mxn_sse<3>( |
2104 | N - jb, |
2105 | &src[ib * ld_src + jb], |
2106 | ld_src, |
2107 | &dst[ib + jb * ld_dst], |
2108 | ld_dst); |
2109 | } |
2110 | break; |
2111 | case 4: |
2112 | for (jb = 0; jb + 4 <= N; jb += 4) { |
2113 | transpose_kernel_4x4_sse( |
2114 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
2115 | } |
2116 | if (jb < N) { |
2117 | transpose_kernel_mxn_sse<4>( |
2118 | N - jb, |
2119 | &src[ib * ld_src + jb], |
2120 | ld_src, |
2121 | &dst[ib + jb * ld_dst], |
2122 | ld_dst); |
2123 | } |
2124 | break; |
2125 | case 5: |
2126 | for (jb = 0; jb + 8 <= N; jb += 8) { |
2127 | transpose_kernel_mxn_avx2<5>( |
2128 | 8, |
2129 | &src[ib * ld_src + jb], |
2130 | ld_src, |
2131 | &dst[ib + jb * ld_dst], |
2132 | ld_dst); |
2133 | } |
2134 | if (jb < N) { |
2135 | transpose_kernel_mxn_avx2<5>( |
2136 | N - jb, |
2137 | &src[ib * ld_src + jb], |
2138 | ld_src, |
2139 | &dst[ib + jb * ld_dst], |
2140 | ld_dst); |
2141 | } |
2142 | break; |
2143 | case 6: |
2144 | for (jb = 0; jb + 8 <= N; jb += 8) { |
2145 | transpose_kernel_mxn_avx2<6>( |
2146 | 8, |
2147 | &src[ib * ld_src + jb], |
2148 | ld_src, |
2149 | &dst[ib + jb * ld_dst], |
2150 | ld_dst); |
2151 | } |
2152 | if (jb < N) { |
2153 | transpose_kernel_mxn_avx2<6>( |
2154 | N - jb, |
2155 | &src[ib * ld_src + jb], |
2156 | ld_src, |
2157 | &dst[ib + jb * ld_dst], |
2158 | ld_dst); |
2159 | } |
2160 | break; |
2161 | case 7: |
2162 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2163 | transpose_kernel_mxn_avx512<7>( |
2164 | 16, |
2165 | &src[ib * ld_src + jb], |
2166 | ld_src, |
2167 | &dst[ib + jb * ld_dst], |
2168 | ld_dst); |
2169 | } |
2170 | if (jb < N) { |
2171 | transpose_kernel_mxn_avx512<7>( |
2172 | N - jb, |
2173 | &src[ib * ld_src + jb], |
2174 | ld_src, |
2175 | &dst[ib + jb * ld_dst], |
2176 | ld_dst); |
2177 | } |
2178 | break; |
2179 | case 8: |
2180 | for (jb = 0; jb + 8 <= N; jb += 8) { |
2181 | transpose_kernel_8x8_avx2( |
2182 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
2183 | } |
2184 | if (jb < N) { |
2185 | transpose_kernel_mxn_avx2<8>( |
2186 | N - jb, |
2187 | &src[ib * ld_src + jb], |
2188 | ld_src, |
2189 | &dst[ib + jb * ld_dst], |
2190 | ld_dst); |
2191 | } |
2192 | break; |
2193 | case 9: |
2194 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2195 | transpose_kernel_mxn_avx512<9>( |
2196 | 16, |
2197 | &src[ib * ld_src + jb], |
2198 | ld_src, |
2199 | &dst[ib + jb * ld_dst], |
2200 | ld_dst); |
2201 | } |
2202 | if (jb < N) { |
2203 | transpose_kernel_mxn_avx512<9>( |
2204 | N - jb, |
2205 | &src[ib * ld_src + jb], |
2206 | ld_src, |
2207 | &dst[ib + jb * ld_dst], |
2208 | ld_dst); |
2209 | } |
2210 | break; |
2211 | case 10: |
2212 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2213 | transpose_kernel_mxn_avx512<10>( |
2214 | 16, |
2215 | &src[ib * ld_src + jb], |
2216 | ld_src, |
2217 | &dst[ib + jb * ld_dst], |
2218 | ld_dst); |
2219 | } |
2220 | if (jb < N) { |
2221 | transpose_kernel_mxn_avx512<10>( |
2222 | N - jb, |
2223 | &src[ib * ld_src + jb], |
2224 | ld_src, |
2225 | &dst[ib + jb * ld_dst], |
2226 | ld_dst); |
2227 | } |
2228 | break; |
2229 | case 11: |
2230 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2231 | transpose_kernel_mxn_avx512<11>( |
2232 | 16, |
2233 | &src[ib * ld_src + jb], |
2234 | ld_src, |
2235 | &dst[ib + jb * ld_dst], |
2236 | ld_dst); |
2237 | } |
2238 | if (jb < N) { |
2239 | transpose_kernel_mxn_avx512<11>( |
2240 | N - jb, |
2241 | &src[ib * ld_src + jb], |
2242 | ld_src, |
2243 | &dst[ib + jb * ld_dst], |
2244 | ld_dst); |
2245 | } |
2246 | break; |
2247 | case 12: |
2248 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2249 | transpose_kernel_mxn_avx512<12>( |
2250 | 16, |
2251 | &src[ib * ld_src + jb], |
2252 | ld_src, |
2253 | &dst[ib + jb * ld_dst], |
2254 | ld_dst); |
2255 | } |
2256 | if (jb < N) { |
2257 | transpose_kernel_mxn_avx512<12>( |
2258 | N - jb, |
2259 | &src[ib * ld_src + jb], |
2260 | ld_src, |
2261 | &dst[ib + jb * ld_dst], |
2262 | ld_dst); |
2263 | } |
2264 | break; |
2265 | case 13: |
2266 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2267 | transpose_kernel_mxn_avx512<13>( |
2268 | 16, |
2269 | &src[ib * ld_src + jb], |
2270 | ld_src, |
2271 | &dst[ib + jb * ld_dst], |
2272 | ld_dst); |
2273 | } |
2274 | if (jb < N) { |
2275 | transpose_kernel_mxn_avx512<13>( |
2276 | N - jb, |
2277 | &src[ib * ld_src + jb], |
2278 | ld_src, |
2279 | &dst[ib + jb * ld_dst], |
2280 | ld_dst); |
2281 | } |
2282 | break; |
2283 | case 14: |
2284 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2285 | transpose_kernel_mxn_avx512<14>( |
2286 | 16, |
2287 | &src[ib * ld_src + jb], |
2288 | ld_src, |
2289 | &dst[ib + jb * ld_dst], |
2290 | ld_dst); |
2291 | } |
2292 | if (jb < N) { |
2293 | transpose_kernel_mxn_avx512<14>( |
2294 | N - jb, |
2295 | &src[ib * ld_src + jb], |
2296 | ld_src, |
2297 | &dst[ib + jb * ld_dst], |
2298 | ld_dst); |
2299 | } |
2300 | break; |
2301 | case 15: |
2302 | for (jb = 0; jb + 16 <= N; jb += 16) { |
2303 | transpose_kernel_mxn_avx512<15>( |
2304 | 16, |
2305 | &src[ib * ld_src + jb], |
2306 | ld_src, |
2307 | &dst[ib + jb * ld_dst], |
2308 | ld_dst); |
2309 | } |
2310 | if (jb < N) { |
2311 | transpose_kernel_mxn_avx512<15>( |
2312 | N - jb, |
2313 | &src[ib * ld_src + jb], |
2314 | ld_src, |
2315 | &dst[ib + jb * ld_dst], |
2316 | ld_dst); |
2317 | } |
2318 | break; |
2319 | } |
2320 | } |
2321 | } |
2322 | |
2323 | template <> |
2324 | void transpose_avx512( |
2325 | int64_t M, |
2326 | int64_t N, |
2327 | const uint16_t* src, |
2328 | int64_t ld_src, |
2329 | uint16_t* dst, |
2330 | int64_t ld_dst) { |
2331 | if (M == ld_dst && (M == 2 || M == 4)) { |
2332 | transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst); |
2333 | } else if (N == ld_src && (N == 2 || N == 4)) { |
2334 | transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst); |
2335 | } else { |
2336 | int64_t i = 0; |
2337 | for (; i < M / 16 * 16; i += 16) { |
2338 | int64_t j = 0; |
2339 | for (; j < N / 16 * 16; j += 16) { |
2340 | transpose_16x16_block<false, false>( |
2341 | src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst); |
2342 | } |
2343 | // handle j rem |
2344 | int nrem = N - j; |
2345 | if (nrem > 0) { |
2346 | transpose_16x16_block<false, true>( |
2347 | src + i * ld_src + j, |
2348 | ld_src, |
2349 | dst + j * ld_dst + i, |
2350 | ld_dst, |
2351 | 16, |
2352 | nrem); |
2353 | } |
2354 | } |
2355 | // handle i rem |
2356 | int mrem = M - i; |
2357 | if (mrem > 0) { |
2358 | int j = 0; |
2359 | for (; j < N / 16 * 16; j += 16) { |
2360 | transpose_16x16_block<true, false>( |
2361 | src + i * ld_src + j, |
2362 | ld_src, |
2363 | dst + j * ld_dst + i, |
2364 | ld_dst, |
2365 | mrem, |
2366 | 16); |
2367 | } |
2368 | // handle j rem |
2369 | int nrem = N - j; |
2370 | transpose_16x16_block<true, true>( |
2371 | src + i * ld_src + j, |
2372 | ld_src, |
2373 | dst + j * ld_dst + i, |
2374 | ld_dst, |
2375 | mrem, |
2376 | nrem); |
2377 | } |
2378 | } |
2379 | } |
2380 | |
2381 | template <> |
2382 | void transpose_avx512( |
2383 | int64_t M, |
2384 | int64_t N, |
2385 | const uint8_t* src, |
2386 | int64_t ld_src, |
2387 | uint8_t* dst, |
2388 | int64_t ld_dst) { |
2389 | if (M == ld_dst && (M == 2 || M == 4)) { |
2390 | transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst); |
2391 | } else if (N == ld_src && (N == 2 || N == 4)) { |
2392 | transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst); |
2393 | } else { |
2394 | int64_t i = 0; |
2395 | for (; i < M / 16 * 16; i += 16) { |
2396 | int64_t j = 0; |
2397 | for (; j < N / 32 * 32; j += 32) { |
2398 | transpose_16x32_block<false, false>( |
2399 | src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst); |
2400 | } |
2401 | // handle j rem |
2402 | int nrem = N - j; |
2403 | if (nrem > 0) { |
2404 | transpose_16x32_block<false, true>( |
2405 | src + i * ld_src + j, |
2406 | ld_src, |
2407 | dst + j * ld_dst + i, |
2408 | ld_dst, |
2409 | 16, |
2410 | nrem); |
2411 | } |
2412 | } |
2413 | |
2414 | // handle i rem |
2415 | int mrem = M - i; |
2416 | if (mrem > 0) { |
2417 | int64_t j = 0; |
2418 | for (; j < N / 32 * 32; j += 32) { |
2419 | transpose_16x32_block<true, false>( |
2420 | src + i * ld_src + j, |
2421 | ld_src, |
2422 | dst + j * ld_dst + i, |
2423 | ld_dst, |
2424 | mrem, |
2425 | 32); |
2426 | } |
2427 | // handle j rem |
2428 | int nrem = N - j; |
2429 | transpose_16x32_block<true, true>( |
2430 | src + i * ld_src + j, |
2431 | ld_src, |
2432 | dst + j * ld_dst + i, |
2433 | ld_dst, |
2434 | mrem, |
2435 | nrem); |
2436 | } |
2437 | } |
2438 | } |
2439 | |
2440 | } // namespace internal |
2441 | |
2442 | } // namespace fbgemm |
2443 | |