1#pragma once
2
3// Defines the bloat16 type (brain floating-point). This representation uses
4// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
5
6#include <c10/macros/Macros.h>
7#include <cmath>
8#include <cstring>
9
10#if defined(__CUDACC__) && !defined(USE_ROCM)
11#include <cuda_bf16.h>
12#endif
13
14#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
15#if defined(CL_SYCL_LANGUAGE_VERSION)
16#include <CL/sycl.hpp> // for SYCL 1.2.1
17#else
18#include <sycl/sycl.hpp> // for SYCL 2020
19#endif
20#include <ext/oneapi/bfloat16.hpp>
21#endif
22
23namespace c10 {
24
25namespace detail {
26inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
27 float res = 0;
28 uint32_t tmp = src;
29 tmp <<= 16;
30
31#if defined(USE_ROCM)
32 float* tempRes;
33
34 // We should be using memcpy in order to respect the strict aliasing rule
35 // but it fails in the HIP environment.
36 tempRes = reinterpret_cast<float*>(&tmp);
37 res = *tempRes;
38#else
39 std::memcpy(&res, &tmp, sizeof(tmp));
40#endif
41
42 return res;
43}
44
45inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
46 uint32_t res = 0;
47
48#if defined(USE_ROCM)
49 // We should be using memcpy in order to respect the strict aliasing rule
50 // but it fails in the HIP environment.
51 uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
52 res = *tempRes;
53#else
54 std::memcpy(&res, &src, sizeof(res));
55#endif
56
57 return res >> 16;
58}
59
60inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
61#if defined(USE_ROCM)
62 if (src != src) {
63#elif defined(_MSC_VER)
64 if (isnan(src)) {
65#else
66 if (std::isnan(src)) {
67#endif
68 return UINT16_C(0x7FC0);
69 } else {
70 union {
71 uint32_t U32;
72 float F32;
73 };
74
75 F32 = src;
76 uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
77 return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
78 }
79}
80} // namespace detail
81
82struct alignas(2) BFloat16 {
83 uint16_t x;
84
85 // HIP wants __host__ __device__ tag, CUDA does not
86#if defined(USE_ROCM)
87 C10_HOST_DEVICE BFloat16() = default;
88#else
89 BFloat16() = default;
90#endif
91
92 struct from_bits_t {};
93 static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
94 return from_bits_t();
95 }
96
97 constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
98 : x(bits){};
99 inline C10_HOST_DEVICE BFloat16(float value);
100 inline C10_HOST_DEVICE operator float() const;
101
102#if defined(__CUDACC__) && !defined(USE_ROCM)
103 inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
104 explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
105#endif
106
107#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
108 inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
109 explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
110#endif
111};
112
113} // namespace c10
114
115#include <c10/util/BFloat16-inl.h> // IWYU pragma: keep
116