1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#if defined(__x86_64__) || defined(__i386__) || \
8 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
9#include <immintrin.h>
10#endif
11#include "./TransposeUtils.h"
12#include "./TransposeUtilsAvx2.h"
13
14namespace fbgemm {
15
16namespace internal {
17
18template <>
19void transpose_avx2(
20 int64_t M,
21 int64_t N,
22 const float* src,
23 int64_t ld_src,
24 float* dst,
25 int64_t ld_dst) {
26 int64_t ib = 0, jb = 0;
27 if (N % 8 > 0 && N % 8 < 4) {
28 // If the remainder has n < 4 columns, we use the SSE kernel for the
29 // remainder because it requires 2 * (2 * 4 + 2 * N) = 16 + 4N instructions
30 // instead of 3 * 8 + 2 * N = 24 + 2N instructions in the masked AVX2
31 // kernel.
32 for (ib = 0; ib + 8 <= M; ib += 8) {
33 for (jb = 0; jb + 8 <= N; jb += 8) {
34 transpose_kernel_8x8_avx2(
35 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
36 }
37 for (int64_t i = ib; i < ib + 8; i += 4) {
38 transpose_kernel_mxn_sse<4>(
39 N - jb,
40 &src[i * ld_src + jb],
41 ld_src,
42 &dst[i + jb * ld_dst],
43 ld_dst);
44 }
45 }
46 } else if (N % 8 == 4) {
47 // If the remainder has 4 columns, we use the SSE kernel for the remainder
48 // because it requires 2 * 16 = 32 instructions instead of 3 * 8 + 2 * 4 =
49 // 32 instructions + looping overhead needed in the masked AVX2 kernel.
50 for (ib = 0; ib + 8 <= M; ib += 8) {
51 for (jb = 0; jb + 8 <= N; jb += 8) {
52 transpose_kernel_8x8_avx2(
53 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
54 }
55 for (int64_t i = ib; i < ib + 8; i += 4) {
56 transpose_kernel_4x4_sse(
57 &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
58 }
59 }
60 } else {
61 for (ib = 0; ib + 8 <= M; ib += 8) {
62 for (jb = 0; jb + 8 <= N; jb += 8) {
63 transpose_kernel_8x8_avx2(
64 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
65 }
66 if (jb < N) {
67 transpose_kernel_mxn_avx2<8>(
68 N - jb,
69 &src[ib * ld_src + jb],
70 ld_src,
71 &dst[ib + jb * ld_dst],
72 ld_dst);
73 }
74 }
75 }
76
77 // Specialization for small M - ib cases so that the compiler can inline
78 // transpose_kernel_mxn_avx2 and unroll the loops whose iteration count
79 // depends on by M - ib .
80 // Specialization for m helps more than for n in transpose_kernel_mxn_avx2
81 // because we have more loops in that function whose iteration count depends
82 // on m.
83 switch (M - ib) {
84 case 1:
85 for (int64_t j = 0; j < N; ++j) {
86 dst[ib + j * ld_dst] = src[ib * ld_src + j];
87 }
88 break;
89 case 2:
90 for (jb = 0; jb + 4 <= N; jb += 4) {
91 transpose_kernel_mxn_sse<2>(
92 4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
93 }
94 if (jb < N) {
95 transpose_kernel_mxn_sse<2>(
96 N - jb,
97 &src[ib * ld_src + jb],
98 ld_src,
99 &dst[ib + jb * ld_dst],
100 ld_dst);
101 }
102 break;
103 case 3:
104 for (jb = 0; jb + 4 <= N; jb += 4) {
105 transpose_kernel_mxn_sse<3>(
106 4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
107 }
108 if (jb < N) {
109 transpose_kernel_mxn_sse<3>(
110 N - jb,
111 &src[ib * ld_src + jb],
112 ld_src,
113 &dst[ib + jb * ld_dst],
114 ld_dst);
115 }
116 break;
117 case 4:
118 for (jb = 0; jb + 4 <= N; jb += 4) {
119 transpose_kernel_4x4_sse(
120 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
121 }
122 if (jb < N) {
123 transpose_kernel_mxn_sse<4>(
124 N - jb,
125 &src[ib * ld_src + jb],
126 ld_src,
127 &dst[ib + jb * ld_dst],
128 ld_dst);
129 }
130 break;
131 case 5:
132 for (jb = 0; jb + 8 <= N; jb += 8) {
133 transpose_kernel_mxn_avx2<5>(
134 8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
135 }
136 if (jb < N) {
137 transpose_kernel_mxn_avx2<5>(
138 N - jb,
139 &src[ib * ld_src + jb],
140 ld_src,
141 &dst[ib + jb * ld_dst],
142 ld_dst);
143 }
144 break;
145 case 6:
146 for (jb = 0; jb + 8 <= N; jb += 8) {
147 transpose_kernel_mxn_avx2<6>(
148 8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
149 }
150 if (jb < N) {
151 transpose_kernel_mxn_avx2<6>(
152 N - jb,
153 &src[ib * ld_src + jb],
154 ld_src,
155 &dst[ib + jb * ld_dst],
156 ld_dst);
157 }
158 break;
159 case 7:
160 for (jb = 0; jb + 8 <= N; jb += 8) {
161 transpose_kernel_mxn_avx2<7>(
162 8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
163 }
164 if (jb < N) {
165 transpose_kernel_mxn_avx2<7>(
166 N - jb,
167 &src[ib * ld_src + jb],
168 ld_src,
169 &dst[ib + jb * ld_dst],
170 ld_dst);
171 }
172 break;
173 }
174}
175
176template <>
177void transpose_avx2(
178 int64_t M,
179 int64_t N,
180 const uint8_t* src,
181 int64_t ld_src,
182 uint8_t* dst,
183 int64_t ld_dst) {
184 int64_t ib = 0, jb = 0;
185 if (M >= 8) {
186 for (ib = 0; ib + 8 <= M; ib += 8) {
187 for (jb = 0; jb + 32 <= N; jb += 32) {
188 transpose_kernel_8x32_avx2(
189 &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
190 }
191
192 if (jb < N) {
193 transpose_kernel_mxn_avx2_uint8<8>(
194 N - jb,
195 &src[ib * ld_src + jb],
196 ld_src,
197 &dst[ib + jb * ld_dst],
198 ld_dst);
199 }
200 }
201 }
202
203 // Specialization for small M - ib cases
204 switch (M - ib) {
205 case 1:
206 for (jb = 0; jb + 32 <= N; jb += 32) {
207 transpose_kernel_mxn_avx2_uint8<1>(
208 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
209 }
210
211 if (jb < N)
212 transpose_kernel_mxn_avx2_uint8<1>(
213 N - jb,
214 &src[ib * ld_src + jb],
215 ld_src,
216 &dst[ib + jb * ld_dst],
217 ld_dst);
218
219 break;
220 case 2:
221 for (jb = 0; jb + 32 <= N; jb += 32) {
222 transpose_kernel_mxn_avx2_uint8<2>(
223 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
224 }
225 if (jb < N)
226 transpose_kernel_mxn_avx2_uint8<2>(
227 N - jb,
228 &src[ib * ld_src + jb],
229 ld_src,
230 &dst[ib + jb * ld_dst],
231 ld_dst);
232 break;
233 case 3:
234 for (jb = 0; jb + 32 <= N; jb += 32) {
235 transpose_kernel_mxn_avx2_uint8<3>(
236 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
237 }
238 if (jb < N)
239 transpose_kernel_mxn_avx2_uint8<3>(
240 N - jb,
241 &src[ib * ld_src + jb],
242 ld_src,
243 &dst[ib + jb * ld_dst],
244 ld_dst);
245 break;
246 case 4:
247 for (jb = 0; jb + 32 <= N; jb += 32) {
248 transpose_kernel_mxn_avx2_uint8<4>(
249 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
250 }
251 if (jb < N)
252 transpose_kernel_mxn_avx2_uint8<4>(
253 N - jb,
254 &src[ib * ld_src + jb],
255 ld_src,
256 &dst[ib + jb * ld_dst],
257 ld_dst);
258 break;
259 case 5:
260 for (jb = 0; jb + 32 <= N; jb += 32) {
261 transpose_kernel_mxn_avx2_uint8<5>(
262 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
263 }
264 if (jb < N)
265 transpose_kernel_mxn_avx2_uint8<5>(
266 N - jb,
267 &src[ib * ld_src + jb],
268 ld_src,
269 &dst[ib + jb * ld_dst],
270 ld_dst);
271 break;
272 case 6:
273 for (jb = 0; jb + 32 <= N; jb += 32) {
274 transpose_kernel_mxn_avx2_uint8<6>(
275 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
276 }
277 if (jb < N)
278 transpose_kernel_mxn_avx2_uint8<6>(
279 N - jb,
280 &src[ib * ld_src + jb],
281 ld_src,
282 &dst[ib + jb * ld_dst],
283 ld_dst);
284 break;
285 case 7:
286 for (jb = 0; jb + 32 <= N; jb += 32) {
287 transpose_kernel_mxn_avx2_uint8<7>(
288 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
289 }
290 if (jb < N)
291 transpose_kernel_mxn_avx2_uint8<7>(
292 N - jb,
293 &src[ib * ld_src + jb],
294 ld_src,
295 &dst[ib + jb * ld_dst],
296 ld_dst);
297 break;
298 }
299}
300
301template <>
302void transpose_avx2(
303 int64_t M,
304 int64_t N,
305 const uint16_t* src,
306 int64_t ld_src,
307 uint16_t* dst,
308 int64_t ld_dst) {
309 int64_t i = 0;
310 for (; i < M / 8 * 8; i += 8) {
311 int64_t j = 0;
312 for (; j < N / 16 * 16; j += 16) {
313 transpose_kernel_8x16_avx2<false, false>(
314 src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst);
315 }
316 // handle j rem
317 unsigned nrem = N - j;
318 if (nrem > 0) {
319 transpose_kernel_8x16_avx2<false, true>(
320 src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 8, nrem);
321 }
322 }
323
324 // handle i rem
325 unsigned mrem = M - i;
326 if (mrem > 0) {
327 int64_t j = 0;
328 for (; j < N / 16 * 16; j += 16) {
329 transpose_kernel_8x16_avx2<true, false>(
330 src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16);
331 }
332 // handle j rem
333 unsigned nrem = N - j;
334 transpose_kernel_8x16_avx2<true, true>(
335 src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem);
336 }
337}
338
339} // namespace internal
340
341} // namespace fbgemm
342