1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | |
9 | #include <stdexcept> |
10 | #include "fbgemm/Types.h" |
11 | #include "fbgemm/Utils.h" |
12 | |
13 | namespace fbgemm { |
14 | |
15 | /** |
16 | * @ Transform all entries in a matrix from fp32 to bfloat16: reference |
17 | * implementation. |
18 | * |
19 | */ |
20 | FBGEMM_API void |
21 | FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size); |
22 | |
23 | /** |
24 | * @ Transform all entries in a matrix from bfloat16 to fp32: reference |
25 | * implementation. |
26 | * |
27 | */ |
28 | FBGEMM_API void |
29 | Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size); |
30 | |
31 | /** |
32 | * @ Transform all entries in a matrix from fp32 to bfloat16: simd |
33 | * implementation. |
34 | * |
35 | */ |
36 | FBGEMM_API void |
37 | FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size); |
38 | |
39 | /** |
40 | * @ Transform all entries in a matrix from bfloat16 to fp32: simd |
41 | * implementation. |
42 | * |
43 | */ |
44 | FBGEMM_API void |
45 | Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size); |
46 | |
47 | /** |
48 | * @brief AVX2 implementation to convert fp32 numbers to bf16 numbers. |
49 | * |
50 | */ |
51 | FBGEMM_API void |
52 | FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size); |
53 | |
54 | /** |
55 | * @brief AVX512 implementation to convert fp32 numbers to bf16 numbers. |
56 | * |
57 | */ |
58 | FBGEMM_API void |
59 | FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size); |
60 | |
61 | /** |
62 | * @brief AVX2 implementation to convert bf16 numbers to fp32 numbers. |
63 | * |
64 | */ |
65 | FBGEMM_API void |
66 | Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size); |
67 | |
68 | /** |
69 | * @brief AVX512 implementation to convert bf16 numbers to fp32 numbers. |
70 | * |
71 | */ |
72 | FBGEMM_API void |
73 | Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size); |
74 | |
75 | /** |
76 | * @ Transform all entries in a matrix from fp32 to float16: reference |
77 | * implementation. |
78 | * |
79 | * @param do_clip if true we saturate to fp16 min and max instead of generating |
80 | * infinities. |
81 | */ |
82 | FBGEMM_API void FloatToFloat16_ref( |
83 | const float* src, |
84 | float16* dst, |
85 | size_t size, |
86 | bool do_clip = false); |
87 | |
88 | /** |
89 | * @ Transform all entries in a matrix from float16 to fp32: reference |
90 | * implementation. |
91 | * |
92 | */ |
93 | FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, size_t size); |
94 | |
95 | /** |
96 | * @ Transform all entries in a matrix from fp32 to float16: simd |
97 | * implementation. |
98 | * |
99 | * @param do_clip if true we saturate to fp16 min and max instead of generating |
100 | * infinities. |
101 | */ |
102 | FBGEMM_API void FloatToFloat16_simd( |
103 | const float* src, |
104 | float16* dst, |
105 | size_t size, |
106 | bool do_clip = false); |
107 | |
108 | /** |
109 | * @ Transform all entries in a matrix from float16 to fp32: simd |
110 | * implementation. |
111 | * |
112 | */ |
113 | FBGEMM_API void |
114 | Float16ToFloat_simd(const float16* src, float* dst, size_t size); |
115 | |
116 | /** |
117 | * @brief AVX2 implementation to convert fp32 numbers to fp16 numbers. |
118 | * |
119 | */ |
120 | FBGEMM_API void FloatToFloat16_avx2( |
121 | const float* src, |
122 | float16* dst, |
123 | size_t size, |
124 | bool do_clip = false); |
125 | |
126 | /** |
127 | * @brief AVX512 implementation to convert fp32 numbers to fp16 numbers. |
128 | * |
129 | */ |
130 | FBGEMM_API void FloatToFloat16_avx512( |
131 | const float* src, |
132 | float16* dst, |
133 | size_t size, |
134 | bool do_clip = false); |
135 | |
136 | /** |
137 | * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers. |
138 | * |
139 | */ |
140 | FBGEMM_API void |
141 | Float16ToFloat_avx2(const float16* src, float* dst, size_t size); |
142 | |
143 | /** |
144 | * @brief AVX512 implementation to convert fp16 numbers to fp32 numbers. |
145 | * |
146 | */ |
147 | FBGEMM_API void |
148 | Float16ToFloat_avx512(const float16* src, float* dst, size_t size); |
149 | |
150 | /** |
151 | * @brief Transform all entries in a matrix from fp32 to float16 and back to |
152 | * fp32. |
153 | */ |
154 | FBGEMM_API void RoundToFloat16( |
155 | const float* input, |
156 | float* output, |
157 | size_t size, |
158 | bool clamp = false, |
159 | bool clamp_denorms = false); |
160 | |
161 | /** |
162 | * @brief Quantize float32 to float8. The code is a copy of float_to_hfp8() in |
163 | * fbgemm_gpu/quantize_ops_utils.h |
164 | */ |
165 | FBGEMM_API void FloatToFloat8_ref( |
166 | const float input, |
167 | uint8_t* output, |
168 | int exponent_bits, |
169 | int exponent_bias); |
170 | |
171 | /** |
172 | * @brief Dequantize float8 to float32. The code is a copy of hf8_to_float() in |
173 | * fbgemm_gpu/quantize_ops_utils.h |
174 | */ |
175 | FBGEMM_API void Float8ToFloat_ref( |
176 | const uint8_t input, |
177 | float* output, |
178 | int exponent_bits, |
179 | int exponent_bias); |
180 | |
181 | } // namespace fbgemm |
182 | |