1#pragma once
2
3#include <c10/macros/Macros.h>
4#include <cstring>
5#include <limits>
6
7#ifdef __CUDACC__
8#include <cuda_fp16.h>
9#endif
10
11#ifdef __HIPCC__
12#include <hip/hip_fp16.h>
13#endif
14
15#if defined(CL_SYCL_LANGUAGE_VERSION)
16#include <CL/sycl.hpp> // for SYCL 1.2.1
17#elif defined(SYCL_LANGUAGE_VERSION)
18#include <sycl/sycl.hpp> // for SYCL 2020
19#endif
20
21C10_CLANG_DIAGNOSTIC_PUSH()
22#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
23C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
24#endif
25
26namespace c10 {
27
28/// Constructors
29
30inline C10_HOST_DEVICE Half::Half(float value)
31 :
32#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
33 x(__half_as_short(__float2half(value)))
34#elif defined(__SYCL_DEVICE_ONLY__)
35 x(sycl::bit_cast<uint16_t>(sycl::half(value)))
36#else
37 x(detail::fp16_ieee_from_fp32_value(value))
38#endif
39{
40}
41
42/// Implicit conversions
43
44inline C10_HOST_DEVICE Half::operator float() const {
45#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
46 return __half2float(*reinterpret_cast<const __half*>(&x));
47#elif defined(__SYCL_DEVICE_ONLY__)
48 return float(sycl::bit_cast<sycl::half>(x));
49#else
50 return detail::fp16_ieee_to_fp32_value(x);
51#endif
52}
53
54#if defined(__CUDACC__) || defined(__HIPCC__)
55inline C10_HOST_DEVICE Half::Half(const __half& value) {
56 x = *reinterpret_cast<const unsigned short*>(&value);
57}
58inline C10_HOST_DEVICE Half::operator __half() const {
59 return *reinterpret_cast<const __half*>(&x);
60}
61#endif
62
63#ifdef SYCL_LANGUAGE_VERSION
64inline C10_HOST_DEVICE Half::Half(const sycl::half& value) {
65 x = *reinterpret_cast<const unsigned short*>(&value);
66}
67inline C10_HOST_DEVICE Half::operator sycl::half() const {
68 return *reinterpret_cast<const sycl::half*>(&x);
69}
70#endif
71
72// CUDA intrinsics
73
74#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \
75 (defined(__clang__) && defined(__CUDA__))
76inline __device__ Half __ldg(const Half* ptr) {
77 return __ldg(reinterpret_cast<const __half*>(ptr));
78}
79#endif
80
81/// Arithmetic
82
83inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
84 return static_cast<float>(a) + static_cast<float>(b);
85}
86
87inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
88 return static_cast<float>(a) - static_cast<float>(b);
89}
90
91inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) {
92 return static_cast<float>(a) * static_cast<float>(b);
93}
94
95inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b)
96 __ubsan_ignore_float_divide_by_zero__ {
97 return static_cast<float>(a) / static_cast<float>(b);
98}
99
100inline C10_HOST_DEVICE Half operator-(const Half& a) {
101#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
102 defined(__HIP_DEVICE_COMPILE__)
103 return __hneg(a);
104#elif defined(__SYCL_DEVICE_ONLY__)
105 return -sycl::bit_cast<sycl::half>(a);
106#else
107 return -static_cast<float>(a);
108#endif
109}
110
111inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) {
112 a = a + b;
113 return a;
114}
115
116inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) {
117 a = a - b;
118 return a;
119}
120
121inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) {
122 a = a * b;
123 return a;
124}
125
126inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) {
127 a = a / b;
128 return a;
129}
130
131/// Arithmetic with floats
132
133inline C10_HOST_DEVICE float operator+(Half a, float b) {
134 return static_cast<float>(a) + b;
135}
136inline C10_HOST_DEVICE float operator-(Half a, float b) {
137 return static_cast<float>(a) - b;
138}
139inline C10_HOST_DEVICE float operator*(Half a, float b) {
140 return static_cast<float>(a) * b;
141}
142inline C10_HOST_DEVICE float operator/(Half a, float b)
143 __ubsan_ignore_float_divide_by_zero__ {
144 return static_cast<float>(a) / b;
145}
146
147inline C10_HOST_DEVICE float operator+(float a, Half b) {
148 return a + static_cast<float>(b);
149}
150inline C10_HOST_DEVICE float operator-(float a, Half b) {
151 return a - static_cast<float>(b);
152}
153inline C10_HOST_DEVICE float operator*(float a, Half b) {
154 return a * static_cast<float>(b);
155}
156inline C10_HOST_DEVICE float operator/(float a, Half b)
157 __ubsan_ignore_float_divide_by_zero__ {
158 return a / static_cast<float>(b);
159}
160
161inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) {
162 return a += static_cast<float>(b);
163}
164inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) {
165 return a -= static_cast<float>(b);
166}
167inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) {
168 return a *= static_cast<float>(b);
169}
170inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) {
171 return a /= static_cast<float>(b);
172}
173
174/// Arithmetic with doubles
175
176inline C10_HOST_DEVICE double operator+(Half a, double b) {
177 return static_cast<double>(a) + b;
178}
179inline C10_HOST_DEVICE double operator-(Half a, double b) {
180 return static_cast<double>(a) - b;
181}
182inline C10_HOST_DEVICE double operator*(Half a, double b) {
183 return static_cast<double>(a) * b;
184}
185inline C10_HOST_DEVICE double operator/(Half a, double b)
186 __ubsan_ignore_float_divide_by_zero__ {
187 return static_cast<double>(a) / b;
188}
189
190inline C10_HOST_DEVICE double operator+(double a, Half b) {
191 return a + static_cast<double>(b);
192}
193inline C10_HOST_DEVICE double operator-(double a, Half b) {
194 return a - static_cast<double>(b);
195}
196inline C10_HOST_DEVICE double operator*(double a, Half b) {
197 return a * static_cast<double>(b);
198}
199inline C10_HOST_DEVICE double operator/(double a, Half b)
200 __ubsan_ignore_float_divide_by_zero__ {
201 return a / static_cast<double>(b);
202}
203
204/// Arithmetic with ints
205
206inline C10_HOST_DEVICE Half operator+(Half a, int b) {
207 return a + static_cast<Half>(b);
208}
209inline C10_HOST_DEVICE Half operator-(Half a, int b) {
210 return a - static_cast<Half>(b);
211}
212inline C10_HOST_DEVICE Half operator*(Half a, int b) {
213 return a * static_cast<Half>(b);
214}
215inline C10_HOST_DEVICE Half operator/(Half a, int b) {
216 return a / static_cast<Half>(b);
217}
218
219inline C10_HOST_DEVICE Half operator+(int a, Half b) {
220 return static_cast<Half>(a) + b;
221}
222inline C10_HOST_DEVICE Half operator-(int a, Half b) {
223 return static_cast<Half>(a) - b;
224}
225inline C10_HOST_DEVICE Half operator*(int a, Half b) {
226 return static_cast<Half>(a) * b;
227}
228inline C10_HOST_DEVICE Half operator/(int a, Half b) {
229 return static_cast<Half>(a) / b;
230}
231
232//// Arithmetic with int64_t
233
234inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) {
235 return a + static_cast<Half>(b);
236}
237inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) {
238 return a - static_cast<Half>(b);
239}
240inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) {
241 return a * static_cast<Half>(b);
242}
243inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) {
244 return a / static_cast<Half>(b);
245}
246
247inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) {
248 return static_cast<Half>(a) + b;
249}
250inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) {
251 return static_cast<Half>(a) - b;
252}
253inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) {
254 return static_cast<Half>(a) * b;
255}
256inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) {
257 return static_cast<Half>(a) / b;
258}
259
260/// NOTE: we do not define comparisons directly and instead rely on the implicit
261/// conversion from c10::Half to float.
262
263} // namespace c10
264
265namespace std {
266
267template <>
268class numeric_limits<c10::Half> {
269 public:
270 static constexpr bool is_specialized = true;
271 static constexpr bool is_signed = true;
272 static constexpr bool is_integer = false;
273 static constexpr bool is_exact = false;
274 static constexpr bool has_infinity = true;
275 static constexpr bool has_quiet_NaN = true;
276 static constexpr bool has_signaling_NaN = true;
277 static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
278 static constexpr auto has_denorm_loss =
279 numeric_limits<float>::has_denorm_loss;
280 static constexpr auto round_style = numeric_limits<float>::round_style;
281 static constexpr bool is_iec559 = true;
282 static constexpr bool is_bounded = true;
283 static constexpr bool is_modulo = false;
284 static constexpr int digits = 11;
285 static constexpr int digits10 = 3;
286 static constexpr int max_digits10 = 5;
287 static constexpr int radix = 2;
288 static constexpr int min_exponent = -13;
289 static constexpr int min_exponent10 = -4;
290 static constexpr int max_exponent = 16;
291 static constexpr int max_exponent10 = 4;
292 static constexpr auto traps = numeric_limits<float>::traps;
293 static constexpr auto tinyness_before =
294 numeric_limits<float>::tinyness_before;
295 static constexpr c10::Half min() {
296 return c10::Half(0x0400, c10::Half::from_bits());
297 }
298 static constexpr c10::Half lowest() {
299 return c10::Half(0xFBFF, c10::Half::from_bits());
300 }
301 static constexpr c10::Half max() {
302 return c10::Half(0x7BFF, c10::Half::from_bits());
303 }
304 static constexpr c10::Half epsilon() {
305 return c10::Half(0x1400, c10::Half::from_bits());
306 }
307 static constexpr c10::Half round_error() {
308 return c10::Half(0x3800, c10::Half::from_bits());
309 }
310 static constexpr c10::Half infinity() {
311 return c10::Half(0x7C00, c10::Half::from_bits());
312 }
313 static constexpr c10::Half quiet_NaN() {
314 return c10::Half(0x7E00, c10::Half::from_bits());
315 }
316 static constexpr c10::Half signaling_NaN() {
317 return c10::Half(0x7D00, c10::Half::from_bits());
318 }
319 static constexpr c10::Half denorm_min() {
320 return c10::Half(0x0001, c10::Half::from_bits());
321 }
322};
323
324} // namespace std
325
326C10_CLANG_DIAGNOSTIC_POP()
327