1#ifndef GLOW_SUPPORT_BFLOAT16_H
2#define GLOW_SUPPORT_BFLOAT16_H
3
4#include <cmath>
5#include <cstdint>
6#include <iostream>
7
8namespace glow {
9
10/// Soft bfloat16.
11/// This implementation uses single-precision floating point.
12class alignas(2) bfloat16 {
13 uint16_t storage_;
14
15public:
16 static uint16_t float32_to_bfloat16_storage(float rhs) {
17 const float &rhs_ref = rhs;
18 uint32_t rhsu = *reinterpret_cast<const uint32_t *>(&rhs_ref);
19 uint16_t lhs = static_cast<uint16_t>(rhsu >> 16);
20
21 if (std::isnan(rhs) && (lhs & 0x7fu) == 0) {
22 lhs = 0x7fc0u; // qNaN
23 }
24
25 return lhs;
26 }
27
28 static float bfloat16_to_float32(bfloat16 rhs) {
29 const uint32_t lhsu = static_cast<uint32_t>(rhs.storage_) << 16;
30 const uint32_t &lhsu_ref = lhsu;
31 float lhs = *reinterpret_cast<const float *>(&lhsu_ref);
32 return lhs;
33 }
34
35 static bfloat16 bfloat16_from_uint16_storage(uint16_t rhs) {
36 bfloat16 lhs;
37 lhs.storage_ = rhs;
38 return lhs;
39 }
40
41 static int fpclassify(bfloat16 rhs) {
42 return std::fpclassify(bfloat16_to_float32(rhs));
43 }
44
45 static bool isfinite(bfloat16 rhs) {
46 return (rhs.storage_ & 0x7f80u) != 0x7f80u;
47 }
48
49 static bool isinf(bfloat16 rhs) {
50 return (rhs.storage_ & 0x7fffu) == 0x7f80u;
51 }
52
53 static bool isnan(bfloat16 rhs) { return fpclassify(rhs) == FP_NAN; }
54
55 static bool isnormal(bfloat16 rhs) { return fpclassify(rhs) == FP_NORMAL; }
56
57 static bool signbit(bfloat16 rhs) {
58 return (rhs.storage_ & 0x8000u) == 0x8000u;
59 }
60
61 bfloat16() : storage_{0} {}
62
63 bfloat16(const bfloat16 &rhs) : storage_{rhs.storage_} {}
64
65 bfloat16(float rhs) { storage_ = float32_to_bfloat16_storage(rhs); }
66
67 uint16_t storage() const { return storage_; }
68
69 operator float() const { return bfloat16_to_float32(*this); }
70
71 operator double() const {
72 return static_cast<double>(static_cast<float>(*this));
73 }
74
75 operator bool() const { return static_cast<bool>(static_cast<float>(*this)); }
76
77 operator int8_t() const {
78 return static_cast<int8_t>(static_cast<float>(*this));
79 }
80
81 operator int16_t() const {
82 return static_cast<int16_t>(static_cast<float>(*this));
83 }
84
85 operator int32_t() const {
86 return static_cast<int32_t>(static_cast<float>(*this));
87 }
88
89 operator int64_t() const {
90 return static_cast<int64_t>(static_cast<float>(*this));
91 }
92
93 operator uint8_t() const {
94 return static_cast<uint8_t>(static_cast<float>(*this));
95 }
96
97 operator uint16_t() const {
98 return static_cast<uint16_t>(static_cast<float>(*this));
99 }
100
101 operator uint32_t() const {
102 return static_cast<uint32_t>(static_cast<float>(*this));
103 }
104
105 operator uint64_t() const {
106 return static_cast<uint64_t>(static_cast<float>(*this));
107 }
108
109 bfloat16 operator-() {
110 bfloat16 lhs =
111 bfloat16_from_uint16_storage(static_cast<uint16_t>(storage_ ^ 0x8000u));
112 return lhs;
113 }
114
115 bfloat16 &operator+=(const bfloat16 &rhs) {
116 storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) +
117 static_cast<float>(rhs));
118 return *this;
119 }
120
121 bfloat16 &operator-=(const bfloat16 &rhs) {
122 storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) -
123 static_cast<float>(rhs));
124 return *this;
125 }
126
127 bfloat16 &operator*=(const bfloat16 &rhs) {
128 storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) *
129 static_cast<float>(rhs));
130 return *this;
131 }
132
133 bfloat16 &operator/=(const bfloat16 &rhs) {
134 storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) /
135 static_cast<float>(rhs));
136 return *this;
137 }
138
139 bfloat16 &operator+=(const float &rhs) {
140 storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) + rhs);
141 return *this;
142 }
143
144 bfloat16 &operator-=(const float &rhs) {
145 storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) - rhs);
146 return *this;
147 }
148
149 bfloat16 &operator*=(const float &rhs) {
150 storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) * rhs);
151 return *this;
152 }
153
154 bfloat16 &operator/=(const float &rhs) {
155 storage_ = float32_to_bfloat16_storage(static_cast<float>(*this) / rhs);
156 return *this;
157 }
158
159 friend bfloat16 operator+(bfloat16 lhs, const bfloat16 &rhs) {
160 lhs += rhs;
161 return lhs;
162 }
163
164 friend bfloat16 operator-(bfloat16 lhs, const bfloat16 &rhs) {
165 lhs -= rhs;
166 return lhs;
167 }
168
169 friend bfloat16 operator*(bfloat16 lhs, const bfloat16 &rhs) {
170 lhs *= rhs;
171 return lhs;
172 }
173
174 friend bfloat16 operator/(bfloat16 lhs, const bfloat16 &rhs) {
175 lhs /= rhs;
176 return lhs;
177 }
178
179 friend float operator+(bfloat16 lhs, const float &rhs) {
180 return static_cast<float>(lhs) + rhs;
181 }
182
183 friend float operator-(bfloat16 lhs, const float &rhs) {
184 return static_cast<float>(lhs) - rhs;
185 }
186
187 friend float operator*(bfloat16 lhs, const float &rhs) {
188 return static_cast<float>(lhs) * rhs;
189 }
190
191 friend float operator/(bfloat16 lhs, const float &rhs) {
192 return static_cast<float>(lhs) / rhs;
193 }
194
195 friend float operator+(float lhs, const bfloat16 &rhs) {
196 return lhs + static_cast<float>(rhs);
197 }
198
199 friend float operator-(float lhs, const bfloat16 &rhs) {
200 return lhs - static_cast<float>(rhs);
201 }
202
203 friend float operator*(float lhs, const bfloat16 &rhs) {
204 return lhs * static_cast<float>(rhs);
205 }
206
207 friend float operator/(float lhs, const bfloat16 &rhs) {
208 return lhs / static_cast<float>(rhs);
209 }
210
211 friend bool operator<(const bfloat16 &lhs, const bfloat16 &rhs) {
212 return static_cast<float>(lhs) < static_cast<float>(rhs);
213 }
214
215 friend bool operator>(const bfloat16 &lhs, const bfloat16 &rhs) {
216 return rhs < lhs;
217 }
218
219 friend bool operator<=(const bfloat16 &lhs, const bfloat16 &rhs) {
220 return static_cast<float>(lhs) <= static_cast<float>(rhs);
221 }
222
223 friend bool operator>=(const bfloat16 &lhs, const bfloat16 &rhs) {
224 return rhs <= lhs;
225 }
226
227 friend bool operator==(const bfloat16 &lhs, const bfloat16 &rhs) {
228 return static_cast<float>(lhs) == static_cast<float>(rhs);
229 }
230
231 friend bool operator!=(const bfloat16 &lhs, const bfloat16 &rhs) {
232 return static_cast<float>(lhs) != static_cast<float>(rhs);
233 }
234};
235
236/// Allow bfloat16_t to be passed to an ostream.
237inline std::ostream &operator<<(std::ostream &os, const bfloat16 &b) {
238 os << static_cast<float>(b);
239 return os;
240}
241
242using bfloat16_t = bfloat16;
243static_assert(sizeof(bfloat16_t) == 2, "bfloat16_t must be 16 bits wide");
244
245} // namespace glow
246
247#endif // GLOW_SUPPORT_BFLOAT16_H
248