1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmConvert.h"
9
10#include "./RefImplementations.h"
11
12#ifdef USE_MKL
13#include <mkl.h>
14#endif
15
16#ifdef USE_BLAS
17#if __APPLE__
18// not sure whether need to differentiate TARGET_OS_MAC or TARGET_OS_IPHONE,
19// etc.
20#include <Accelerate/Accelerate.h>
21#else
22#include <cblas.h>
23#endif
24#endif
25
26#include <cpuinfo.h>
27#include <memory>
28#include <utility>
29#include <vector>
30
31namespace fbgemm {
32
33void FloatToFloat16_simd(
34 const float* src,
35 float16* dst,
36 size_t size,
37 bool do_clip) {
38 // Run time CPU detection
39 if (cpuinfo_initialize()) {
40 if (fbgemmHasAvx512Support()) {
41 FloatToFloat16_avx512(src, dst, size, do_clip);
42 } else if (fbgemmHasAvx2Support()) {
43 FloatToFloat16_avx2(src, dst, size, do_clip);
44 } else {
45 FloatToFloat16_ref(src, dst, size, do_clip);
46 return;
47 }
48 } else {
49 throw std::runtime_error("Failed to initialize cpuinfo!");
50 }
51}
52
53void Float16ToFloat_simd(const float16* src, float* dst, size_t size) {
54 // Run time CPU detection
55 if (cpuinfo_initialize()) {
56 if (fbgemmHasAvx512Support()) {
57 Float16ToFloat_avx512(src, dst, size);
58 } else if (fbgemmHasAvx2Support()) {
59 Float16ToFloat_avx2(src, dst, size);
60 } else {
61 Float16ToFloat_ref(src, dst, size);
62 return;
63 }
64 } else {
65 throw std::runtime_error("Failed to initialize cpuinfo!");
66 }
67}
68
69void RoundToFloat16(
70 const float* input,
71 float* output,
72 size_t size,
73 bool clamp,
74 bool clamp_denorms) {
75 std::vector<fbgemm::float16> data_fp16(size);
76 FloatToFloat16_simd(input, &(data_fp16[0]), size, /*do_clip=*/clamp);
77 Float16ToFloat_simd(&(data_fp16[0]), output, size);
78 if (clamp_denorms) {
79 // FloatToFloat16_simd always preserve fp16 denorm, so we need to manually
80 // clamp.
81 union epsilon_t {
82 float f;
83 uint32_t i;
84 };
85 union epsilon_t epsilon;
86 epsilon.i = 0x38800000u; // 1 / 16384
87 for (size_t i = 0; i < size; ++i) {
88 if (std::abs(output[i]) < epsilon.f) {
89 output[i] = 0.0;
90 }
91 }
92 }
93}
94
95} // namespace fbgemm
96