1 | #pragma once |
2 | |
3 | #include <c10/macros/Macros.h> |
4 | #include <cstring> |
5 | #include <limits> |
6 | |
7 | #ifdef __CUDACC__ |
8 | #include <cuda_fp16.h> |
9 | #endif |
10 | |
11 | #ifdef __HIPCC__ |
12 | #include <hip/hip_fp16.h> |
13 | #endif |
14 | |
15 | #if defined(CL_SYCL_LANGUAGE_VERSION) |
16 | #include <CL/sycl.hpp> // for SYCL 1.2.1 |
17 | #elif defined(SYCL_LANGUAGE_VERSION) |
18 | #include <sycl/sycl.hpp> // for SYCL 2020 |
19 | #endif |
20 | |
21 | C10_CLANG_DIAGNOSTIC_PUSH() |
22 | #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") |
23 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion" ) |
24 | #endif |
25 | |
26 | namespace c10 { |
27 | |
28 | /// Constructors |
29 | |
30 | inline C10_HOST_DEVICE Half::Half(float value) |
31 | : |
32 | #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) |
33 | x(__half_as_short(__float2half(value))) |
34 | #elif defined(__SYCL_DEVICE_ONLY__) |
35 | x(sycl::bit_cast<uint16_t>(sycl::half(value))) |
36 | #else |
37 | x(detail::fp16_ieee_from_fp32_value(value)) |
38 | #endif |
39 | { |
40 | } |
41 | |
42 | /// Implicit conversions |
43 | |
44 | inline C10_HOST_DEVICE Half::operator float() const { |
45 | #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) |
46 | return __half2float(*reinterpret_cast<const __half*>(&x)); |
47 | #elif defined(__SYCL_DEVICE_ONLY__) |
48 | return float(sycl::bit_cast<sycl::half>(x)); |
49 | #else |
50 | return detail::fp16_ieee_to_fp32_value(x); |
51 | #endif |
52 | } |
53 | |
54 | #if defined(__CUDACC__) || defined(__HIPCC__) |
55 | inline C10_HOST_DEVICE Half::Half(const __half& value) { |
56 | x = *reinterpret_cast<const unsigned short*>(&value); |
57 | } |
58 | inline C10_HOST_DEVICE Half::operator __half() const { |
59 | return *reinterpret_cast<const __half*>(&x); |
60 | } |
61 | #endif |
62 | |
63 | #ifdef SYCL_LANGUAGE_VERSION |
64 | inline C10_HOST_DEVICE Half::Half(const sycl::half& value) { |
65 | x = *reinterpret_cast<const unsigned short*>(&value); |
66 | } |
67 | inline C10_HOST_DEVICE Half::operator sycl::half() const { |
68 | return *reinterpret_cast<const sycl::half*>(&x); |
69 | } |
70 | #endif |
71 | |
72 | // CUDA intrinsics |
73 | |
74 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \ |
75 | (defined(__clang__) && defined(__CUDA__)) |
76 | inline __device__ Half __ldg(const Half* ptr) { |
77 | return __ldg(reinterpret_cast<const __half*>(ptr)); |
78 | } |
79 | #endif |
80 | |
81 | /// Arithmetic |
82 | |
83 | inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) { |
84 | return static_cast<float>(a) + static_cast<float>(b); |
85 | } |
86 | |
87 | inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) { |
88 | return static_cast<float>(a) - static_cast<float>(b); |
89 | } |
90 | |
91 | inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { |
92 | return static_cast<float>(a) * static_cast<float>(b); |
93 | } |
94 | |
95 | inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) |
96 | __ubsan_ignore_float_divide_by_zero__ { |
97 | return static_cast<float>(a) / static_cast<float>(b); |
98 | } |
99 | |
100 | inline C10_HOST_DEVICE Half operator-(const Half& a) { |
101 | #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ |
102 | defined(__HIP_DEVICE_COMPILE__) |
103 | return __hneg(a); |
104 | #elif defined(__SYCL_DEVICE_ONLY__) |
105 | return -sycl::bit_cast<sycl::half>(a); |
106 | #else |
107 | return -static_cast<float>(a); |
108 | #endif |
109 | } |
110 | |
111 | inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { |
112 | a = a + b; |
113 | return a; |
114 | } |
115 | |
116 | inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { |
117 | a = a - b; |
118 | return a; |
119 | } |
120 | |
121 | inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { |
122 | a = a * b; |
123 | return a; |
124 | } |
125 | |
126 | inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { |
127 | a = a / b; |
128 | return a; |
129 | } |
130 | |
131 | /// Arithmetic with floats |
132 | |
133 | inline C10_HOST_DEVICE float operator+(Half a, float b) { |
134 | return static_cast<float>(a) + b; |
135 | } |
136 | inline C10_HOST_DEVICE float operator-(Half a, float b) { |
137 | return static_cast<float>(a) - b; |
138 | } |
139 | inline C10_HOST_DEVICE float operator*(Half a, float b) { |
140 | return static_cast<float>(a) * b; |
141 | } |
142 | inline C10_HOST_DEVICE float operator/(Half a, float b) |
143 | __ubsan_ignore_float_divide_by_zero__ { |
144 | return static_cast<float>(a) / b; |
145 | } |
146 | |
147 | inline C10_HOST_DEVICE float operator+(float a, Half b) { |
148 | return a + static_cast<float>(b); |
149 | } |
150 | inline C10_HOST_DEVICE float operator-(float a, Half b) { |
151 | return a - static_cast<float>(b); |
152 | } |
153 | inline C10_HOST_DEVICE float operator*(float a, Half b) { |
154 | return a * static_cast<float>(b); |
155 | } |
156 | inline C10_HOST_DEVICE float operator/(float a, Half b) |
157 | __ubsan_ignore_float_divide_by_zero__ { |
158 | return a / static_cast<float>(b); |
159 | } |
160 | |
161 | inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) { |
162 | return a += static_cast<float>(b); |
163 | } |
164 | inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) { |
165 | return a -= static_cast<float>(b); |
166 | } |
167 | inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) { |
168 | return a *= static_cast<float>(b); |
169 | } |
170 | inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) { |
171 | return a /= static_cast<float>(b); |
172 | } |
173 | |
174 | /// Arithmetic with doubles |
175 | |
176 | inline C10_HOST_DEVICE double operator+(Half a, double b) { |
177 | return static_cast<double>(a) + b; |
178 | } |
179 | inline C10_HOST_DEVICE double operator-(Half a, double b) { |
180 | return static_cast<double>(a) - b; |
181 | } |
182 | inline C10_HOST_DEVICE double operator*(Half a, double b) { |
183 | return static_cast<double>(a) * b; |
184 | } |
185 | inline C10_HOST_DEVICE double operator/(Half a, double b) |
186 | __ubsan_ignore_float_divide_by_zero__ { |
187 | return static_cast<double>(a) / b; |
188 | } |
189 | |
190 | inline C10_HOST_DEVICE double operator+(double a, Half b) { |
191 | return a + static_cast<double>(b); |
192 | } |
193 | inline C10_HOST_DEVICE double operator-(double a, Half b) { |
194 | return a - static_cast<double>(b); |
195 | } |
196 | inline C10_HOST_DEVICE double operator*(double a, Half b) { |
197 | return a * static_cast<double>(b); |
198 | } |
199 | inline C10_HOST_DEVICE double operator/(double a, Half b) |
200 | __ubsan_ignore_float_divide_by_zero__ { |
201 | return a / static_cast<double>(b); |
202 | } |
203 | |
204 | /// Arithmetic with ints |
205 | |
206 | inline C10_HOST_DEVICE Half operator+(Half a, int b) { |
207 | return a + static_cast<Half>(b); |
208 | } |
209 | inline C10_HOST_DEVICE Half operator-(Half a, int b) { |
210 | return a - static_cast<Half>(b); |
211 | } |
212 | inline C10_HOST_DEVICE Half operator*(Half a, int b) { |
213 | return a * static_cast<Half>(b); |
214 | } |
215 | inline C10_HOST_DEVICE Half operator/(Half a, int b) { |
216 | return a / static_cast<Half>(b); |
217 | } |
218 | |
219 | inline C10_HOST_DEVICE Half operator+(int a, Half b) { |
220 | return static_cast<Half>(a) + b; |
221 | } |
222 | inline C10_HOST_DEVICE Half operator-(int a, Half b) { |
223 | return static_cast<Half>(a) - b; |
224 | } |
225 | inline C10_HOST_DEVICE Half operator*(int a, Half b) { |
226 | return static_cast<Half>(a) * b; |
227 | } |
228 | inline C10_HOST_DEVICE Half operator/(int a, Half b) { |
229 | return static_cast<Half>(a) / b; |
230 | } |
231 | |
232 | //// Arithmetic with int64_t |
233 | |
234 | inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) { |
235 | return a + static_cast<Half>(b); |
236 | } |
237 | inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) { |
238 | return a - static_cast<Half>(b); |
239 | } |
240 | inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) { |
241 | return a * static_cast<Half>(b); |
242 | } |
243 | inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) { |
244 | return a / static_cast<Half>(b); |
245 | } |
246 | |
247 | inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) { |
248 | return static_cast<Half>(a) + b; |
249 | } |
250 | inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) { |
251 | return static_cast<Half>(a) - b; |
252 | } |
253 | inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) { |
254 | return static_cast<Half>(a) * b; |
255 | } |
256 | inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) { |
257 | return static_cast<Half>(a) / b; |
258 | } |
259 | |
260 | /// NOTE: we do not define comparisons directly and instead rely on the implicit |
261 | /// conversion from c10::Half to float. |
262 | |
263 | } // namespace c10 |
264 | |
265 | namespace std { |
266 | |
267 | template <> |
268 | class numeric_limits<c10::Half> { |
269 | public: |
270 | static constexpr bool is_specialized = true; |
271 | static constexpr bool is_signed = true; |
272 | static constexpr bool is_integer = false; |
273 | static constexpr bool is_exact = false; |
274 | static constexpr bool has_infinity = true; |
275 | static constexpr bool has_quiet_NaN = true; |
276 | static constexpr bool has_signaling_NaN = true; |
277 | static constexpr auto has_denorm = numeric_limits<float>::has_denorm; |
278 | static constexpr auto has_denorm_loss = |
279 | numeric_limits<float>::has_denorm_loss; |
280 | static constexpr auto round_style = numeric_limits<float>::round_style; |
281 | static constexpr bool is_iec559 = true; |
282 | static constexpr bool is_bounded = true; |
283 | static constexpr bool is_modulo = false; |
284 | static constexpr int digits = 11; |
285 | static constexpr int digits10 = 3; |
286 | static constexpr int max_digits10 = 5; |
287 | static constexpr int radix = 2; |
288 | static constexpr int min_exponent = -13; |
289 | static constexpr int min_exponent10 = -4; |
290 | static constexpr int max_exponent = 16; |
291 | static constexpr int max_exponent10 = 4; |
292 | static constexpr auto traps = numeric_limits<float>::traps; |
293 | static constexpr auto tinyness_before = |
294 | numeric_limits<float>::tinyness_before; |
295 | static constexpr c10::Half min() { |
296 | return c10::Half(0x0400, c10::Half::from_bits()); |
297 | } |
298 | static constexpr c10::Half lowest() { |
299 | return c10::Half(0xFBFF, c10::Half::from_bits()); |
300 | } |
301 | static constexpr c10::Half max() { |
302 | return c10::Half(0x7BFF, c10::Half::from_bits()); |
303 | } |
304 | static constexpr c10::Half epsilon() { |
305 | return c10::Half(0x1400, c10::Half::from_bits()); |
306 | } |
307 | static constexpr c10::Half round_error() { |
308 | return c10::Half(0x3800, c10::Half::from_bits()); |
309 | } |
310 | static constexpr c10::Half infinity() { |
311 | return c10::Half(0x7C00, c10::Half::from_bits()); |
312 | } |
313 | static constexpr c10::Half quiet_NaN() { |
314 | return c10::Half(0x7E00, c10::Half::from_bits()); |
315 | } |
316 | static constexpr c10::Half signaling_NaN() { |
317 | return c10::Half(0x7D00, c10::Half::from_bits()); |
318 | } |
319 | static constexpr c10::Half denorm_min() { |
320 | return c10::Half(0x0001, c10::Half::from_bits()); |
321 | } |
322 | }; |
323 | |
324 | } // namespace std |
325 | |
326 | C10_CLANG_DIAGNOSTIC_POP() |
327 | |