1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#if defined(__x86_64__) || defined(__i386__) || \
8 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
9#include <immintrin.h>
10#endif
11#include "fbgemm/FbgemmConvert.h"
12
13namespace fbgemm {
14
15namespace {
16
17inline __m256i QuantizeBfloat16Avx512(const __m512& x0) {
18 // Add 2^15 and right shift 16 to do round-nearest
19 __m512i y0 = _mm512_srli_epi32(
20 _mm512_add_epi32(_mm512_castps_si512(x0), _mm512_set1_epi32(1 << 15)),
21 16);
22 return _mm512_cvtepi32_epi16(y0);
23}
24
25inline void FloatToBfloat16KernelAvx512(const float* src, bfloat16* dst) {
26 // One float m512i -> One bfloat16 m256i
27 const __m512 src_reg0 = _mm512_loadu_ps(src);
28 __m256i dst_reg0 = QuantizeBfloat16Avx512(src_reg0);
29 _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg0);
30}
31
32inline void Bfloat16ToFloatKernelAvx512(const bfloat16* src, float* dst) {
33 // One bfloat16 m256i -> One float m512i
34 const __m256i src_reg =
35 _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(src));
36 __m512i dst_reg_bf16 = _mm512_cvtepu16_epi32(src_reg);
37 __m512i dst_reg = _mm512_slli_epi32(dst_reg_bf16, 16);
38 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), dst_reg);
39}
40
41} // namespace
42
43void FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size) {
44 size_t i = 0;
45 for (i = 0; i + 16 <= size; i += 16) {
46 FloatToBfloat16KernelAvx512(src + i, dst + i);
47 }
48 FloatToBfloat16_avx2(src + i, dst + i, size - i);
49}
50
51void Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size) {
52 size_t i = 0;
53 for (i = 0; i + 16 <= size; i += 16) {
54 Bfloat16ToFloatKernelAvx512(src + i, dst + i);
55 }
56 Bfloat16ToFloat_avx2(src + i, dst + i, size - i);
57}
58
59} // namespace fbgemm
60