1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #if defined(__x86_64__) || defined(__i386__) || \ |
8 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
9 | #include <immintrin.h> |
10 | #endif |
11 | #include "./TransposeUtils.h" |
12 | #include "./TransposeUtilsAvx2.h" |
13 | |
14 | namespace fbgemm { |
15 | |
16 | namespace internal { |
17 | |
18 | template <> |
19 | void transpose_avx2( |
20 | int64_t M, |
21 | int64_t N, |
22 | const float* src, |
23 | int64_t ld_src, |
24 | float* dst, |
25 | int64_t ld_dst) { |
26 | int64_t ib = 0, jb = 0; |
27 | if (N % 8 > 0 && N % 8 < 4) { |
28 | // If the remainder has n < 4 columns, we use the SSE kernel for the |
29 | // remainder because it requires 2 * (2 * 4 + 2 * N) = 16 + 4N instructions |
30 | // instead of 3 * 8 + 2 * N = 24 + 2N instructions in the masked AVX2 |
31 | // kernel. |
32 | for (ib = 0; ib + 8 <= M; ib += 8) { |
33 | for (jb = 0; jb + 8 <= N; jb += 8) { |
34 | transpose_kernel_8x8_avx2( |
35 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
36 | } |
37 | for (int64_t i = ib; i < ib + 8; i += 4) { |
38 | transpose_kernel_mxn_sse<4>( |
39 | N - jb, |
40 | &src[i * ld_src + jb], |
41 | ld_src, |
42 | &dst[i + jb * ld_dst], |
43 | ld_dst); |
44 | } |
45 | } |
46 | } else if (N % 8 == 4) { |
47 | // If the remainder has 4 columns, we use the SSE kernel for the remainder |
48 | // because it requires 2 * 16 = 32 instructions instead of 3 * 8 + 2 * 4 = |
49 | // 32 instructions + looping overhead needed in the masked AVX2 kernel. |
50 | for (ib = 0; ib + 8 <= M; ib += 8) { |
51 | for (jb = 0; jb + 8 <= N; jb += 8) { |
52 | transpose_kernel_8x8_avx2( |
53 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
54 | } |
55 | for (int64_t i = ib; i < ib + 8; i += 4) { |
56 | transpose_kernel_4x4_sse( |
57 | &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst); |
58 | } |
59 | } |
60 | } else { |
61 | for (ib = 0; ib + 8 <= M; ib += 8) { |
62 | for (jb = 0; jb + 8 <= N; jb += 8) { |
63 | transpose_kernel_8x8_avx2( |
64 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
65 | } |
66 | if (jb < N) { |
67 | transpose_kernel_mxn_avx2<8>( |
68 | N - jb, |
69 | &src[ib * ld_src + jb], |
70 | ld_src, |
71 | &dst[ib + jb * ld_dst], |
72 | ld_dst); |
73 | } |
74 | } |
75 | } |
76 | |
77 | // Specialization for small M - ib cases so that the compiler can inline |
78 | // transpose_kernel_mxn_avx2 and unroll the loops whose iteration count |
79 | // depends on by M - ib . |
80 | // Specialization for m helps more than for n in transpose_kernel_mxn_avx2 |
81 | // because we have more loops in that function whose iteration count depends |
82 | // on m. |
83 | switch (M - ib) { |
84 | case 1: |
85 | for (int64_t j = 0; j < N; ++j) { |
86 | dst[ib + j * ld_dst] = src[ib * ld_src + j]; |
87 | } |
88 | break; |
89 | case 2: |
90 | for (jb = 0; jb + 4 <= N; jb += 4) { |
91 | transpose_kernel_mxn_sse<2>( |
92 | 4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
93 | } |
94 | if (jb < N) { |
95 | transpose_kernel_mxn_sse<2>( |
96 | N - jb, |
97 | &src[ib * ld_src + jb], |
98 | ld_src, |
99 | &dst[ib + jb * ld_dst], |
100 | ld_dst); |
101 | } |
102 | break; |
103 | case 3: |
104 | for (jb = 0; jb + 4 <= N; jb += 4) { |
105 | transpose_kernel_mxn_sse<3>( |
106 | 4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
107 | } |
108 | if (jb < N) { |
109 | transpose_kernel_mxn_sse<3>( |
110 | N - jb, |
111 | &src[ib * ld_src + jb], |
112 | ld_src, |
113 | &dst[ib + jb * ld_dst], |
114 | ld_dst); |
115 | } |
116 | break; |
117 | case 4: |
118 | for (jb = 0; jb + 4 <= N; jb += 4) { |
119 | transpose_kernel_4x4_sse( |
120 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
121 | } |
122 | if (jb < N) { |
123 | transpose_kernel_mxn_sse<4>( |
124 | N - jb, |
125 | &src[ib * ld_src + jb], |
126 | ld_src, |
127 | &dst[ib + jb * ld_dst], |
128 | ld_dst); |
129 | } |
130 | break; |
131 | case 5: |
132 | for (jb = 0; jb + 8 <= N; jb += 8) { |
133 | transpose_kernel_mxn_avx2<5>( |
134 | 8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
135 | } |
136 | if (jb < N) { |
137 | transpose_kernel_mxn_avx2<5>( |
138 | N - jb, |
139 | &src[ib * ld_src + jb], |
140 | ld_src, |
141 | &dst[ib + jb * ld_dst], |
142 | ld_dst); |
143 | } |
144 | break; |
145 | case 6: |
146 | for (jb = 0; jb + 8 <= N; jb += 8) { |
147 | transpose_kernel_mxn_avx2<6>( |
148 | 8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
149 | } |
150 | if (jb < N) { |
151 | transpose_kernel_mxn_avx2<6>( |
152 | N - jb, |
153 | &src[ib * ld_src + jb], |
154 | ld_src, |
155 | &dst[ib + jb * ld_dst], |
156 | ld_dst); |
157 | } |
158 | break; |
159 | case 7: |
160 | for (jb = 0; jb + 8 <= N; jb += 8) { |
161 | transpose_kernel_mxn_avx2<7>( |
162 | 8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
163 | } |
164 | if (jb < N) { |
165 | transpose_kernel_mxn_avx2<7>( |
166 | N - jb, |
167 | &src[ib * ld_src + jb], |
168 | ld_src, |
169 | &dst[ib + jb * ld_dst], |
170 | ld_dst); |
171 | } |
172 | break; |
173 | } |
174 | } |
175 | |
176 | template <> |
177 | void transpose_avx2( |
178 | int64_t M, |
179 | int64_t N, |
180 | const uint8_t* src, |
181 | int64_t ld_src, |
182 | uint8_t* dst, |
183 | int64_t ld_dst) { |
184 | int64_t ib = 0, jb = 0; |
185 | if (M >= 8) { |
186 | for (ib = 0; ib + 8 <= M; ib += 8) { |
187 | for (jb = 0; jb + 32 <= N; jb += 32) { |
188 | transpose_kernel_8x32_avx2( |
189 | &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
190 | } |
191 | |
192 | if (jb < N) { |
193 | transpose_kernel_mxn_avx2_uint8<8>( |
194 | N - jb, |
195 | &src[ib * ld_src + jb], |
196 | ld_src, |
197 | &dst[ib + jb * ld_dst], |
198 | ld_dst); |
199 | } |
200 | } |
201 | } |
202 | |
203 | // Specialization for small M - ib cases |
204 | switch (M - ib) { |
205 | case 1: |
206 | for (jb = 0; jb + 32 <= N; jb += 32) { |
207 | transpose_kernel_mxn_avx2_uint8<1>( |
208 | 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
209 | } |
210 | |
211 | if (jb < N) |
212 | transpose_kernel_mxn_avx2_uint8<1>( |
213 | N - jb, |
214 | &src[ib * ld_src + jb], |
215 | ld_src, |
216 | &dst[ib + jb * ld_dst], |
217 | ld_dst); |
218 | |
219 | break; |
220 | case 2: |
221 | for (jb = 0; jb + 32 <= N; jb += 32) { |
222 | transpose_kernel_mxn_avx2_uint8<2>( |
223 | 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
224 | } |
225 | if (jb < N) |
226 | transpose_kernel_mxn_avx2_uint8<2>( |
227 | N - jb, |
228 | &src[ib * ld_src + jb], |
229 | ld_src, |
230 | &dst[ib + jb * ld_dst], |
231 | ld_dst); |
232 | break; |
233 | case 3: |
234 | for (jb = 0; jb + 32 <= N; jb += 32) { |
235 | transpose_kernel_mxn_avx2_uint8<3>( |
236 | 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
237 | } |
238 | if (jb < N) |
239 | transpose_kernel_mxn_avx2_uint8<3>( |
240 | N - jb, |
241 | &src[ib * ld_src + jb], |
242 | ld_src, |
243 | &dst[ib + jb * ld_dst], |
244 | ld_dst); |
245 | break; |
246 | case 4: |
247 | for (jb = 0; jb + 32 <= N; jb += 32) { |
248 | transpose_kernel_mxn_avx2_uint8<4>( |
249 | 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
250 | } |
251 | if (jb < N) |
252 | transpose_kernel_mxn_avx2_uint8<4>( |
253 | N - jb, |
254 | &src[ib * ld_src + jb], |
255 | ld_src, |
256 | &dst[ib + jb * ld_dst], |
257 | ld_dst); |
258 | break; |
259 | case 5: |
260 | for (jb = 0; jb + 32 <= N; jb += 32) { |
261 | transpose_kernel_mxn_avx2_uint8<5>( |
262 | 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
263 | } |
264 | if (jb < N) |
265 | transpose_kernel_mxn_avx2_uint8<5>( |
266 | N - jb, |
267 | &src[ib * ld_src + jb], |
268 | ld_src, |
269 | &dst[ib + jb * ld_dst], |
270 | ld_dst); |
271 | break; |
272 | case 6: |
273 | for (jb = 0; jb + 32 <= N; jb += 32) { |
274 | transpose_kernel_mxn_avx2_uint8<6>( |
275 | 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
276 | } |
277 | if (jb < N) |
278 | transpose_kernel_mxn_avx2_uint8<6>( |
279 | N - jb, |
280 | &src[ib * ld_src + jb], |
281 | ld_src, |
282 | &dst[ib + jb * ld_dst], |
283 | ld_dst); |
284 | break; |
285 | case 7: |
286 | for (jb = 0; jb + 32 <= N; jb += 32) { |
287 | transpose_kernel_mxn_avx2_uint8<7>( |
288 | 32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); |
289 | } |
290 | if (jb < N) |
291 | transpose_kernel_mxn_avx2_uint8<7>( |
292 | N - jb, |
293 | &src[ib * ld_src + jb], |
294 | ld_src, |
295 | &dst[ib + jb * ld_dst], |
296 | ld_dst); |
297 | break; |
298 | } |
299 | } |
300 | |
301 | template <> |
302 | void transpose_avx2( |
303 | int64_t M, |
304 | int64_t N, |
305 | const uint16_t* src, |
306 | int64_t ld_src, |
307 | uint16_t* dst, |
308 | int64_t ld_dst) { |
309 | int64_t i = 0; |
310 | for (; i < M / 8 * 8; i += 8) { |
311 | int64_t j = 0; |
312 | for (; j < N / 16 * 16; j += 16) { |
313 | transpose_kernel_8x16_avx2<false, false>( |
314 | src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst); |
315 | } |
316 | // handle j rem |
317 | unsigned nrem = N - j; |
318 | if (nrem > 0) { |
319 | transpose_kernel_8x16_avx2<false, true>( |
320 | src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 8, nrem); |
321 | } |
322 | } |
323 | |
324 | // handle i rem |
325 | unsigned mrem = M - i; |
326 | if (mrem > 0) { |
327 | int64_t j = 0; |
328 | for (; j < N / 16 * 16; j += 16) { |
329 | transpose_kernel_8x16_avx2<true, false>( |
330 | src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16); |
331 | } |
332 | // handle j rem |
333 | unsigned nrem = N - j; |
334 | transpose_kernel_8x16_avx2<true, true>( |
335 | src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); |
336 | } |
337 | } |
338 | |
339 | } // namespace internal |
340 | |
341 | } // namespace fbgemm |
342 | |