1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include <array> |
9 | #include <cmath> |
10 | #include <utility> |
11 | |
12 | #include "./FbgemmFP16UKernelsAvx2.h" |
13 | #include "./FbgemmFP16UKernelsAvx512.h" |
14 | #include "./FbgemmFP16UKernelsAvx512_256.h" |
15 | #include "fbgemm/Fbgemm.h" |
16 | #include "fbgemm/FbgemmFPCommon.h" |
17 | |
18 | namespace fbgemm { |
19 | |
20 | namespace { |
21 | // optimized kernels to cover all cases |
22 | // 2 in ?x2 should be the same as kernel_ncol_blocks. |
23 | // Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to |
24 | // the restrictions of ymm register numbers (16). |
25 | constexpr kernel_array_t<float16> kernel_fp16_avx2 = { |
26 | nullptr, |
27 | gemmkernel_1x2_Avx2_fp16_fA0fB0fC0, |
28 | gemmkernel_2x2_Avx2_fp16_fA0fB0fC0, |
29 | gemmkernel_3x2_Avx2_fp16_fA0fB0fC0, |
30 | gemmkernel_4x2_Avx2_fp16_fA0fB0fC0, |
31 | gemmkernel_5x2_Avx2_fp16_fA0fB0fC0, |
32 | gemmkernel_6x2_Avx2_fp16_fA0fB0fC0}; |
33 | |
34 | constexpr kernel_array_t<float16> kernel_fp16_avx512_256 = { |
35 | nullptr, |
36 | gemmkernel_1x2_Avx2_fp16_fA0fB0fC0, |
37 | gemmkernel_2x2_Avx2_fp16_fA0fB0fC0, |
38 | gemmkernel_3x2_Avx2_fp16_fA0fB0fC0, |
39 | gemmkernel_4x2_Avx2_fp16_fA0fB0fC0, |
40 | gemmkernel_5x2_Avx2_fp16_fA0fB0fC0, |
41 | gemmkernel_6x2_Avx2_fp16_fA0fB0fC0, |
42 | gemmkernel_7x2_Avx512_256_fp16_fA0fB0fC0, |
43 | gemmkernel_8x2_Avx512_256_fp16_fA0fB0fC0, |
44 | gemmkernel_9x2_Avx512_256_fp16_fA0fB0fC0, |
45 | gemmkernel_10x2_Avx512_256_fp16_fA0fB0fC0, |
46 | gemmkernel_11x2_Avx512_256_fp16_fA0fB0fC0, |
47 | gemmkernel_12x2_Avx512_256_fp16_fA0fB0fC0, |
48 | gemmkernel_13x2_Avx512_256_fp16_fA0fB0fC0, |
49 | gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0}; |
50 | |
51 | constexpr kernel_array_t<float16> kernel_fp16_avx512 = { |
52 | #ifndef __aarch64__ |
53 | nullptr, |
54 | gemmkernel_1x2_Avx512_fp16_fA0fB0fC0, |
55 | gemmkernel_2x2_Avx512_fp16_fA0fB0fC0, |
56 | gemmkernel_3x2_Avx512_fp16_fA0fB0fC0, |
57 | gemmkernel_4x2_Avx512_fp16_fA0fB0fC0, |
58 | gemmkernel_5x2_Avx512_fp16_fA0fB0fC0, |
59 | gemmkernel_6x2_Avx512_fp16_fA0fB0fC0, |
60 | gemmkernel_7x2_Avx512_fp16_fA0fB0fC0, |
61 | gemmkernel_8x2_Avx512_fp16_fA0fB0fC0, |
62 | gemmkernel_9x2_Avx512_fp16_fA0fB0fC0, |
63 | gemmkernel_10x2_Avx512_fp16_fA0fB0fC0, |
64 | gemmkernel_11x2_Avx512_fp16_fA0fB0fC0, |
65 | gemmkernel_12x2_Avx512_fp16_fA0fB0fC0, |
66 | gemmkernel_13x2_Avx512_fp16_fA0fB0fC0, |
67 | gemmkernel_14x2_Avx512_fp16_fA0fB0fC0 |
68 | #else |
69 | nullptr |
70 | #endif |
71 | }; |
72 | |
73 | } // namespace |
74 | |
75 | template <> |
76 | const isa_descriptor<float16>& getIsaHandlers(inst_set_t isa, float16) { |
77 | static isa_descriptor<float16> avx2_descriptor = |
78 | std::make_tuple(kernel_fp16_avx2, partition_avx2); |
79 | static isa_descriptor<float16> avx512_descriptor = |
80 | std::make_tuple(kernel_fp16_avx512, partition_avx512); |
81 | static isa_descriptor<float16> avx512_256_descriptor = |
82 | std::make_tuple(kernel_fp16_avx512_256, partition_avx512); |
83 | |
84 | switch (isa) { |
85 | case inst_set_t::anyarch: |
86 | case inst_set_t::avx2: |
87 | return avx2_descriptor; |
88 | |
89 | case inst_set_t::avx512: |
90 | case inst_set_t::avx512_vnni: |
91 | return avx512_descriptor; |
92 | |
93 | case inst_set_t::avx512_ymm: |
94 | case inst_set_t::avx512_vnni_ymm: |
95 | return avx512_256_descriptor; |
96 | } |
97 | |
98 | throw std::runtime_error("Unsupported uArch" ); |
99 | } |
100 | |
101 | #ifdef FBGEMM_FP16_FALLBACK_TO_REF_KERNEL |
102 | template <> |
103 | FBGEMM_API void ref_kernel<float16>( |
104 | int kernel_nrows, |
105 | GemmParams<float16>* gp, |
106 | const float* C_base, |
107 | int m_total, |
108 | int n_total, |
109 | int simd_len) { |
110 | int kernel_ncol_blocks = 2; |
111 | int block_col_size = simd_len * kernel_ncol_blocks; |
112 | for (int jb = 0; jb < gp->b_block_cols; ++jb) { |
113 | for (int k = 0; k < gp->k; ++k) { |
114 | for (int i = 0; i < kernel_nrows; ++i) { |
115 | float a = gp->A[i + k * kernel_nrows]; |
116 | for (int j = 0; j < block_col_size; ++j) { |
117 | float* C_ptr = |
118 | gp->C + i * (gp->ldc / sizeof(float)) + jb * block_col_size + j; |
119 | assert(C_ptr < C_base + m_total * n_total); |
120 | float b = |
121 | cpu_half2float(gp->B[(jb * gp->k + k) * block_col_size + j]); |
122 | if (k == 0) { |
123 | if (gp->beta) { |
124 | *C_ptr = std::fma(a, b, (gp->beta) * (*C_ptr)); |
125 | } else { |
126 | *C_ptr = a * b; |
127 | } |
128 | } else { |
129 | *C_ptr = std::fma(a, b, *C_ptr); |
130 | } |
131 | } |
132 | } |
133 | } |
134 | } |
135 | } |
136 | #endif // FBGEMM_FP16_FALLBACK_TO_REF_KERNEL |
137 | |
138 | template FBGEMM_API void cblas_gemm_compute( |
139 | const matrix_op_t transa, |
140 | const int m, |
141 | const float* A, |
142 | const PackedGemmMatrixB<float16>& Bp, |
143 | const float beta, |
144 | float* C, |
145 | int thread_id, |
146 | int num_threads); |
147 | |
148 | } // namespace fbgemm |
149 | |