1 | #pragma once |
2 | |
3 | // Defines the bloat16 type (brain floating-point). This representation uses |
4 | // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. |
5 | |
6 | #include <c10/macros/Macros.h> |
7 | #include <cmath> |
8 | #include <cstring> |
9 | |
10 | #if defined(__CUDACC__) && !defined(USE_ROCM) |
11 | #include <cuda_bf16.h> |
12 | #endif |
13 | |
14 | #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
15 | #if defined(CL_SYCL_LANGUAGE_VERSION) |
16 | #include <CL/sycl.hpp> // for SYCL 1.2.1 |
17 | #else |
18 | #include <sycl/sycl.hpp> // for SYCL 2020 |
19 | #endif |
20 | #include <ext/oneapi/bfloat16.hpp> |
21 | #endif |
22 | |
23 | namespace c10 { |
24 | |
25 | namespace detail { |
26 | inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { |
27 | float res = 0; |
28 | uint32_t tmp = src; |
29 | tmp <<= 16; |
30 | |
31 | #if defined(USE_ROCM) |
32 | float* tempRes; |
33 | |
34 | // We should be using memcpy in order to respect the strict aliasing rule |
35 | // but it fails in the HIP environment. |
36 | tempRes = reinterpret_cast<float*>(&tmp); |
37 | res = *tempRes; |
38 | #else |
39 | std::memcpy(&res, &tmp, sizeof(tmp)); |
40 | #endif |
41 | |
42 | return res; |
43 | } |
44 | |
45 | inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { |
46 | uint32_t res = 0; |
47 | |
48 | #if defined(USE_ROCM) |
49 | // We should be using memcpy in order to respect the strict aliasing rule |
50 | // but it fails in the HIP environment. |
51 | uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src); |
52 | res = *tempRes; |
53 | #else |
54 | std::memcpy(&res, &src, sizeof(res)); |
55 | #endif |
56 | |
57 | return res >> 16; |
58 | } |
59 | |
60 | inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { |
61 | #if defined(USE_ROCM) |
62 | if (src != src) { |
63 | #elif defined(_MSC_VER) |
64 | if (isnan(src)) { |
65 | #else |
66 | if (std::isnan(src)) { |
67 | #endif |
68 | return UINT16_C(0x7FC0); |
69 | } else { |
70 | union { |
71 | uint32_t U32; |
72 | float F32; |
73 | }; |
74 | |
75 | F32 = src; |
76 | uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); |
77 | return static_cast<uint16_t>((U32 + rounding_bias) >> 16); |
78 | } |
79 | } |
80 | } // namespace detail |
81 | |
82 | struct alignas(2) BFloat16 { |
83 | uint16_t x; |
84 | |
85 | // HIP wants __host__ __device__ tag, CUDA does not |
86 | #if defined(USE_ROCM) |
87 | C10_HOST_DEVICE BFloat16() = default; |
88 | #else |
89 | BFloat16() = default; |
90 | #endif |
91 | |
92 | struct from_bits_t {}; |
93 | static constexpr C10_HOST_DEVICE from_bits_t from_bits() { |
94 | return from_bits_t(); |
95 | } |
96 | |
97 | constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) |
98 | : x(bits){}; |
99 | inline C10_HOST_DEVICE BFloat16(float value); |
100 | inline C10_HOST_DEVICE operator float() const; |
101 | |
102 | #if defined(__CUDACC__) && !defined(USE_ROCM) |
103 | inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); |
104 | explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; |
105 | #endif |
106 | |
107 | #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
108 | inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); |
109 | explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; |
110 | #endif |
111 | }; |
112 | |
113 | } // namespace c10 |
114 | |
115 | #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep |
116 | |