1#pragma once
2
3#include <c10/macros/Macros.h>
4#include <limits>
5
6C10_CLANG_DIAGNOSTIC_PUSH()
7#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
8C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
9#endif
10
11#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
12#if defined(CL_SYCL_LANGUAGE_VERSION)
13#include <CL/sycl.hpp> // for SYCL 1.2.1
14#else
15#include <sycl/sycl.hpp> // for SYCL 2020
16#endif
17#include <ext/oneapi/bfloat16.hpp>
18#endif
19
20namespace c10 {
21
22/// Constructors
23inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
24 :
25#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
26 __CUDA_ARCH__ >= 800
27 x(__bfloat16_as_ushort(__float2bfloat16(value)))
28#elif defined(__SYCL_DEVICE_ONLY__) && \
29 defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
30 x(sycl::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
31#else
32 // RNE by default
33 x(detail::round_to_nearest_even(value))
34#endif
35{
36}
37
38/// Implicit conversions
39inline C10_HOST_DEVICE BFloat16::operator float() const {
40#if defined(__CUDACC__) && !defined(USE_ROCM)
41 return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
42#elif defined(__SYCL_DEVICE_ONLY__) && \
43 defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
44 return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
45#else
46 return detail::f32_from_bits(x);
47#endif
48}
49
50#if defined(__CUDACC__) && !defined(USE_ROCM)
51inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
52 x = *reinterpret_cast<const unsigned short*>(&value);
53}
54inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
55 return *reinterpret_cast<const __nv_bfloat16*>(&x);
56}
57#endif
58
59#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
60inline C10_HOST_DEVICE BFloat16::BFloat16(
61 const sycl::ext::oneapi::bfloat16& value) {
62 x = *reinterpret_cast<const unsigned short*>(&value);
63}
64inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
65 return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
66}
67#endif
68
69// CUDA intrinsics
70
71#if defined(__CUDACC__) || defined(__HIPCC__)
72inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
73#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
74 return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
75#else
76 return *ptr;
77#endif
78}
79#endif
80
81/// Arithmetic
82
83inline C10_HOST_DEVICE BFloat16
84operator+(const BFloat16& a, const BFloat16& b) {
85 return static_cast<float>(a) + static_cast<float>(b);
86}
87
88inline C10_HOST_DEVICE BFloat16
89operator-(const BFloat16& a, const BFloat16& b) {
90 return static_cast<float>(a) - static_cast<float>(b);
91}
92
93inline C10_HOST_DEVICE BFloat16
94operator*(const BFloat16& a, const BFloat16& b) {
95 return static_cast<float>(a) * static_cast<float>(b);
96}
97
98inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
99 __ubsan_ignore_float_divide_by_zero__ {
100 return static_cast<float>(a) / static_cast<float>(b);
101}
102
103inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
104 return -static_cast<float>(a);
105}
106
107inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
108 a = a + b;
109 return a;
110}
111
112inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
113 a = a - b;
114 return a;
115}
116
117inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
118 a = a * b;
119 return a;
120}
121
122inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
123 a = a / b;
124 return a;
125}
126
127inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
128 a.x = a.x | b.x;
129 return a;
130}
131
132inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
133 a.x = a.x ^ b.x;
134 return a;
135}
136
137inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
138 a.x = a.x & b.x;
139 return a;
140}
141
142/// Arithmetic with floats
143
144inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
145 return static_cast<float>(a) + b;
146}
147inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
148 return static_cast<float>(a) - b;
149}
150inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
151 return static_cast<float>(a) * b;
152}
153inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
154 return static_cast<float>(a) / b;
155}
156
157inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
158 return a + static_cast<float>(b);
159}
160inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
161 return a - static_cast<float>(b);
162}
163inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
164 return a * static_cast<float>(b);
165}
166inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
167 return a / static_cast<float>(b);
168}
169
170inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
171 return a += static_cast<float>(b);
172}
173inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
174 return a -= static_cast<float>(b);
175}
176inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
177 return a *= static_cast<float>(b);
178}
179inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
180 return a /= static_cast<float>(b);
181}
182
183/// Arithmetic with doubles
184
185inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
186 return static_cast<double>(a) + b;
187}
188inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
189 return static_cast<double>(a) - b;
190}
191inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
192 return static_cast<double>(a) * b;
193}
194inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
195 return static_cast<double>(a) / b;
196}
197
198inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
199 return a + static_cast<double>(b);
200}
201inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
202 return a - static_cast<double>(b);
203}
204inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
205 return a * static_cast<double>(b);
206}
207inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
208 return a / static_cast<double>(b);
209}
210
211/// Arithmetic with ints
212
213inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
214 return a + static_cast<BFloat16>(b);
215}
216inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
217 return a - static_cast<BFloat16>(b);
218}
219inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
220 return a * static_cast<BFloat16>(b);
221}
222inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
223 return a / static_cast<BFloat16>(b);
224}
225
226inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
227 return static_cast<BFloat16>(a) + b;
228}
229inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
230 return static_cast<BFloat16>(a) - b;
231}
232inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
233 return static_cast<BFloat16>(a) * b;
234}
235inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
236 return static_cast<BFloat16>(a) / b;
237}
238
239//// Arithmetic with int64_t
240
241inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
242 return a + static_cast<BFloat16>(b);
243}
244inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
245 return a - static_cast<BFloat16>(b);
246}
247inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
248 return a * static_cast<BFloat16>(b);
249}
250inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
251 return a / static_cast<BFloat16>(b);
252}
253
254inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
255 return static_cast<BFloat16>(a) + b;
256}
257inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
258 return static_cast<BFloat16>(a) - b;
259}
260inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
261 return static_cast<BFloat16>(a) * b;
262}
263inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
264 return static_cast<BFloat16>(a) / b;
265}
266
267// Overloading < and > operators, because std::max and std::min use them.
268
269inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
270 return float(lhs) > float(rhs);
271}
272
273inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
274 return float(lhs) < float(rhs);
275}
276
277} // namespace c10
278
279namespace std {
280
281template <>
282class numeric_limits<c10::BFloat16> {
283 public:
284 static constexpr bool is_signed = true;
285 static constexpr bool is_specialized = true;
286 static constexpr bool is_integer = false;
287 static constexpr bool is_exact = false;
288 static constexpr bool has_infinity = true;
289 static constexpr bool has_quiet_NaN = true;
290 static constexpr bool has_signaling_NaN = true;
291 static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
292 static constexpr auto has_denorm_loss =
293 numeric_limits<float>::has_denorm_loss;
294 static constexpr auto round_style = numeric_limits<float>::round_style;
295 static constexpr bool is_iec559 = false;
296 static constexpr bool is_bounded = true;
297 static constexpr bool is_modulo = false;
298 static constexpr int digits = 8;
299 static constexpr int digits10 = 2;
300 static constexpr int max_digits10 = 4;
301 static constexpr int radix = 2;
302 static constexpr int min_exponent = -125;
303 static constexpr int min_exponent10 = -37;
304 static constexpr int max_exponent = 128;
305 static constexpr int max_exponent10 = 38;
306 static constexpr auto traps = numeric_limits<float>::traps;
307 static constexpr auto tinyness_before =
308 numeric_limits<float>::tinyness_before;
309
310 static constexpr c10::BFloat16 min() {
311 return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
312 }
313 static constexpr c10::BFloat16 lowest() {
314 return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
315 }
316 static constexpr c10::BFloat16 max() {
317 return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
318 }
319 static constexpr c10::BFloat16 epsilon() {
320 return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
321 }
322 static constexpr c10::BFloat16 round_error() {
323 return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
324 }
325 static constexpr c10::BFloat16 infinity() {
326 return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
327 }
328 static constexpr c10::BFloat16 quiet_NaN() {
329 return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
330 }
331 static constexpr c10::BFloat16 signaling_NaN() {
332 return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
333 }
334 static constexpr c10::BFloat16 denorm_min() {
335 return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
336 }
337};
338
339} // namespace std
340
341C10_CLANG_DIAGNOSTIC_POP()
342