1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8
9#include <stdexcept>
10#include "fbgemm/Types.h"
11#include "fbgemm/Utils.h"
12
13namespace fbgemm {
14
15/**
16 * @ Transform all entries in a matrix from fp32 to bfloat16: reference
17 * implementation.
18 *
19 */
20FBGEMM_API void
21FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size);
22
23/**
24 * @ Transform all entries in a matrix from bfloat16 to fp32: reference
25 * implementation.
26 *
27 */
28FBGEMM_API void
29Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size);
30
31/**
32 * @ Transform all entries in a matrix from fp32 to bfloat16: simd
33 * implementation.
34 *
35 */
36FBGEMM_API void
37FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size);
38
39/**
40 * @ Transform all entries in a matrix from bfloat16 to fp32: simd
41 * implementation.
42 *
43 */
44FBGEMM_API void
45Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size);
46
47/**
48 * @brief AVX2 implementation to convert fp32 numbers to bf16 numbers.
49 *
50 */
51FBGEMM_API void
52FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size);
53
54/**
55 * @brief AVX512 implementation to convert fp32 numbers to bf16 numbers.
56 *
57 */
58FBGEMM_API void
59FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size);
60
61/**
62 * @brief AVX2 implementation to convert bf16 numbers to fp32 numbers.
63 *
64 */
65FBGEMM_API void
66Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size);
67
68/**
69 * @brief AVX512 implementation to convert bf16 numbers to fp32 numbers.
70 *
71 */
72FBGEMM_API void
73Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size);
74
75/**
76 * @ Transform all entries in a matrix from fp32 to float16: reference
77 * implementation.
78 *
79 * @param do_clip if true we saturate to fp16 min and max instead of generating
80 * infinities.
81 */
82FBGEMM_API void FloatToFloat16_ref(
83 const float* src,
84 float16* dst,
85 size_t size,
86 bool do_clip = false);
87
88/**
89 * @ Transform all entries in a matrix from float16 to fp32: reference
90 * implementation.
91 *
92 */
93FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, size_t size);
94
95/**
96 * @ Transform all entries in a matrix from fp32 to float16: simd
97 * implementation.
98 *
99 * @param do_clip if true we saturate to fp16 min and max instead of generating
100 * infinities.
101 */
102FBGEMM_API void FloatToFloat16_simd(
103 const float* src,
104 float16* dst,
105 size_t size,
106 bool do_clip = false);
107
108/**
109 * @ Transform all entries in a matrix from float16 to fp32: simd
110 * implementation.
111 *
112 */
113FBGEMM_API void
114Float16ToFloat_simd(const float16* src, float* dst, size_t size);
115
116/**
117 * @brief AVX2 implementation to convert fp32 numbers to fp16 numbers.
118 *
119 */
120FBGEMM_API void FloatToFloat16_avx2(
121 const float* src,
122 float16* dst,
123 size_t size,
124 bool do_clip = false);
125
126/**
127 * @brief AVX512 implementation to convert fp32 numbers to fp16 numbers.
128 *
129 */
130FBGEMM_API void FloatToFloat16_avx512(
131 const float* src,
132 float16* dst,
133 size_t size,
134 bool do_clip = false);
135
136/**
137 * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
138 *
139 */
140FBGEMM_API void
141Float16ToFloat_avx2(const float16* src, float* dst, size_t size);
142
143/**
144 * @brief AVX512 implementation to convert fp16 numbers to fp32 numbers.
145 *
146 */
147FBGEMM_API void
148Float16ToFloat_avx512(const float16* src, float* dst, size_t size);
149
150/**
151 * @brief Transform all entries in a matrix from fp32 to float16 and back to
152 * fp32.
153 */
154FBGEMM_API void RoundToFloat16(
155 const float* input,
156 float* output,
157 size_t size,
158 bool clamp = false,
159 bool clamp_denorms = false);
160
161/**
162 * @brief Quantize float32 to float8. The code is a copy of float_to_hfp8() in
163 * fbgemm_gpu/quantize_ops_utils.h
164 */
165FBGEMM_API void FloatToFloat8_ref(
166 const float input,
167 uint8_t* output,
168 int exponent_bits,
169 int exponent_bias);
170
171/**
172 * @brief Dequantize float8 to float32. The code is a copy of hf8_to_float() in
173 * fbgemm_gpu/quantize_ops_utils.h
174 */
175FBGEMM_API void Float8ToFloat_ref(
176 const uint8_t input,
177 float* output,
178 int exponent_bits,
179 int exponent_bias);
180
181} // namespace fbgemm
182