1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #if defined(__x86_64__) || defined(__i386__) || \ |
8 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
9 | #include <immintrin.h> |
10 | #endif |
11 | #include "fbgemm/FbgemmConvert.h" |
12 | |
13 | namespace fbgemm { |
14 | |
15 | namespace { |
16 | |
17 | inline __m256i QuantizeBfloat16Avx512(const __m512& x0) { |
18 | // Add 2^15 and right shift 16 to do round-nearest |
19 | __m512i y0 = _mm512_srli_epi32( |
20 | _mm512_add_epi32(_mm512_castps_si512(x0), _mm512_set1_epi32(1 << 15)), |
21 | 16); |
22 | return _mm512_cvtepi32_epi16(y0); |
23 | } |
24 | |
25 | inline void FloatToBfloat16KernelAvx512(const float* src, bfloat16* dst) { |
26 | // One float m512i -> One bfloat16 m256i |
27 | const __m512 src_reg0 = _mm512_loadu_ps(src); |
28 | __m256i dst_reg0 = QuantizeBfloat16Avx512(src_reg0); |
29 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg0); |
30 | } |
31 | |
32 | inline void Bfloat16ToFloatKernelAvx512(const bfloat16* src, float* dst) { |
33 | // One bfloat16 m256i -> One float m512i |
34 | const __m256i src_reg = |
35 | _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(src)); |
36 | __m512i dst_reg_bf16 = _mm512_cvtepu16_epi32(src_reg); |
37 | __m512i dst_reg = _mm512_slli_epi32(dst_reg_bf16, 16); |
38 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), dst_reg); |
39 | } |
40 | |
41 | } // namespace |
42 | |
43 | void FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size) { |
44 | size_t i = 0; |
45 | for (i = 0; i + 16 <= size; i += 16) { |
46 | FloatToBfloat16KernelAvx512(src + i, dst + i); |
47 | } |
48 | FloatToBfloat16_avx2(src + i, dst + i, size - i); |
49 | } |
50 | |
51 | void Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size) { |
52 | size_t i = 0; |
53 | for (i = 0; i + 16 <= size; i += 16) { |
54 | Bfloat16ToFloatKernelAvx512(src + i, dst + i); |
55 | } |
56 | Bfloat16ToFloat_avx2(src + i, dst + i, size - i); |
57 | } |
58 | |
59 | } // namespace fbgemm |
60 | |