1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #if defined(__x86_64__) || defined(__i386__) || \ |
8 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
9 | #include <immintrin.h> |
10 | #endif |
11 | #include "fbgemm/FbgemmConvert.h" |
12 | |
13 | namespace fbgemm { |
14 | |
15 | namespace { |
16 | |
17 | inline __m256i QuantizeBfloat16Avx2(const __m256& x0, const __m256& x1) { |
18 | // Add 2^15 and right shift 16 to do round-nearest |
19 | __m256i y0 = _mm256_srli_epi32( |
20 | _mm256_add_epi32(_mm256_castps_si256(x0), _mm256_set1_epi32(1 << 15)), |
21 | 16); |
22 | __m256i y1 = _mm256_srli_epi32( |
23 | _mm256_add_epi32(_mm256_castps_si256(x1), _mm256_set1_epi32(1 << 15)), |
24 | 16); |
25 | // AVX2 doesn't have _mm256_cvtepi32_epi16 so we need this instruction |
26 | // sequence. |
27 | return _mm256_permute4x64_epi64(_mm256_packus_epi32(y0, y1), 0xd8); |
28 | } |
29 | |
30 | inline void FloatToBfloat16KernelAvx2(const float* src, bfloat16* dst) { |
31 | // Two float m256i -> One bfloat16 m256i |
32 | const __m256 src_reg0 = _mm256_loadu_ps(src); |
33 | const __m256 src_reg1 = _mm256_loadu_ps(src + 8); |
34 | __m256i dst_reg = QuantizeBfloat16Avx2(src_reg0, src_reg1); |
35 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg); |
36 | } |
37 | |
38 | inline void Bfloat16ToFloatKernelAvx2(const bfloat16* src, float* dst) { |
39 | // One bfloat16 m128i -> One float m256i |
40 | const __m128i src_reg = |
41 | _mm_lddqu_si128(reinterpret_cast<const __m128i*>(src)); |
42 | __m256i dst_reg_bf16 = _mm256_cvtepu16_epi32(src_reg); |
43 | __m256i dst_reg = _mm256_slli_epi32(dst_reg_bf16, 16); |
44 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), dst_reg); |
45 | } |
46 | |
47 | } // namespace |
48 | |
49 | void FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size) { |
50 | size_t i = 0; |
51 | for (i = 0; i + 8 * 2 <= size; i += 8 * 2) { |
52 | FloatToBfloat16KernelAvx2(src + i, dst + i); |
53 | } |
54 | FloatToBfloat16_ref(src + i, dst + i, size - i); |
55 | } |
56 | |
57 | void Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size) { |
58 | size_t i = 0; |
59 | for (i = 0; i + 8 <= size; i += 8) { |
60 | Bfloat16ToFloatKernelAvx2(src + i, dst + i); |
61 | } |
62 | Bfloat16ToFloat_ref(src + i, dst + i, size - i); |
63 | } |
64 | |
65 | } // namespace fbgemm |
66 | |