1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #if defined(__x86_64__) || defined(__i386__) || \ |
8 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
9 | #include <immintrin.h> |
10 | #endif |
11 | #include "fbgemm/FbgemmConvert.h" |
12 | |
13 | namespace fbgemm { |
14 | |
15 | namespace { |
16 | |
17 | inline void FloatToFloat16KernelAvx2(const float* src, float16* dst) { |
18 | __m256 float_vector = _mm256_loadu_ps(src); |
19 | __m128i half_vector = _mm256_cvtps_ph( |
20 | float_vector, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
21 | _mm_storeu_si128((__m128i*)dst, half_vector); |
22 | } |
23 | |
24 | inline void FloatToFloat16KernelAvx2WithClip(const float* src, float16* dst) { |
25 | constexpr float FP16_MAX = 65504.f; |
26 | __m256 neg_fp16_max_vector = _mm256_set1_ps(-FP16_MAX); |
27 | __m256 pos_fp16_max_vector = _mm256_set1_ps(FP16_MAX); |
28 | |
29 | __m256 float_vector = _mm256_loadu_ps(src); |
30 | |
31 | // Do the clipping. |
32 | float_vector = _mm256_max_ps( |
33 | neg_fp16_max_vector, _mm256_min_ps(float_vector, pos_fp16_max_vector)); |
34 | |
35 | __m128i half_vector = _mm256_cvtps_ph( |
36 | float_vector, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
37 | _mm_storeu_si128((__m128i*)dst, half_vector); |
38 | } |
39 | |
40 | inline void Float16ToFloatKernelAvx2(const float16* src, float* dst) { |
41 | __m128i half_vector = _mm_loadu_si128((__m128i*)src); |
42 | __m256 float_vector = _mm256_cvtph_ps(half_vector); |
43 | _mm256_storeu_ps(dst, float_vector); |
44 | } |
45 | |
46 | } // namespace |
47 | |
48 | void FloatToFloat16_avx2( |
49 | const float* src, |
50 | float16* dst, |
51 | size_t size, |
52 | bool do_clip) { |
53 | if (do_clip) { |
54 | size_t i = 0; |
55 | for (i = 0; i + 8 <= size; i += 8) { |
56 | FloatToFloat16KernelAvx2WithClip(src + i, dst + i); |
57 | } |
58 | FloatToFloat16_ref(src + i, dst + i, size - i, do_clip); |
59 | } else { |
60 | size_t i = 0; |
61 | for (i = 0; i + 8 <= size; i += 8) { |
62 | FloatToFloat16KernelAvx2(src + i, dst + i); |
63 | } |
64 | FloatToFloat16_ref(src + i, dst + i, size - i); |
65 | } |
66 | } |
67 | |
68 | void Float16ToFloat_avx2(const float16* src, float* dst, size_t size) { |
69 | size_t i = 0; |
70 | for (i = 0; i + 8 <= size; i += 8) { |
71 | Float16ToFloatKernelAvx2(src + i, dst + i); |
72 | } |
73 | Float16ToFloat_ref(src + i, dst + i, size - i); |
74 | } |
75 | |
76 | } // namespace fbgemm |
77 | |