1 | #ifndef GLOW_SUPPORT_BFLOAT16_H |
2 | #define GLOW_SUPPORT_BFLOAT16_H |
3 | |
4 | #include <cmath> |
5 | #include <cstdint> |
6 | #include <iostream> |
7 | |
8 | namespace glow { |
9 | |
10 | /// Soft bfloat16. |
11 | /// This implementation uses single-precision floating point. |
12 | class alignas(2) bfloat16 { |
13 | uint16_t storage_; |
14 | |
15 | public: |
16 | static uint16_t float32_to_bfloat16_storage(float rhs) { |
17 | const float &rhs_ref = rhs; |
18 | uint32_t rhsu = *reinterpret_cast<const uint32_t *>(&rhs_ref); |
19 | uint16_t lhs = static_cast<uint16_t>(rhsu >> 16); |
20 | |
21 | if (std::isnan(rhs) && (lhs & 0x7fu) == 0) { |
22 | lhs = 0x7fc0u; // qNaN |
23 | } |
24 | |
25 | return lhs; |
26 | } |
27 | |
28 | static float bfloat16_to_float32(bfloat16 rhs) { |
29 | const uint32_t lhsu = static_cast<uint32_t>(rhs.storage_) << 16; |
30 | const uint32_t &lhsu_ref = lhsu; |
31 | float lhs = *reinterpret_cast<const float *>(&lhsu_ref); |
32 | return lhs; |
33 | } |
34 | |
35 | static bfloat16 bfloat16_from_uint16_storage(uint16_t rhs) { |
36 | bfloat16 lhs; |
37 | lhs.storage_ = rhs; |
38 | return lhs; |
39 | } |
40 | |
41 | static int fpclassify(bfloat16 rhs) { |
42 | return std::fpclassify(bfloat16_to_float32(rhs)); |
43 | } |
44 | |
45 | static bool isfinite(bfloat16 rhs) { |
46 | return (rhs.storage_ & 0x7f80u) != 0x7f80u; |
47 | } |
48 | |
49 | static bool isinf(bfloat16 rhs) { |
50 | return (rhs.storage_ & 0x7fffu) == 0x7f80u; |
51 | } |
52 | |
53 | static bool isnan(bfloat16 rhs) { return fpclassify(rhs) == FP_NAN; } |
54 | |
55 | static bool isnormal(bfloat16 rhs) { return fpclassify(rhs) == FP_NORMAL; } |
56 | |
57 | static bool signbit(bfloat16 rhs) { |
58 | return (rhs.storage_ & 0x8000u) == 0x8000u; |
59 | } |
60 | |
61 | bfloat16() : storage_{0} {} |
62 | |
63 | bfloat16(const bfloat16 &rhs) : storage_{rhs.storage_} {} |
64 | |
65 | bfloat16(float rhs) { storage_ = float32_to_bfloat16_storage(rhs); } |
66 | |
67 | uint16_t storage() const { return storage_; } |
68 | |
69 | operator float() const { return bfloat16_to_float32(*this); } |
70 | |
71 | operator double() const { |
72 | return static_cast<double>(static_cast<float>(*this)); |
73 | } |
74 | |
75 | operator bool() const { return static_cast<bool>(static_cast<float>(*this)); } |
76 | |
77 | operator int8_t() const { |
78 | return static_cast<int8_t>(static_cast<float>(*this)); |
79 | } |
80 | |
81 | operator int16_t() const { |
82 | return static_cast<int16_t>(static_cast<float>(*this)); |
83 | } |
84 | |
85 | operator int32_t() const { |
86 | return static_cast<int32_t>(static_cast<float>(*this)); |
87 | } |
88 | |
89 | operator int64_t() const { |
90 | return static_cast<int64_t>(static_cast<float>(*this)); |
91 | } |
92 | |
93 | operator uint8_t() const { |
94 | return static_cast<uint8_t>(static_cast<float>(*this)); |
95 | } |
96 | |
97 | operator uint16_t() const { |
98 | return static_cast<uint16_t>(static_cast<float>(*this)); |
99 | } |
100 | |
101 | operator uint32_t() const { |
102 | return static_cast<uint32_t>(static_cast<float>(*this)); |
103 | } |
104 | |
105 | operator uint64_t() const { |
106 | return static_cast<uint64_t>(static_cast<float>(*this)); |
107 | } |
108 | |
109 | bfloat16 operator-() { |
110 | bfloat16 lhs = |
111 | bfloat16_from_uint16_storage(static_cast<uint16_t>(storage_ ^ 0x8000u)); |
112 | return lhs; |
113 | } |
114 | |
115 | bfloat16 &operator+=(const bfloat16 &rhs) { |
116 | storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) + |
117 | static_cast<float>(rhs)); |
118 | return *this; |
119 | } |
120 | |
121 | bfloat16 &operator-=(const bfloat16 &rhs) { |
122 | storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) - |
123 | static_cast<float>(rhs)); |
124 | return *this; |
125 | } |
126 | |
127 | bfloat16 &operator*=(const bfloat16 &rhs) { |
128 | storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) * |
129 | static_cast<float>(rhs)); |
130 | return *this; |
131 | } |
132 | |
133 | bfloat16 &operator/=(const bfloat16 &rhs) { |
134 | storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) / |
135 | static_cast<float>(rhs)); |
136 | return *this; |
137 | } |
138 | |
139 | bfloat16 &operator+=(const float &rhs) { |
140 | storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) + rhs); |
141 | return *this; |
142 | } |
143 | |
144 | bfloat16 &operator-=(const float &rhs) { |
145 | storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) - rhs); |
146 | return *this; |
147 | } |
148 | |
149 | bfloat16 &operator*=(const float &rhs) { |
150 | storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) * rhs); |
151 | return *this; |
152 | } |
153 | |
154 | bfloat16 &operator/=(const float &rhs) { |
155 | storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) / rhs); |
156 | return *this; |
157 | } |
158 | |
159 | friend bfloat16 operator+(bfloat16 lhs, const bfloat16 &rhs) { |
160 | lhs += rhs; |
161 | return lhs; |
162 | } |
163 | |
164 | friend bfloat16 operator-(bfloat16 lhs, const bfloat16 &rhs) { |
165 | lhs -= rhs; |
166 | return lhs; |
167 | } |
168 | |
169 | friend bfloat16 operator*(bfloat16 lhs, const bfloat16 &rhs) { |
170 | lhs *= rhs; |
171 | return lhs; |
172 | } |
173 | |
174 | friend bfloat16 operator/(bfloat16 lhs, const bfloat16 &rhs) { |
175 | lhs /= rhs; |
176 | return lhs; |
177 | } |
178 | |
179 | friend float operator+(bfloat16 lhs, const float &rhs) { |
180 | return static_cast<float>(lhs) + rhs; |
181 | } |
182 | |
183 | friend float operator-(bfloat16 lhs, const float &rhs) { |
184 | return static_cast<float>(lhs) - rhs; |
185 | } |
186 | |
187 | friend float operator*(bfloat16 lhs, const float &rhs) { |
188 | return static_cast<float>(lhs) * rhs; |
189 | } |
190 | |
191 | friend float operator/(bfloat16 lhs, const float &rhs) { |
192 | return static_cast<float>(lhs) / rhs; |
193 | } |
194 | |
195 | friend float operator+(float lhs, const bfloat16 &rhs) { |
196 | return lhs + static_cast<float>(rhs); |
197 | } |
198 | |
199 | friend float operator-(float lhs, const bfloat16 &rhs) { |
200 | return lhs - static_cast<float>(rhs); |
201 | } |
202 | |
203 | friend float operator*(float lhs, const bfloat16 &rhs) { |
204 | return lhs * static_cast<float>(rhs); |
205 | } |
206 | |
207 | friend float operator/(float lhs, const bfloat16 &rhs) { |
208 | return lhs / static_cast<float>(rhs); |
209 | } |
210 | |
211 | friend bool operator<(const bfloat16 &lhs, const bfloat16 &rhs) { |
212 | return static_cast<float>(lhs) < static_cast<float>(rhs); |
213 | } |
214 | |
215 | friend bool operator>(const bfloat16 &lhs, const bfloat16 &rhs) { |
216 | return rhs < lhs; |
217 | } |
218 | |
219 | friend bool operator<=(const bfloat16 &lhs, const bfloat16 &rhs) { |
220 | return static_cast<float>(lhs) <= static_cast<float>(rhs); |
221 | } |
222 | |
223 | friend bool operator>=(const bfloat16 &lhs, const bfloat16 &rhs) { |
224 | return rhs <= lhs; |
225 | } |
226 | |
227 | friend bool operator==(const bfloat16 &lhs, const bfloat16 &rhs) { |
228 | return static_cast<float>(lhs) == static_cast<float>(rhs); |
229 | } |
230 | |
231 | friend bool operator!=(const bfloat16 &lhs, const bfloat16 &rhs) { |
232 | return static_cast<float>(lhs) != static_cast<float>(rhs); |
233 | } |
234 | }; |
235 | |
236 | /// Allow bfloat16_t to be passed to an ostream. |
237 | inline std::ostream &operator<<(std::ostream &os, const bfloat16 &b) { |
238 | os << static_cast<float>(b); |
239 | return os; |
240 | } |
241 | |
242 | using bfloat16_t = bfloat16; |
243 | static_assert(sizeof(bfloat16_t) == 2, "bfloat16_t must be 16 bits wide" ); |
244 | |
245 | } // namespace glow |
246 | |
247 | #endif // GLOW_SUPPORT_BFLOAT16_H |
248 | |