1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "./TransposeUtils.h" |
9 | #include <cstring> |
10 | #include "fbgemm/Utils.h" |
11 | |
12 | namespace fbgemm { |
13 | |
14 | template <typename T> |
15 | void transpose_ref( |
16 | int64_t M, |
17 | int64_t N, |
18 | const T* src, |
19 | int64_t ld_src, |
20 | T* dst, |
21 | int64_t ld_dst) { |
22 | for (int64_t j = 0; j < N; j++) { |
23 | for (int64_t i = 0; i < M; i++) { |
24 | dst[i + j * ld_dst] = src[i * ld_src + j]; |
25 | } |
26 | } // for each output row |
27 | } |
28 | |
29 | template <typename T> |
30 | void transpose_simd( |
31 | int64_t M, |
32 | int64_t N, |
33 | const T* src, |
34 | int64_t ld_src, |
35 | T* dst, |
36 | int64_t ld_dst) { |
37 | if (M == 0 || N == 0) { |
38 | return; |
39 | } |
40 | if ((M == 1 && ld_dst == 1) || (N == 1 && ld_src == 1)) { |
41 | if (dst != src) { |
42 | // sizeof must be first operand force dims promotion to OS-bitness type |
43 | memcpy(dst, src, sizeof(T) * M * N); |
44 | } |
45 | return; |
46 | } |
47 | static const auto iset = fbgemmInstructionSet(); |
48 | // Run time CPU detection |
49 | if (isZmm(iset)) { |
50 | internal::transpose_avx512<T>(M, N, src, ld_src, dst, ld_dst); |
51 | } else if (isYmm(iset)) { |
52 | internal::transpose_avx2<T>(M, N, src, ld_src, dst, ld_dst); |
53 | } else { |
54 | transpose_ref<T>(M, N, src, ld_src, dst, ld_dst); |
55 | } |
56 | } |
57 | |
58 | template void transpose_ref<float>( |
59 | int64_t M, |
60 | int64_t N, |
61 | const float* src, |
62 | int64_t ld_src, |
63 | float* dst, |
64 | int64_t ld_dst); |
65 | |
66 | template void transpose_ref<uint16_t>( |
67 | int64_t M, |
68 | int64_t N, |
69 | const uint16_t* src, |
70 | int64_t ld_src, |
71 | uint16_t* dst, |
72 | int64_t ld_dst); |
73 | |
74 | template void transpose_ref<uint8_t>( |
75 | int64_t M, |
76 | int64_t N, |
77 | const uint8_t* src, |
78 | int64_t ld_src, |
79 | uint8_t* dst, |
80 | int64_t ld_dst); |
81 | |
82 | template FBGEMM_API void transpose_simd<float>( |
83 | int64_t M, |
84 | int64_t N, |
85 | const float* src, |
86 | int64_t ld_src, |
87 | float* dst, |
88 | int64_t ld_dst); |
89 | |
90 | template FBGEMM_API void transpose_simd<uint8_t>( |
91 | int64_t M, |
92 | int64_t N, |
93 | const uint8_t* src, |
94 | int64_t ld_src, |
95 | uint8_t* dst, |
96 | int64_t ld_dst); |
97 | |
98 | template FBGEMM_API void transpose_simd<uint16_t>( |
99 | int64_t M, |
100 | int64_t N, |
101 | const uint16_t* src, |
102 | int64_t ld_src, |
103 | uint16_t* dst, |
104 | int64_t ld_dst); |
105 | } // namespace fbgemm |
106 | |