1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | |
9 | #include <cstdint> |
10 | #include <cstdlib> |
11 | #include <cstring> |
12 | |
13 | namespace fbgemm { |
14 | |
15 | using float16 = std::uint16_t; |
16 | using bfloat16 = std::uint16_t; |
17 | |
18 | // Round to nearest even |
19 | static inline float16 cpu_float2half_rn(float f) { |
20 | float16 ret; |
21 | |
22 | static_assert( |
23 | sizeof(unsigned int) == sizeof(float), |
24 | "Programming error sizeof(unsigned int) != sizeof(float)" ); |
25 | |
26 | unsigned* xp = reinterpret_cast<unsigned int*>(&f); |
27 | unsigned x = *xp; |
28 | unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; |
29 | unsigned sign, exponent, mantissa; |
30 | |
31 | // Get rid of +NaN/-NaN case first. |
32 | if (u > 0x7f800000) { |
33 | ret = 0x7fffU; |
34 | return ret; |
35 | } |
36 | |
37 | sign = ((x >> 16) & 0x8000); |
38 | |
39 | // Get rid of +Inf/-Inf, +0/-0. |
40 | if (u > 0x477fefff) { |
41 | ret = static_cast<float16>(sign | 0x7c00U); |
42 | return ret; |
43 | } |
44 | if (u < 0x33000001) { |
45 | ret = static_cast<float16>(sign | 0x0000); |
46 | return ret; |
47 | } |
48 | |
49 | exponent = ((u >> 23) & 0xff); |
50 | mantissa = (u & 0x7fffff); |
51 | |
52 | if (exponent > 0x70) { |
53 | shift = 13; |
54 | exponent -= 0x70; |
55 | } else { |
56 | shift = 0x7e - exponent; |
57 | exponent = 0; |
58 | mantissa |= 0x800000; |
59 | } |
60 | lsb = (1 << shift); |
61 | lsb_s1 = (lsb >> 1); |
62 | lsb_m1 = (lsb - 1); |
63 | |
64 | // Round to nearest even. |
65 | remainder = (mantissa & lsb_m1); |
66 | mantissa >>= shift; |
67 | if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { |
68 | ++mantissa; |
69 | if (!(mantissa & 0x3ff)) { |
70 | ++exponent; |
71 | mantissa = 0; |
72 | } |
73 | } |
74 | |
75 | ret = static_cast<float16>(sign | (exponent << 10) | mantissa); |
76 | |
77 | return ret; |
78 | } |
79 | |
80 | // Round to zero |
81 | static inline float16 cpu_float2half_rz(float f) { |
82 | float16 ret; |
83 | |
84 | static_assert( |
85 | sizeof(unsigned int) == sizeof(float), |
86 | "Programming error sizeof(unsigned int) != sizeof(float)" ); |
87 | |
88 | unsigned* xp = reinterpret_cast<unsigned int*>(&f); |
89 | unsigned x = *xp; |
90 | unsigned u = (x & 0x7fffffff); |
91 | unsigned shift, sign, exponent, mantissa; |
92 | |
93 | // Get rid of +NaN/-NaN case first. |
94 | if (u > 0x7f800000) { |
95 | ret = static_cast<float16>(0x7fffU); |
96 | return ret; |
97 | } |
98 | |
99 | sign = ((x >> 16) & 0x8000); |
100 | |
101 | // Get rid of +Inf/-Inf, +0/-0. |
102 | if (u > 0x477fefff) { |
103 | ret = static_cast<float16>(sign | 0x7c00U); |
104 | return ret; |
105 | } |
106 | if (u < 0x33000001) { |
107 | ret = static_cast<float16>(sign | 0x0000); |
108 | return ret; |
109 | } |
110 | |
111 | exponent = ((u >> 23) & 0xff); |
112 | mantissa = (u & 0x7fffff); |
113 | |
114 | if (exponent > 0x70) { |
115 | shift = 13; |
116 | exponent -= 0x70; |
117 | } else { |
118 | shift = 0x7e - exponent; |
119 | exponent = 0; |
120 | mantissa |= 0x800000; |
121 | } |
122 | |
123 | // Round to zero. |
124 | mantissa >>= shift; |
125 | |
126 | ret = static_cast<float16>(sign | (exponent << 10) | mantissa); |
127 | |
128 | return ret; |
129 | } |
130 | |
131 | static inline float cpu_half2float(float16 h) { |
132 | unsigned sign = ((h >> 15) & 1); |
133 | unsigned exponent = ((h >> 10) & 0x1f); |
134 | unsigned mantissa = ((h & 0x3ff) << 13); |
135 | |
136 | if (exponent == 0x1f) { /* NaN or Inf */ |
137 | mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); |
138 | exponent = 0xff; |
139 | } else if (!exponent) { /* Denorm or Zero */ |
140 | if (mantissa) { |
141 | unsigned int msb; |
142 | exponent = 0x71; |
143 | do { |
144 | msb = (mantissa & 0x400000); |
145 | mantissa <<= 1; /* normalize */ |
146 | --exponent; |
147 | } while (!msb); |
148 | mantissa &= 0x7fffff; /* 1.mantissa is implicit */ |
149 | } |
150 | } else { |
151 | exponent += 0x70; |
152 | } |
153 | |
154 | unsigned i = ((sign << 31) | (exponent << 23) | mantissa); |
155 | float ret; |
156 | memcpy(&ret, &i, sizeof(i)); |
157 | return ret; |
158 | } |
159 | |
160 | static inline float cpu_bf162float(bfloat16 src) { |
161 | float ret; |
162 | uint32_t val_fp32 = |
163 | static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(&src)[0]) << 16; |
164 | memcpy(&ret, &val_fp32, sizeof(ret)); |
165 | return ret; |
166 | } |
167 | |
168 | static inline bfloat16 cpu_float2bfloat16(float src) { |
169 | uint32_t temp; |
170 | memcpy(&temp, &src, sizeof(temp)); |
171 | return (temp + (1 << 15)) >> 16; |
172 | } |
173 | |
174 | } // namespace fbgemm |
175 | |