1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <array>
9#include <cmath>
10#include <utility>
11
12#include "./FbgemmFP16UKernelsAvx2.h"
13#include "./FbgemmFP16UKernelsAvx512.h"
14#include "./FbgemmFP16UKernelsAvx512_256.h"
15#include "fbgemm/Fbgemm.h"
16#include "fbgemm/FbgemmFPCommon.h"
17
18namespace fbgemm {
19
20namespace {
21// optimized kernels to cover all cases
22// 2 in ?x2 should be the same as kernel_ncol_blocks.
23// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
24// the restrictions of ymm register numbers (16).
25constexpr kernel_array_t<float16> kernel_fp16_avx2 = {
26 nullptr,
27 gemmkernel_1x2_Avx2_fp16_fA0fB0fC0,
28 gemmkernel_2x2_Avx2_fp16_fA0fB0fC0,
29 gemmkernel_3x2_Avx2_fp16_fA0fB0fC0,
30 gemmkernel_4x2_Avx2_fp16_fA0fB0fC0,
31 gemmkernel_5x2_Avx2_fp16_fA0fB0fC0,
32 gemmkernel_6x2_Avx2_fp16_fA0fB0fC0};
33
34constexpr kernel_array_t<float16> kernel_fp16_avx512_256 = {
35 nullptr,
36 gemmkernel_1x2_Avx2_fp16_fA0fB0fC0,
37 gemmkernel_2x2_Avx2_fp16_fA0fB0fC0,
38 gemmkernel_3x2_Avx2_fp16_fA0fB0fC0,
39 gemmkernel_4x2_Avx2_fp16_fA0fB0fC0,
40 gemmkernel_5x2_Avx2_fp16_fA0fB0fC0,
41 gemmkernel_6x2_Avx2_fp16_fA0fB0fC0,
42 gemmkernel_7x2_Avx512_256_fp16_fA0fB0fC0,
43 gemmkernel_8x2_Avx512_256_fp16_fA0fB0fC0,
44 gemmkernel_9x2_Avx512_256_fp16_fA0fB0fC0,
45 gemmkernel_10x2_Avx512_256_fp16_fA0fB0fC0,
46 gemmkernel_11x2_Avx512_256_fp16_fA0fB0fC0,
47 gemmkernel_12x2_Avx512_256_fp16_fA0fB0fC0,
48 gemmkernel_13x2_Avx512_256_fp16_fA0fB0fC0,
49 gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0};
50
51constexpr kernel_array_t<float16> kernel_fp16_avx512 = {
52#ifndef __aarch64__
53 nullptr,
54 gemmkernel_1x2_Avx512_fp16_fA0fB0fC0,
55 gemmkernel_2x2_Avx512_fp16_fA0fB0fC0,
56 gemmkernel_3x2_Avx512_fp16_fA0fB0fC0,
57 gemmkernel_4x2_Avx512_fp16_fA0fB0fC0,
58 gemmkernel_5x2_Avx512_fp16_fA0fB0fC0,
59 gemmkernel_6x2_Avx512_fp16_fA0fB0fC0,
60 gemmkernel_7x2_Avx512_fp16_fA0fB0fC0,
61 gemmkernel_8x2_Avx512_fp16_fA0fB0fC0,
62 gemmkernel_9x2_Avx512_fp16_fA0fB0fC0,
63 gemmkernel_10x2_Avx512_fp16_fA0fB0fC0,
64 gemmkernel_11x2_Avx512_fp16_fA0fB0fC0,
65 gemmkernel_12x2_Avx512_fp16_fA0fB0fC0,
66 gemmkernel_13x2_Avx512_fp16_fA0fB0fC0,
67 gemmkernel_14x2_Avx512_fp16_fA0fB0fC0
68#else
69 nullptr
70#endif
71};
72
73} // namespace
74
75template <>
76const isa_descriptor<float16>& getIsaHandlers(inst_set_t isa, float16) {
77 static isa_descriptor<float16> avx2_descriptor =
78 std::make_tuple(kernel_fp16_avx2, partition_avx2);
79 static isa_descriptor<float16> avx512_descriptor =
80 std::make_tuple(kernel_fp16_avx512, partition_avx512);
81 static isa_descriptor<float16> avx512_256_descriptor =
82 std::make_tuple(kernel_fp16_avx512_256, partition_avx512);
83
84 switch (isa) {
85 case inst_set_t::anyarch:
86 case inst_set_t::avx2:
87 return avx2_descriptor;
88
89 case inst_set_t::avx512:
90 case inst_set_t::avx512_vnni:
91 return avx512_descriptor;
92
93 case inst_set_t::avx512_ymm:
94 case inst_set_t::avx512_vnni_ymm:
95 return avx512_256_descriptor;
96 }
97
98 throw std::runtime_error("Unsupported uArch");
99}
100
101#ifdef FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
102template <>
103FBGEMM_API void ref_kernel<float16>(
104 int kernel_nrows,
105 GemmParams<float16>* gp,
106 const float* C_base,
107 int m_total,
108 int n_total,
109 int simd_len) {
110 int kernel_ncol_blocks = 2;
111 int block_col_size = simd_len * kernel_ncol_blocks;
112 for (int jb = 0; jb < gp->b_block_cols; ++jb) {
113 for (int k = 0; k < gp->k; ++k) {
114 for (int i = 0; i < kernel_nrows; ++i) {
115 float a = gp->A[i + k * kernel_nrows];
116 for (int j = 0; j < block_col_size; ++j) {
117 float* C_ptr =
118 gp->C + i * (gp->ldc / sizeof(float)) + jb * block_col_size + j;
119 assert(C_ptr < C_base + m_total * n_total);
120 float b =
121 cpu_half2float(gp->B[(jb * gp->k + k) * block_col_size + j]);
122 if (k == 0) {
123 if (gp->beta) {
124 *C_ptr = std::fma(a, b, (gp->beta) * (*C_ptr));
125 } else {
126 *C_ptr = a * b;
127 }
128 } else {
129 *C_ptr = std::fma(a, b, *C_ptr);
130 }
131 }
132 }
133 }
134 }
135}
136#endif // FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
137
138template FBGEMM_API void cblas_gemm_compute(
139 const matrix_op_t transa,
140 const int m,
141 const float* A,
142 const PackedGemmMatrixB<float16>& Bp,
143 const float beta,
144 float* C,
145 int thread_id,
146 int num_threads);
147
148} // namespace fbgemm
149