1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "./TransposeUtils.h"
9#include <cstring>
10#include "fbgemm/Utils.h"
11
12namespace fbgemm {
13
14template <typename T>
15void transpose_ref(
16 int64_t M,
17 int64_t N,
18 const T* src,
19 int64_t ld_src,
20 T* dst,
21 int64_t ld_dst) {
22 for (int64_t j = 0; j < N; j++) {
23 for (int64_t i = 0; i < M; i++) {
24 dst[i + j * ld_dst] = src[i * ld_src + j];
25 }
26 } // for each output row
27}
28
29template <typename T>
30void transpose_simd(
31 int64_t M,
32 int64_t N,
33 const T* src,
34 int64_t ld_src,
35 T* dst,
36 int64_t ld_dst) {
37 if (M == 0 || N == 0) {
38 return;
39 }
40 if ((M == 1 && ld_dst == 1) || (N == 1 && ld_src == 1)) {
41 if (dst != src) {
42 // sizeof must be first operand force dims promotion to OS-bitness type
43 memcpy(dst, src, sizeof(T) * M * N);
44 }
45 return;
46 }
47 static const auto iset = fbgemmInstructionSet();
48 // Run time CPU detection
49 if (isZmm(iset)) {
50 internal::transpose_avx512<T>(M, N, src, ld_src, dst, ld_dst);
51 } else if (isYmm(iset)) {
52 internal::transpose_avx2<T>(M, N, src, ld_src, dst, ld_dst);
53 } else {
54 transpose_ref<T>(M, N, src, ld_src, dst, ld_dst);
55 }
56}
57
58template void transpose_ref<float>(
59 int64_t M,
60 int64_t N,
61 const float* src,
62 int64_t ld_src,
63 float* dst,
64 int64_t ld_dst);
65
66template void transpose_ref<uint16_t>(
67 int64_t M,
68 int64_t N,
69 const uint16_t* src,
70 int64_t ld_src,
71 uint16_t* dst,
72 int64_t ld_dst);
73
74template void transpose_ref<uint8_t>(
75 int64_t M,
76 int64_t N,
77 const uint8_t* src,
78 int64_t ld_src,
79 uint8_t* dst,
80 int64_t ld_dst);
81
82template FBGEMM_API void transpose_simd<float>(
83 int64_t M,
84 int64_t N,
85 const float* src,
86 int64_t ld_src,
87 float* dst,
88 int64_t ld_dst);
89
90template FBGEMM_API void transpose_simd<uint8_t>(
91 int64_t M,
92 int64_t N,
93 const uint8_t* src,
94 int64_t ld_src,
95 uint8_t* dst,
96 int64_t ld_dst);
97
98template FBGEMM_API void transpose_simd<uint16_t>(
99 int64_t M,
100 int64_t N,
101 const uint16_t* src,
102 int64_t ld_src,
103 uint16_t* dst,
104 int64_t ld_dst);
105} // namespace fbgemm
106