1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | |
9 | #if defined(__x86_64__) || defined(__i386__) || \ |
10 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
11 | #include <immintrin.h> |
12 | #endif |
13 | #include <cassert> |
14 | #include <cstdint> |
15 | |
16 | #include "./MaskAvx2.h" |
17 | |
18 | namespace fbgemm { |
19 | |
20 | namespace internal { |
21 | |
22 | #ifdef __AVX2__ |
23 | // NOTE: Make sure every function defined in here has static linkage because |
24 | // this header file is included by UtilsAvx512.cc compiled with -mavx512f option |
25 | |
26 | // 4 * 4 = 16 instructions |
27 | static inline void transpose_kernel_4x4_sse( |
28 | const float* src, |
29 | int64_t ld_src, |
30 | float* dst, |
31 | int64_t ld_dst) { |
32 | // load from src to registers |
33 | // a : a0 a1 a2 a3 |
34 | // b : b0 b1 b2 b3 |
35 | // c : c0 c1 c2 c3 |
36 | // d : d0 d1 d2 d3 |
37 | __m128 a = _mm_loadu_ps(&src[0 * ld_src]); |
38 | __m128 b = _mm_loadu_ps(&src[1 * ld_src]); |
39 | __m128 c = _mm_loadu_ps(&src[2 * ld_src]); |
40 | __m128 d = _mm_loadu_ps(&src[3 * ld_src]); |
41 | |
42 | // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE |
43 | // a : a0 b0 c0 d0 |
44 | // b : a1 b1 c1 d1 |
45 | // c : a2 b2 c2 d2 |
46 | // d : a3 b3 c3 d3 |
47 | _MM_TRANSPOSE4_PS(a, b, c, d); |
48 | |
49 | // store from registers to dst |
50 | _mm_storeu_ps(&dst[0 * ld_dst], a); |
51 | _mm_storeu_ps(&dst[1 * ld_dst], b); |
52 | _mm_storeu_ps(&dst[2 * ld_dst], c); |
53 | _mm_storeu_ps(&dst[3 * ld_dst], d); |
54 | } |
55 | |
56 | // kernel for transpose mxn where m, n <= 4 |
57 | // M + (M + 1) / 2 * 2 + 2 * N instructions |
58 | template <unsigned M> |
59 | static void transpose_kernel_mxn_sse( |
60 | unsigned N, |
61 | const float* src, |
62 | int64_t ld_src, |
63 | float* dst, |
64 | int64_t ld_dst) { |
65 | // clang-format off |
66 | alignas(64) static const int masks[5][4] = { |
67 | { 0, 0, 0, 0, }, |
68 | { -1, 0, 0, 0, }, |
69 | { -1, -1, 0, 0, }, |
70 | { -1, -1, -1, 0, }, |
71 | { -1, -1, -1, -1, }, |
72 | }; |
73 | // clang-format on |
74 | |
75 | // load from src to registers |
76 | __m128i mask_v = _mm_load_si128(reinterpret_cast<const __m128i*>(masks[N])); |
77 | __m128 input[4]; |
78 | unsigned i; |
79 | for (i = 0; i < M; ++i) { |
80 | input[i] = _mm_maskload_ps(&src[i * ld_src], mask_v); |
81 | } |
82 | for (; i < 4; ++i) { |
83 | // Not really needed but to avoid uninitialized variable warning. |
84 | // Shouldn't be much overhead because xor can be executed in parallel with |
85 | // other instructions. |
86 | input[i] = _mm_setzero_ps(); |
87 | } |
88 | |
89 | __m128 temp[4]; |
90 | for (i = 0; i < (M + 1) / 2; ++i) { |
91 | temp[2 * i] = _mm_unpacklo_ps(input[2 * i], input[2 * i + 1]); |
92 | temp[2 * i + 1] = _mm_unpackhi_ps(input[2 * i], input[2 * i + 1]); |
93 | } |
94 | for (i = i * 2; i < 4; ++i) { |
95 | temp[i] = _mm_setzero_ps(); |
96 | } |
97 | |
98 | mask_v = _mm_load_si128(reinterpret_cast<const __m128i*>(masks[M])); |
99 | for (i = 0; i < N; ++i) { |
100 | if (i % 2 == 0) { |
101 | input[i] = _mm_movelh_ps(temp[i / 2], temp[2 + i / 2]); |
102 | } else { |
103 | input[i] = _mm_movehl_ps(temp[2 + i / 2], temp[i / 2]); |
104 | } |
105 | _mm_maskstore_ps(&dst[i * ld_dst], mask_v, input[i]); |
106 | } |
107 | } |
108 | |
109 | // 8 * 5 = 40 instructions |
110 | static inline void transpose_kernel_8x8_avx2( |
111 | const float* src, |
112 | int64_t ld_src, |
113 | float* dst, |
114 | int64_t ld_dst) { |
115 | // load from src to registers |
116 | // a : a0 a1 a2 a3 a4 a5 a6 a7 |
117 | // b : b0 b1 b2 b3 b4 b5 b6 b7 |
118 | // c : c0 c1 c2 c3 c4 c5 c6 c7 |
119 | // d : d0 d1 d2 d3 d4 d5 d6 d7 |
120 | // e : e0 e1 e2 e3 e4 e5 e6 e7 |
121 | // f : f0 f1 f2 f3 f4 f5 f6 f7 |
122 | // g : g0 g1 g2 g3 g4 g5 g6 g7 |
123 | // h : h0 h1 h2 h3 h4 h5 h6 h7 |
124 | __m256 a = _mm256_loadu_ps(&src[0 * ld_src]); |
125 | __m256 b = _mm256_loadu_ps(&src[1 * ld_src]); |
126 | __m256 c = _mm256_loadu_ps(&src[2 * ld_src]); |
127 | __m256 d = _mm256_loadu_ps(&src[3 * ld_src]); |
128 | __m256 e = _mm256_loadu_ps(&src[4 * ld_src]); |
129 | __m256 f = _mm256_loadu_ps(&src[5 * ld_src]); |
130 | __m256 g = _mm256_loadu_ps(&src[6 * ld_src]); |
131 | __m256 h = _mm256_loadu_ps(&src[7 * ld_src]); |
132 | |
133 | __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367; |
134 | __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37; |
135 | // unpacking and interleaving 32-bit elements |
136 | // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5 |
137 | // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7 |
138 | // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5 |
139 | // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7 |
140 | // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5 |
141 | // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7 |
142 | // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5 |
143 | // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7 |
144 | ab0145 = _mm256_unpacklo_ps(a, b); |
145 | ab2367 = _mm256_unpackhi_ps(a, b); |
146 | cd0145 = _mm256_unpacklo_ps(c, d); |
147 | cd2367 = _mm256_unpackhi_ps(c, d); |
148 | ef0145 = _mm256_unpacklo_ps(e, f); |
149 | ef2367 = _mm256_unpackhi_ps(e, f); |
150 | gh0145 = _mm256_unpacklo_ps(g, h); |
151 | gh2367 = _mm256_unpackhi_ps(g, h); |
152 | |
153 | // shuffling the 32-bit elements |
154 | // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4 |
155 | // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5 |
156 | // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4 |
157 | // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5 |
158 | // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6 |
159 | // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7 |
160 | // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6 |
161 | // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7 |
162 | abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44); |
163 | abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee); |
164 | efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44); |
165 | efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee); |
166 | abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44); |
167 | abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee); |
168 | efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44); |
169 | efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee); |
170 | |
171 | // shuffling 128-bit elements |
172 | // a : a0 b0 c0 d0 e0 f0 g0 h0 |
173 | // b : a1 b1 c1 d1 e1 f1 g1 h1 |
174 | // c : a2 b2 c2 d2 e2 f2 g2 h2 |
175 | // d : a3 b3 c3 d3 e3 f3 g3 h3 |
176 | // e : a4 b4 c4 d4 e4 f4 g4 h4 |
177 | // f : a5 b5 c5 d5 e5 f5 g5 h5 |
178 | // g : a6 b6 c6 d6 e6 f6 g6 h6 |
179 | // h : a7 b7 c7 d7 e7 f7 g7 h7 |
180 | a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02); |
181 | b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02); |
182 | c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02); |
183 | d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02); |
184 | e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13); |
185 | f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13); |
186 | g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13); |
187 | h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13); |
188 | |
189 | // store from registers to dst |
190 | _mm256_storeu_ps(&dst[0 * ld_dst], a); |
191 | _mm256_storeu_ps(&dst[1 * ld_dst], b); |
192 | _mm256_storeu_ps(&dst[2 * ld_dst], c); |
193 | _mm256_storeu_ps(&dst[3 * ld_dst], d); |
194 | _mm256_storeu_ps(&dst[4 * ld_dst], e); |
195 | _mm256_storeu_ps(&dst[5 * ld_dst], f); |
196 | _mm256_storeu_ps(&dst[6 * ld_dst], g); |
197 | _mm256_storeu_ps(&dst[7 * ld_dst], h); |
198 | } |
199 | |
200 | // kernel for transposing mxn where m, n <= 8 |
201 | // M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + 2 * N instructions |
202 | template <unsigned M> |
203 | static void transpose_kernel_mxn_avx2( |
204 | unsigned N, |
205 | const float* src, |
206 | int64_t ld_src, |
207 | float* dst, |
208 | int64_t ld_dst) { |
209 | // load from src to registers |
210 | __m256i mask_v = _mm256_load_si256( |
211 | reinterpret_cast<const __m256i*>(internal::avx2_ps_or_epi32_masks[N])); |
212 | __m256 input[8]; |
213 | unsigned i; |
214 | for (i = 0; i < M; ++i) { |
215 | input[i] = _mm256_maskload_ps(&src[i * ld_src], mask_v); |
216 | } |
217 | for (; i < 8; ++i) { |
218 | // Not really needed but to avoid uninitialized variable warning. |
219 | // Shouldn't be much overhead because xor can be executed in parallel with |
220 | // other instructions. |
221 | input[i] = _mm256_setzero_ps(); |
222 | } |
223 | |
224 | // unpacking and interleaving 32-bit elements |
225 | __m256 temp[8]; |
226 | for (i = 0; i < (M + 1) / 2; ++i) { |
227 | temp[2 * i] = _mm256_unpacklo_ps(input[2 * i], input[2 * i + 1]); |
228 | temp[2 * i + 1] = _mm256_unpackhi_ps(input[2 * i], input[2 * i + 1]); |
229 | } |
230 | for (i = i * 2; i < 8; ++i) { |
231 | temp[i] = _mm256_setzero_ps(); |
232 | } |
233 | |
234 | // shuffling the 32-bit elements |
235 | for (i = 0; i < (M + 3) / 4; ++i) { |
236 | input[4 * i] = _mm256_shuffle_ps(temp[4 * i], temp[4 * i + 2], 0x44); |
237 | input[4 * i + 1] = _mm256_shuffle_ps(temp[4 * i], temp[4 * i + 2], 0xee); |
238 | input[4 * i + 2] = |
239 | _mm256_shuffle_ps(temp[4 * i + 1], temp[4 * i + 3], 0x44); |
240 | input[4 * i + 3] = |
241 | _mm256_shuffle_ps(temp[4 * i + 1], temp[4 * i + 3], 0xee); |
242 | } |
243 | |
244 | // shuffling 128-bit elements |
245 | // store from registers to dst |
246 | mask_v = _mm256_load_si256( |
247 | reinterpret_cast<const __m256i*>(internal::avx2_ps_or_epi32_masks[M])); |
248 | for (i = 0; i < N; ++i) { |
249 | if (i < 4) { |
250 | temp[i] = _mm256_permute2f128_ps(input[4 + i], input[i], 0x02); |
251 | } else { |
252 | temp[i] = _mm256_permute2f128_ps(input[i], input[i - 4], 0x13); |
253 | } |
254 | _mm256_maskstore_ps(&dst[i * ld_dst], mask_v, temp[i]); |
255 | } |
256 | } |
257 | |
258 | inline __m256i permute_row(__m256i row) { |
259 | // clang-format off |
260 | row = _mm256_shuffle_epi8( |
261 | row, |
262 | _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, |
263 | 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0)); |
264 | // clang-format on |
265 | return row; |
266 | } |
267 | |
268 | // template <> |
269 | inline static void transpose_kernel_8x32_avx2( |
270 | const uint8_t* src, |
271 | int64_t ld_src, |
272 | uint8_t* dst, |
273 | int64_t ld_dst) { |
274 | // load from src to registers |
275 | // a : a0 a1 a2 a3 a4 a5 a6 a7 ... a31 |
276 | // b : b0 b1 b2 b3 b4 b5 b6 b7 ... b31 |
277 | // c : c0 c1 c2 c3 c4 c5 c6 c7 ... c31 |
278 | // d : d0 d1 d2 d3 d4 d5 d6 d7 ... d31 |
279 | // e : e0 e1 e2 e3 e4 e5 e6 e7 ... e31 |
280 | // f : f0 f1 f2 f3 f4 f5 f6 f7 ... f31 |
281 | // g : g0 g1 g2 g3 g4 g5 g6 g7 ... g31 |
282 | // h : h0 h1 h2 h3 h4 h5 h6 h7 ... h31 |
283 | |
284 | // load from src |
285 | __m256i a = _mm256_loadu_si256( |
286 | reinterpret_cast<const __m256i*>((src) + (0 * ld_src))); |
287 | __m256i b = _mm256_loadu_si256( |
288 | reinterpret_cast<const __m256i*>((src) + (1 * ld_src))); |
289 | __m256i c = _mm256_loadu_si256( |
290 | reinterpret_cast<const __m256i*>((src) + (2 * ld_src))); |
291 | __m256i d = _mm256_loadu_si256( |
292 | reinterpret_cast<const __m256i*>((src) + (3 * ld_src))); |
293 | __m256i e = _mm256_loadu_si256( |
294 | reinterpret_cast<const __m256i*>((src) + (4 * ld_src))); |
295 | __m256i f = _mm256_loadu_si256( |
296 | reinterpret_cast<const __m256i*>((src) + (5 * ld_src))); |
297 | __m256i g = _mm256_loadu_si256( |
298 | reinterpret_cast<const __m256i*>((src) + (6 * ld_src))); |
299 | __m256i h = _mm256_loadu_si256( |
300 | reinterpret_cast<const __m256i*>((src) + (7 * ld_src))); |
301 | |
302 | // shuffle in stride of one: |
303 | // t0 : a0 -- a3, b0 -- b3, a4 -- a7, b4 -- b7, |
304 | // a16 -- a19, b16 -- b19, a20 -- a23, b20 -- b23 |
305 | |
306 | // t1 : a8 -- a11, b8 -- b11, a12 -- a15, b12 -- b15, |
307 | // a24 -- a27, b24 -- b27, a28 -- a31, b28 -- b31 |
308 | |
309 | // t2 : c0 -- c3, d0 -- d3, c4 -- c7, d4 -- d7, |
310 | // c16 -- c19, d16 -- d19, c20 -- c23, d20 -- d23 |
311 | |
312 | __m256i __t0 = _mm256_unpacklo_epi32(a, b); |
313 | __m256i __t1 = _mm256_unpackhi_epi32(a, b); |
314 | __m256i __t2 = _mm256_unpacklo_epi32(c, d); |
315 | __m256i __t3 = _mm256_unpackhi_epi32(c, d); |
316 | __m256i __t4 = _mm256_unpacklo_epi32(e, f); |
317 | __m256i __t5 = _mm256_unpackhi_epi32(e, f); |
318 | __m256i __t6 = _mm256_unpacklo_epi32(g, h); |
319 | __m256i __t7 = _mm256_unpackhi_epi32(g, h); |
320 | |
321 | // shuffle in stride of two: |
322 | // tt0: a0--a3, b0--b3, c0--c3, d0--d3, |
323 | // a16--a19, b16 -- b19, c16 -- c19, d16--d19 |
324 | |
325 | // tt1: a4 -- a7, b4 -- b7, c8--c11, d8--d11, |
326 | // a20--a23, b20--b23, c20--c23, d20--d23 |
327 | |
328 | // tt2: a8 -- a11, b8 -- b11, c8 -- c11, d8 -- d11, |
329 | // a24 -- a27, b24 -- b27, c24 -- c27, d24 -- d27 |
330 | |
331 | // tt3: a12 -- a15, b12 -- b15, c12--c15, d12--d15, |
332 | // a28--a31, b28--b31, c28--c31, d28--d31 |
333 | |
334 | // tt4: e0--e3, f0--f3, g0--h3, g0--h3, |
335 | // e16--e19, f16--f19, g16--h19, g16--h19 |
336 | __m256i __tt0 = _mm256_unpacklo_epi64(__t0, __t2); |
337 | __m256i __tt1 = _mm256_unpackhi_epi64(__t0, __t2); |
338 | __m256i __tt2 = _mm256_unpacklo_epi64(__t1, __t3); |
339 | __m256i __tt3 = _mm256_unpackhi_epi64(__t1, __t3); |
340 | __m256i __tt4 = _mm256_unpacklo_epi64(__t4, __t6); |
341 | __m256i __tt5 = _mm256_unpackhi_epi64(__t4, __t6); |
342 | __m256i __tt6 = _mm256_unpacklo_epi64(__t5, __t7); |
343 | __m256i __tt7 = _mm256_unpackhi_epi64(__t5, __t7); |
344 | |
345 | // permute: pack consecutive elements(0-3) together |
346 | // ttt0: a0--d0 a1--d1 a2--d2 a3--d3 a16-d16 a17-d17 a18-d18 a18-d19 |
347 | |
348 | // ttt1: a4--d4 a5--d5 a6--d6 a7--d7 a20-d20 a21-d21 a22-d22 a23-d23 |
349 | |
350 | // ttt2: a8--d8 a9--d9 a10--d10 a11--d11 a24-d24 a25-d25 a26-d26 a27-d27 |
351 | __m256i __ttt0 = permute_row(__tt0); |
352 | __m256i __ttt1 = permute_row(__tt1); |
353 | __m256i __ttt2 = permute_row(__tt2); |
354 | __m256i __ttt3 = permute_row(__tt3); |
355 | __m256i __ttt4 = permute_row(__tt4); |
356 | __m256i __ttt5 = permute_row(__tt5); |
357 | __m256i __ttt6 = permute_row(__tt6); |
358 | __m256i __ttt7 = permute_row(__tt7); |
359 | |
360 | // |
361 | // a: a0-h0 a1-h1 a16-h16 a17-h17 |
362 | // b: a2-h2 a3-h3 a18-h18 a19-h19 |
363 | |
364 | // c: a4-h4 a6-h6 a20-h20 a22-h22 (a-h)x(4-7) |
365 | // d: a5-h5 a7-h7 a21-h21 a23-h23 (a-h)x(20-23) |
366 | |
367 | // e: a8-h8 a9-h9 a24-h24 a25-h25 (a-h)x(8-11) |
368 | // f: a10-h10 a11-h11 a26-h26 a27-h27 (a-h)x(24-27) |
369 | |
370 | // g: (a-h)x(12-15) |
371 | // h: (a-h)x(28-31) |
372 | a = _mm256_unpacklo_epi32(__ttt0, __ttt4); |
373 | b = _mm256_unpackhi_epi32(__ttt0, __ttt4); |
374 | c = _mm256_unpacklo_epi32(__ttt1, __ttt5); |
375 | d = _mm256_unpackhi_epi32(__ttt1, __ttt5); |
376 | e = _mm256_unpacklo_epi32(__ttt2, __ttt6); |
377 | f = _mm256_unpackhi_epi32(__ttt2, __ttt6); |
378 | g = _mm256_unpacklo_epi32(__ttt3, __ttt7); |
379 | h = _mm256_unpackhi_epi32(__ttt3, __ttt7); |
380 | |
381 | // stores back 32 rows: |
382 | |
383 | reinterpret_cast<uint64_t*>(dst)[0] = _mm256_extract_epi64(a, 0); |
384 | reinterpret_cast<uint64_t*>((dst) + ld_dst)[0] = _mm256_extract_epi64(a, 1); |
385 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 2)[0] = |
386 | _mm256_extract_epi64(b, 0); |
387 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 3)[0] = |
388 | _mm256_extract_epi64(b, 1); |
389 | |
390 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 4)[0] = |
391 | _mm256_extract_epi64(c, 0); |
392 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 5)[0] = |
393 | _mm256_extract_epi64(c, 1); |
394 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 6)[0] = |
395 | _mm256_extract_epi64(d, 0); |
396 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 7)[0] = |
397 | _mm256_extract_epi64(d, 1); |
398 | |
399 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 8)[0] = |
400 | _mm256_extract_epi64(e, 0); |
401 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 9)[0] = |
402 | _mm256_extract_epi64(e, 1); |
403 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 10)[0] = |
404 | _mm256_extract_epi64(f, 0); |
405 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 11)[0] = |
406 | _mm256_extract_epi64(f, 1); |
407 | |
408 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 12)[0] = |
409 | _mm256_extract_epi64(g, 0); |
410 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 13)[0] = |
411 | _mm256_extract_epi64(g, 1); |
412 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 14)[0] = |
413 | _mm256_extract_epi64(h, 0); |
414 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 15)[0] = |
415 | _mm256_extract_epi64(h, 1); |
416 | |
417 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 16)[0] = |
418 | _mm256_extract_epi64(a, 2); |
419 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 17)[0] = |
420 | _mm256_extract_epi64(a, 3); |
421 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 18)[0] = |
422 | _mm256_extract_epi64(b, 2); |
423 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 19)[0] = |
424 | _mm256_extract_epi64(b, 3); |
425 | |
426 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 20)[0] = |
427 | _mm256_extract_epi64(c, 2); |
428 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 21)[0] = |
429 | _mm256_extract_epi64(c, 3); |
430 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 22)[0] = |
431 | _mm256_extract_epi64(d, 2); |
432 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 23)[0] = |
433 | _mm256_extract_epi64(d, 3); |
434 | |
435 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 24)[0] = |
436 | _mm256_extract_epi64(e, 2); |
437 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 25)[0] = |
438 | _mm256_extract_epi64(e, 3); |
439 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 26)[0] = |
440 | _mm256_extract_epi64(f, 2); |
441 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 27)[0] = |
442 | _mm256_extract_epi64(f, 3); |
443 | |
444 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 28)[0] = |
445 | _mm256_extract_epi64(g, 2); |
446 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 29)[0] = |
447 | _mm256_extract_epi64(g, 3); |
448 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 30)[0] = |
449 | _mm256_extract_epi64(h, 2); |
450 | reinterpret_cast<uint64_t*>((dst) + ld_dst * 31)[0] = |
451 | _mm256_extract_epi64(h, 3); |
452 | } |
453 | |
454 | static inline void load_with_remainders_i16( |
455 | const uint16_t* src, |
456 | int64_t ld_src, |
457 | __m256i r[], |
458 | unsigned mrem, |
459 | unsigned nrem) { |
460 | if (nrem < 16) { |
461 | uint16_t local_buffer[16] = {0}; |
462 | __m256i mask_nrem_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
463 | internal::avx2_ps_or_epi32_masks[nrem / 2])); |
464 | unsigned half = nrem % 2; |
465 | for (unsigned i = 0; i < mrem; ++i) { |
466 | // mask load |
467 | r[i] = _mm256_maskload_epi32( |
468 | reinterpret_cast<const int*>(&src[i * ld_src]), mask_nrem_v); |
469 | if (half == 1) { |
470 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(&local_buffer[0]), r[i]); |
471 | local_buffer[nrem - 1] = src[i * ld_src + nrem - 1]; |
472 | r[i] = _mm256_loadu_si256( |
473 | reinterpret_cast<const __m256i*>(&local_buffer[0])); |
474 | } |
475 | } |
476 | } else { |
477 | for (unsigned i = 0; i < mrem; ++i) { |
478 | // normal load |
479 | r[i] = _mm256_loadu_si256( |
480 | reinterpret_cast<const __m256i*>(src + i * ld_src)); |
481 | } |
482 | } |
483 | } |
484 | |
485 | static inline void store_with_remainders_i16( |
486 | uint16_t* dst, |
487 | int64_t ld_dst, |
488 | __m256i u[], |
489 | unsigned mrem, |
490 | unsigned nrem) { |
491 | if (mrem < 8) { |
492 | uint16_t local_buffer[8] = {0}; |
493 | __m256i mask_mrem_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
494 | internal::avx2_ps_or_epi32_masks[mrem / 2])); |
495 | unsigned half = mrem % 2; |
496 | unsigned i = 0; |
497 | for (; i < nrem; i += 1) { |
498 | // mask store |
499 | int reg_idx = i % 8; |
500 | __m128i d; |
501 | if (i >= 8) { |
502 | d = _mm256_extractf128_si256(u[reg_idx], 1); |
503 | } else { |
504 | d = _mm256_extractf128_si256(u[reg_idx], 0); |
505 | } |
506 | _mm256_maskstore_epi32( |
507 | reinterpret_cast<int*>(dst + i * ld_dst), |
508 | mask_mrem_v, |
509 | _mm256_castsi128_si256(d)); |
510 | if (half == 1) { |
511 | _mm_storeu_si128(reinterpret_cast<__m128i*>(local_buffer), d); |
512 | (dst + i * ld_dst)[mrem - 1] = local_buffer[mrem - 1]; |
513 | } |
514 | } |
515 | |
516 | } else { |
517 | unsigned i = 0; |
518 | for (; i < nrem; i += 1) { |
519 | // normal store |
520 | unsigned reg_idx = i % 8; |
521 | if (i >= 8) { |
522 | _mm_storeu_si128( |
523 | reinterpret_cast<__m128i*>(dst + i * ld_dst), |
524 | _mm256_extractf128_si256(u[reg_idx], 1)); |
525 | } else { |
526 | _mm_storeu_si128( |
527 | reinterpret_cast<__m128i*>(dst + i * ld_dst), |
528 | _mm256_extractf128_si256(u[reg_idx], 0)); |
529 | } |
530 | } |
531 | } |
532 | } |
533 | |
534 | template <bool MREM = false, bool NREM = false> |
535 | inline static void transpose_kernel_8x16_avx2( |
536 | const uint16_t* src, |
537 | int64_t ld_src, |
538 | uint16_t* dst, |
539 | int64_t ld_dst, |
540 | unsigned mrem = 8, |
541 | unsigned nrem = 16) { |
542 | __m256i r[8]; |
543 | // load from src to registers |
544 | // a : a0 a1 a2 a3 a4 a5 a6 a7 ... a15 |
545 | // b : b0 b1 b2 b3 b4 b5 b6 b7 ... b15 |
546 | // c : c0 c1 c2 c3 c4 c5 c6 c7 ... c15 |
547 | // d : d0 d1 d2 d3 d4 d5 d6 d7 ... d15 |
548 | // e : e0 e1 e2 e3 e4 e5 e6 e7 ... e15 |
549 | // f : f0 f1 f2 f3 f4 f5 f6 f7 ... f15 |
550 | // g : g0 g1 g2 g3 g4 g5 g6 g7 ... g15 |
551 | // h : h0 h1 h2 h3 h4 h5 h6 h7 ... h15 |
552 | if (MREM || NREM) { |
553 | load_with_remainders_i16(src, ld_src, r, mrem, nrem); |
554 | } else { |
555 | r[0] = _mm256_loadu_si256( |
556 | reinterpret_cast<const __m256i*>((src) + (0 * ld_src))); |
557 | r[1] = _mm256_loadu_si256( |
558 | reinterpret_cast<const __m256i*>((src) + (1 * ld_src))); |
559 | r[2] = _mm256_loadu_si256( |
560 | reinterpret_cast<const __m256i*>((src) + (2 * ld_src))); |
561 | r[3] = _mm256_loadu_si256( |
562 | reinterpret_cast<const __m256i*>((src) + (3 * ld_src))); |
563 | r[4] = _mm256_loadu_si256( |
564 | reinterpret_cast<const __m256i*>((src) + (4 * ld_src))); |
565 | r[5] = _mm256_loadu_si256( |
566 | reinterpret_cast<const __m256i*>((src) + (5 * ld_src))); |
567 | r[6] = _mm256_loadu_si256( |
568 | reinterpret_cast<const __m256i*>((src) + (6 * ld_src))); |
569 | r[7] = _mm256_loadu_si256( |
570 | reinterpret_cast<const __m256i*>((src) + (7 * ld_src))); |
571 | } |
572 | // t0 : a0a1, b0b1, a2a3, b2b3, |
573 | // a8a9, b8b9, a10a11, b10b11 |
574 | |
575 | // t1 : a4a5, b4b5, a6a7, b6b7, |
576 | // a12a13, b12b13, a14a15, b14b15 |
577 | |
578 | // t2 : c0c1, d0d1, c2c3, d2d3, |
579 | // c8c9, d8d9, c10c11, d10d11 |
580 | |
581 | __m256i __t0 = _mm256_unpacklo_epi32(r[0], r[1]); |
582 | __m256i __t1 = _mm256_unpackhi_epi32(r[0], r[1]); |
583 | __m256i __t2 = _mm256_unpacklo_epi32(r[2], r[3]); |
584 | __m256i __t3 = _mm256_unpackhi_epi32(r[2], r[3]); |
585 | __m256i __t4 = _mm256_unpacklo_epi32(r[4], r[5]); |
586 | __m256i __t5 = _mm256_unpackhi_epi32(r[4], r[5]); |
587 | __m256i __t6 = _mm256_unpacklo_epi32(r[6], r[7]); |
588 | __m256i __t7 = _mm256_unpackhi_epi32(r[6], r[7]); |
589 | |
590 | // tt0: a0a1, b0b1, c0c1, d0d1, |
591 | // a9a9, b8b9, c8c9, d8d9 |
592 | |
593 | // tt1: a2a3, b2b3, c2c3, d2d3, |
594 | // a10a11, b10b11, c10c11, d10d11 |
595 | |
596 | // tt2: a4a5, b4b5, c4c5, d4d5, |
597 | // a12a13, b12b13, c12c13, d12d13 |
598 | |
599 | // tt3: a6a7, b6b7, c6c7, d6d7, |
600 | // a14a15, b14b15, c14c15, d14d15 |
601 | |
602 | // tt4: e0e1, f0f1, g0g1, h0h1, |
603 | // e9e9, f8f9, g8g9, h8h9 |
604 | __m256i __tt0 = _mm256_unpacklo_epi64(__t0, __t2); |
605 | __m256i __tt1 = _mm256_unpackhi_epi64(__t0, __t2); |
606 | __m256i __tt2 = _mm256_unpacklo_epi64(__t1, __t3); |
607 | __m256i __tt3 = _mm256_unpackhi_epi64(__t1, __t3); |
608 | __m256i __tt4 = _mm256_unpacklo_epi64(__t4, __t6); |
609 | __m256i __tt5 = _mm256_unpackhi_epi64(__t4, __t6); |
610 | __m256i __tt6 = _mm256_unpacklo_epi64(__t5, __t7); |
611 | __m256i __tt7 = _mm256_unpackhi_epi64(__t5, __t7); |
612 | |
613 | // t0: a0b0, a1b1, c0c1, d0d1, |
614 | // a8b8, a9b9, c8c9, d8d9 |
615 | __t0 = _mm256_shufflelo_epi16(__tt0, 0xD8); |
616 | __t1 = _mm256_shufflelo_epi16(__tt1, 0xD8); |
617 | __t2 = _mm256_shufflelo_epi16(__tt2, 0xD8); |
618 | __t3 = _mm256_shufflelo_epi16(__tt3, 0xD8); |
619 | __t4 = _mm256_shufflelo_epi16(__tt4, 0xD8); |
620 | __t5 = _mm256_shufflelo_epi16(__tt5, 0xD8); |
621 | __t6 = _mm256_shufflelo_epi16(__tt6, 0xD8); |
622 | __t7 = _mm256_shufflelo_epi16(__tt7, 0xD8); |
623 | |
624 | // tt0: a0b0, a1b1, c0d0, c1d1, |
625 | // a8b8, a9b9, c8d8, c9d9 |
626 | __tt0 = _mm256_shufflehi_epi16(__t0, 0xD8); |
627 | __tt1 = _mm256_shufflehi_epi16(__t1, 0xD8); |
628 | __tt2 = _mm256_shufflehi_epi16(__t2, 0xD8); |
629 | __tt3 = _mm256_shufflehi_epi16(__t3, 0xD8); |
630 | __tt4 = _mm256_shufflehi_epi16(__t4, 0xD8); |
631 | __tt5 = _mm256_shufflehi_epi16(__t5, 0xD8); |
632 | __tt6 = _mm256_shufflehi_epi16(__t6, 0xD8); |
633 | __tt7 = _mm256_shufflehi_epi16(__t7, 0xD8); |
634 | |
635 | // t0: a0b0, c0d0, a1b1, c1d1, |
636 | // a8b8, c8d8, a9b9, c9d9 |
637 | __t0 = _mm256_shuffle_epi32(__tt0, 0xD8); |
638 | __t1 = _mm256_shuffle_epi32(__tt1, 0xD8); |
639 | __t2 = _mm256_shuffle_epi32(__tt2, 0xD8); |
640 | __t3 = _mm256_shuffle_epi32(__tt3, 0xD8); |
641 | // t4: e0f0, g0h0, e1f1, g1h1, |
642 | // e8f8, g8h8, e9f9, g9h9 |
643 | __t4 = _mm256_shuffle_epi32(__tt4, 0xD8); |
644 | __t5 = _mm256_shuffle_epi32(__tt5, 0xD8); |
645 | __t6 = _mm256_shuffle_epi32(__tt6, 0xD8); |
646 | __t7 = _mm256_shuffle_epi32(__tt7, 0xD8); |
647 | |
648 | // r0: a0b0, c0d0, e0f0, g0h0, |
649 | // a8b8, c8d8, e8f8, g8h8 |
650 | r[0] = _mm256_unpacklo_epi64(__t0, __t4); // 0, 8 |
651 | // r1: a1b1, c1d1, e1f1, g1h1, |
652 | // a9b9, c9d9, e9f9, g9h9 |
653 | r[1] = _mm256_unpackhi_epi64(__t0, __t4); // 1, 9 |
654 | r[2] = _mm256_unpacklo_epi64(__t1, __t5); // 2, 10 |
655 | r[3] = _mm256_unpackhi_epi64(__t1, __t5); // 3, 11 |
656 | r[4] = _mm256_unpacklo_epi64(__t2, __t6); // 4, 12 |
657 | r[5] = _mm256_unpackhi_epi64(__t2, __t6); // 5, 13 |
658 | r[6] = _mm256_unpacklo_epi64(__t3, __t7); // 6, 14 |
659 | r[7] = _mm256_unpackhi_epi64(__t3, __t7); // 7, 15 |
660 | |
661 | // stores back 16 rows: |
662 | if (MREM || NREM) { |
663 | store_with_remainders_i16(dst, ld_dst, r, mrem, nrem); |
664 | } else { |
665 | _mm_storeu_si128( |
666 | reinterpret_cast<__m128i*>(dst), _mm256_extractf128_si256(r[0], 0)); |
667 | _mm_storeu_si128( |
668 | reinterpret_cast<__m128i*>((dst) + ld_dst), |
669 | _mm256_extractf128_si256(r[1], 0)); |
670 | _mm_storeu_si128( |
671 | reinterpret_cast<__m128i*>((dst) + ld_dst * 2), |
672 | _mm256_extractf128_si256(r[2], 0)); |
673 | _mm_storeu_si128( |
674 | reinterpret_cast<__m128i*>((dst) + ld_dst * 3), |
675 | _mm256_extractf128_si256(r[3], 0)); |
676 | _mm_storeu_si128( |
677 | reinterpret_cast<__m128i*>((dst) + ld_dst * 4), |
678 | _mm256_extractf128_si256(r[4], 0)); |
679 | _mm_storeu_si128( |
680 | reinterpret_cast<__m128i*>((dst) + ld_dst * 5), |
681 | _mm256_extractf128_si256(r[5], 0)); |
682 | _mm_storeu_si128( |
683 | reinterpret_cast<__m128i*>((dst) + ld_dst * 6), |
684 | _mm256_extractf128_si256(r[6], 0)); |
685 | _mm_storeu_si128( |
686 | reinterpret_cast<__m128i*>((dst) + ld_dst * 7), |
687 | _mm256_extractf128_si256(r[7], 0)); |
688 | |
689 | _mm_storeu_si128( |
690 | reinterpret_cast<__m128i*>((dst) + ld_dst * 8), |
691 | _mm256_extractf128_si256(r[0], 1)); |
692 | _mm_storeu_si128( |
693 | reinterpret_cast<__m128i*>((dst) + ld_dst * 9), |
694 | _mm256_extractf128_si256(r[1], 1)); |
695 | _mm_storeu_si128( |
696 | reinterpret_cast<__m128i*>((dst) + ld_dst * 10), |
697 | _mm256_extractf128_si256(r[2], 1)); |
698 | _mm_storeu_si128( |
699 | reinterpret_cast<__m128i*>((dst) + ld_dst * 11), |
700 | _mm256_extractf128_si256(r[3], 1)); |
701 | _mm_storeu_si128( |
702 | reinterpret_cast<__m128i*>((dst) + ld_dst * 12), |
703 | _mm256_extractf128_si256(r[4], 1)); |
704 | _mm_storeu_si128( |
705 | reinterpret_cast<__m128i*>((dst) + ld_dst * 13), |
706 | _mm256_extractf128_si256(r[5], 1)); |
707 | _mm_storeu_si128( |
708 | reinterpret_cast<__m128i*>((dst) + ld_dst * 14), |
709 | _mm256_extractf128_si256(r[6], 1)); |
710 | _mm_storeu_si128( |
711 | reinterpret_cast<__m128i*>((dst) + ld_dst * 15), |
712 | _mm256_extractf128_si256(r[7], 1)); |
713 | } |
714 | } |
715 | |
716 | // kernel for transposing mxn where m, n <= 8 |
717 | // M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + 2 * N instructions |
718 | template <unsigned M> |
719 | static void transpose_kernel_mxn_avx2_uint8( |
720 | unsigned N, |
721 | const uint8_t* src, |
722 | int64_t ld_src, |
723 | uint8_t* dst, |
724 | int64_t ld_dst) { |
725 | // load from src to registers |
726 | // first load masks |
727 | __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
728 | internal::avx2_ps_or_epi32_masks[N / 4])); |
729 | |
730 | __m256i input[8]; |
731 | unsigned i, j; |
732 | for (i = 0; i < M; ++i) { |
733 | uint8_t local_buffer[32] = {0}; |
734 | |
735 | // first load into local buffer with mask |
736 | input[i] = _mm256_maskload_epi32( |
737 | reinterpret_cast<const int*>(src + i * ld_src), mask_v); |
738 | |
739 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(&local_buffer[0]), input[i]); |
740 | |
741 | // fill in the local buffer with the remainder elements |
742 | for (j = N / 4 * 4; j < N; j++) |
743 | local_buffer[j] = src[i * ld_src + j]; |
744 | |
745 | // from local buffer to input registers |
746 | input[i] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(&local_buffer[0])); |
747 | } |
748 | |
749 | // for (; i < 8; ++i) { |
750 | // input[i] = _mm256_setzero_si256(); |
751 | //} |
752 | |
753 | // interleaving 8-bit elements |
754 | // e.g., temp[0] now becomes: a0 b0 a1 b1 a2 b2 ... |
755 | __m256i temp[8]; |
756 | for (i = 0; i < (M + 1) / 2; ++i) { |
757 | temp[2 * i] = _mm256_unpacklo_epi8(input[2 * i], input[2 * i + 1]); |
758 | temp[2 * i + 1] = _mm256_unpackhi_epi8(input[2 * i], input[2 * i + 1]); |
759 | } |
760 | for (i = i * 2; i < 8; ++i) { |
761 | temp[i] = _mm256_setzero_si256(); |
762 | } |
763 | |
764 | // interleaving 16-bit elements |
765 | // e.g., temp[0] now becomes: a0 b0 c0 d0 a1 b1 c1 d1 ... |
766 | for (i = 0; i < (M + 3) / 4; ++i) { |
767 | input[4 * i] = _mm256_unpacklo_epi16(temp[i * 4], temp[i * 4 + 2]); |
768 | input[4 * i + 1] = _mm256_unpackhi_epi16(temp[i * 4], temp[i * 4 + 2]); |
769 | input[4 * i + 2] = _mm256_unpacklo_epi16(temp[i * 4 + 1], temp[i * 4 + 3]); |
770 | input[4 * i + 3] = _mm256_unpackhi_epi16(temp[i * 4 + 1], temp[i * 4 + 3]); |
771 | } |
772 | |
773 | // interleaving 32-bit elements |
774 | // e.g., temp[0] now becomes a0 b0 c0 d0 e0 f0 g0 h0 ... |
775 | for (i = 0; i < 4 /*(M + 1) / 2*/; ++i) { |
776 | temp[2 * i] = _mm256_unpacklo_epi32(input[i], input[(i + 4)]); |
777 | temp[2 * i + 1] = _mm256_unpackhi_epi32(input[i], input[(i + 4)]); |
778 | } |
779 | |
780 | // retrieve the final result, extract every 64-bit |
781 | // i.e., take a 256-bit temp[0] for example, that will |
782 | // 0-63 bit: a0 -- h0, |
783 | // 64-127 bit: a1 -- h1, |
784 | // 128-191 bit: a16 -- h16, |
785 | // 192-255 bit: a17 -- h17 |
786 | uint64_t t; |
787 | mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
788 | internal::avx2_ps_or_epi32_masks[M / 4])); |
789 | for (i = 0; i < N; ++i) { |
790 | if (i < 16) { |
791 | if (i % 2 == 0) |
792 | t = _mm256_extract_epi64(temp[i / 2], 0); |
793 | else |
794 | t = _mm256_extract_epi64(temp[i / 2], 1); |
795 | |
796 | } else { |
797 | if (i % 2 == 0) |
798 | t = _mm256_extract_epi64(temp[(i - 16) / 2], 2); |
799 | else |
800 | t = _mm256_extract_epi64(temp[(i - 16) / 2], 3); |
801 | } |
802 | __m256i t_vec = _mm256_set_epi64x(0, 0, 0, t); |
803 | _mm256_maskstore_epi32( |
804 | reinterpret_cast<int*>(dst + i * ld_dst), mask_v, t_vec); |
805 | for (j = M / 4 * 4; j < M; j++) { |
806 | dst[ld_dst * i + j] = ((t >> (8 * j)) & 255); |
807 | } |
808 | } |
809 | } |
810 | |
811 | #endif // __AVX2__ |
812 | |
813 | } // namespace internal |
814 | |
815 | } // namespace fbgemm |
816 | |