1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#if defined(__x86_64__) || defined(__i386__) || \
8 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
9#include <immintrin.h>
10#endif
11#include "fbgemm/FbgemmConvert.h"
12
13namespace fbgemm {
14
15namespace {
16
17inline void FloatToFloat16KernelAvx2(const float* src, float16* dst) {
18 __m256 float_vector = _mm256_loadu_ps(src);
19 __m128i half_vector = _mm256_cvtps_ph(
20 float_vector, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
21 _mm_storeu_si128((__m128i*)dst, half_vector);
22}
23
24inline void FloatToFloat16KernelAvx2WithClip(const float* src, float16* dst) {
25 constexpr float FP16_MAX = 65504.f;
26 __m256 neg_fp16_max_vector = _mm256_set1_ps(-FP16_MAX);
27 __m256 pos_fp16_max_vector = _mm256_set1_ps(FP16_MAX);
28
29 __m256 float_vector = _mm256_loadu_ps(src);
30
31 // Do the clipping.
32 float_vector = _mm256_max_ps(
33 neg_fp16_max_vector, _mm256_min_ps(float_vector, pos_fp16_max_vector));
34
35 __m128i half_vector = _mm256_cvtps_ph(
36 float_vector, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
37 _mm_storeu_si128((__m128i*)dst, half_vector);
38}
39
40inline void Float16ToFloatKernelAvx2(const float16* src, float* dst) {
41 __m128i half_vector = _mm_loadu_si128((__m128i*)src);
42 __m256 float_vector = _mm256_cvtph_ps(half_vector);
43 _mm256_storeu_ps(dst, float_vector);
44}
45
46} // namespace
47
48void FloatToFloat16_avx2(
49 const float* src,
50 float16* dst,
51 size_t size,
52 bool do_clip) {
53 if (do_clip) {
54 size_t i = 0;
55 for (i = 0; i + 8 <= size; i += 8) {
56 FloatToFloat16KernelAvx2WithClip(src + i, dst + i);
57 }
58 FloatToFloat16_ref(src + i, dst + i, size - i, do_clip);
59 } else {
60 size_t i = 0;
61 for (i = 0; i + 8 <= size; i += 8) {
62 FloatToFloat16KernelAvx2(src + i, dst + i);
63 }
64 FloatToFloat16_ref(src + i, dst + i, size - i);
65 }
66}
67
68void Float16ToFloat_avx2(const float16* src, float* dst, size_t size) {
69 size_t i = 0;
70 for (i = 0; i + 8 <= size; i += 8) {
71 Float16ToFloatKernelAvx2(src + i, dst + i);
72 }
73 Float16ToFloat_ref(src + i, dst + i, size - i);
74}
75
76} // namespace fbgemm
77