1 | #pragma once |
2 | |
3 | /// Defines the Half type (half-precision floating-point) including conversions |
4 | /// to standard C types and basic arithmetic operations. Note that arithmetic |
5 | /// operations are implemented by converting to floating point and |
6 | /// performing the operation in float32, instead of using CUDA half intrinsics. |
7 | /// Most uses of this type within ATen are memory bound, including the |
8 | /// element-wise kernels, and the half intrinsics aren't efficient on all GPUs. |
9 | /// If you are writing a compute bound kernel, you can use the CUDA half |
10 | /// intrinsics directly on the Half type from device code. |
11 | |
12 | #include <c10/macros/Macros.h> |
13 | #include <c10/util/C++17.h> |
14 | #include <c10/util/TypeSafeSignMath.h> |
15 | #include <c10/util/complex.h> |
16 | #include <type_traits> |
17 | |
18 | #if defined(__cplusplus) && (__cplusplus >= 201103L) |
19 | #include <cmath> |
20 | #include <cstdint> |
21 | #elif !defined(__OPENCL_VERSION__) |
22 | #include <math.h> |
23 | #include <stdint.h> |
24 | #endif |
25 | |
26 | #ifdef _MSC_VER |
27 | #include <intrin.h> |
28 | #endif |
29 | |
30 | #include <complex> |
31 | #include <cstdint> |
32 | #include <cstring> |
33 | #include <iosfwd> |
34 | #include <limits> |
35 | #include <sstream> |
36 | #include <stdexcept> |
37 | #include <string> |
38 | #include <utility> |
39 | |
40 | #ifdef __CUDACC__ |
41 | #include <cuda_fp16.h> |
42 | #endif |
43 | |
44 | #ifdef __HIPCC__ |
45 | #include <hip/hip_fp16.h> |
46 | #endif |
47 | |
48 | #if defined(CL_SYCL_LANGUAGE_VERSION) |
49 | #include <CL/sycl.hpp> // for SYCL 1.2.1 |
50 | #elif defined(SYCL_LANGUAGE_VERSION) |
51 | #include <sycl/sycl.hpp> // for SYCL 2020 |
52 | #endif |
53 | |
54 | // Standard check for compiling CUDA with clang |
55 | #if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__) |
56 | #define C10_DEVICE_HOST_FUNCTION __device__ __host__ |
57 | #else |
58 | #define C10_DEVICE_HOST_FUNCTION |
59 | #endif |
60 | |
61 | #include <typeinfo> // operator typeid |
62 | |
63 | namespace c10 { |
64 | |
65 | namespace detail { |
66 | |
67 | C10_DEVICE_HOST_FUNCTION inline float fp32_from_bits(uint32_t w) { |
68 | #if defined(__OPENCL_VERSION__) |
69 | return as_float(w); |
70 | #elif defined(__CUDA_ARCH__) |
71 | return __uint_as_float((unsigned int)w); |
72 | #elif defined(__INTEL_COMPILER) |
73 | return _castu32_f32(w); |
74 | #else |
75 | union { |
76 | uint32_t as_bits; |
77 | float as_value; |
78 | } fp32 = {w}; |
79 | return fp32.as_value; |
80 | #endif |
81 | } |
82 | |
83 | C10_DEVICE_HOST_FUNCTION inline uint32_t fp32_to_bits(float f) { |
84 | #if defined(__OPENCL_VERSION__) |
85 | return as_uint(f); |
86 | #elif defined(__CUDA_ARCH__) |
87 | return (uint32_t)__float_as_uint(f); |
88 | #elif defined(__INTEL_COMPILER) |
89 | return _castf32_u32(f); |
90 | #else |
91 | union { |
92 | float as_value; |
93 | uint32_t as_bits; |
94 | } fp32 = {f}; |
95 | return fp32.as_bits; |
96 | #endif |
97 | } |
98 | |
99 | /* |
100 | * Convert a 16-bit floating-point number in IEEE half-precision format, in bit |
101 | * representation, to a 32-bit floating-point number in IEEE single-precision |
102 | * format, in bit representation. |
103 | * |
104 | * @note The implementation doesn't use any floating-point operations. |
105 | */ |
106 | inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { |
107 | /* |
108 | * Extend the half-precision floating-point number to 32 bits and shift to the |
109 | * upper part of the 32-bit word: |
110 | * +---+-----+------------+-------------------+ |
111 | * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| |
112 | * +---+-----+------------+-------------------+ |
113 | * Bits 31 26-30 16-25 0-15 |
114 | * |
115 | * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 |
116 | * - zero bits. |
117 | */ |
118 | const uint32_t w = (uint32_t)h << 16; |
119 | /* |
120 | * Extract the sign of the input number into the high bit of the 32-bit word: |
121 | * |
122 | * +---+----------------------------------+ |
123 | * | S |0000000 00000000 00000000 00000000| |
124 | * +---+----------------------------------+ |
125 | * Bits 31 0-31 |
126 | */ |
127 | const uint32_t sign = w & UINT32_C(0x80000000); |
128 | /* |
129 | * Extract mantissa and biased exponent of the input number into the bits 0-30 |
130 | * of the 32-bit word: |
131 | * |
132 | * +---+-----+------------+-------------------+ |
133 | * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| |
134 | * +---+-----+------------+-------------------+ |
135 | * Bits 30 27-31 17-26 0-16 |
136 | */ |
137 | const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); |
138 | /* |
139 | * Renorm shift is the number of bits to shift mantissa left to make the |
140 | * half-precision number normalized. If the initial number is normalized, some |
141 | * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case |
142 | * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note |
143 | * that if we shift denormalized nonsign by renorm_shift, the unit bit of |
144 | * mantissa will shift into exponent, turning the biased exponent into 1, and |
145 | * making mantissa normalized (i.e. without leading 1). |
146 | */ |
147 | #ifdef _MSC_VER |
148 | unsigned long nonsign_bsr; |
149 | _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); |
150 | uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; |
151 | #else |
152 | uint32_t renorm_shift = __builtin_clz(nonsign); |
153 | #endif |
154 | renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; |
155 | /* |
156 | * Iff half-precision number has exponent of 15, the addition overflows |
157 | * it into bit 31, and the subsequent shift turns the high 9 bits |
158 | * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number |
159 | * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise |
160 | */ |
161 | const int32_t inf_nan_mask = |
162 | ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); |
163 | /* |
164 | * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 |
165 | * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 |
166 | * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == |
167 | * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) |
168 | * 0x00000000 otherwise |
169 | */ |
170 | const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; |
171 | /* |
172 | * 1. Shift nonsign left by renorm_shift to normalize it (if the input |
173 | * was denormal) |
174 | * 2. Shift nonsign right by 3 so the exponent (5 bits originally) |
175 | * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high |
176 | * bits of the 23-bit mantissa of IEEE single-precision number. |
177 | * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the |
178 | * different in exponent bias (0x7F for single-precision number less 0xF |
179 | * for half-precision number). |
180 | * 4. Subtract renorm_shift from the exponent (starting at bit 23) to |
181 | * account for renormalization. As renorm_shift is less than 0x70, this |
182 | * can be combined with step 3. |
183 | * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the |
184 | * input was NaN or infinity. |
185 | * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent |
186 | * into zero if the input was zero. |
187 | * 7. Combine with the sign of the input number. |
188 | */ |
189 | return sign | |
190 | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | |
191 | inf_nan_mask) & |
192 | ~zero_mask); |
193 | } |
194 | |
195 | /* |
196 | * Convert a 16-bit floating-point number in IEEE half-precision format, in bit |
197 | * representation, to a 32-bit floating-point number in IEEE single-precision |
198 | * format. |
199 | * |
200 | * @note The implementation relies on IEEE-like (no assumption about rounding |
201 | * mode and no operations on denormals) floating-point operations and bitcasts |
202 | * between integer and floating-point variables. |
203 | */ |
204 | inline float fp16_ieee_to_fp32_value(uint16_t h) { |
205 | /* |
206 | * Extend the half-precision floating-point number to 32 bits and shift to the |
207 | * upper part of the 32-bit word: |
208 | * +---+-----+------------+-------------------+ |
209 | * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| |
210 | * +---+-----+------------+-------------------+ |
211 | * Bits 31 26-30 16-25 0-15 |
212 | * |
213 | * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 |
214 | * - zero bits. |
215 | */ |
216 | const uint32_t w = (uint32_t)h << 16; |
217 | /* |
218 | * Extract the sign of the input number into the high bit of the 32-bit word: |
219 | * |
220 | * +---+----------------------------------+ |
221 | * | S |0000000 00000000 00000000 00000000| |
222 | * +---+----------------------------------+ |
223 | * Bits 31 0-31 |
224 | */ |
225 | const uint32_t sign = w & UINT32_C(0x80000000); |
226 | /* |
227 | * Extract mantissa and biased exponent of the input number into the high bits |
228 | * of the 32-bit word: |
229 | * |
230 | * +-----+------------+---------------------+ |
231 | * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| |
232 | * +-----+------------+---------------------+ |
233 | * Bits 27-31 17-26 0-16 |
234 | */ |
235 | const uint32_t two_w = w + w; |
236 | |
237 | /* |
238 | * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become |
239 | * mantissa and exponent of a single-precision floating-point number: |
240 | * |
241 | * S|Exponent | Mantissa |
242 | * +-+---+-----+------------+----------------+ |
243 | * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| |
244 | * +-+---+-----+------------+----------------+ |
245 | * Bits | 23-31 | 0-22 |
246 | * |
247 | * Next, there are some adjustments to the exponent: |
248 | * - The exponent needs to be corrected by the difference in exponent bias |
249 | * between single-precision and half-precision formats (0x7F - 0xF = 0x70) |
250 | * - Inf and NaN values in the inputs should become Inf and NaN values after |
251 | * conversion to the single-precision number. Therefore, if the biased |
252 | * exponent of the half-precision input was 0x1F (max possible value), the |
253 | * biased exponent of the single-precision output must be 0xFF (max possible |
254 | * value). We do this correction in two steps: |
255 | * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset |
256 | * below) rather than by 0x70 suggested by the difference in the exponent bias |
257 | * (see above). |
258 | * - Then we multiply the single-precision result of exponent adjustment by |
259 | * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the |
260 | * necessary exponent adjustment by 0x70 due to difference in exponent bias. |
261 | * The floating-point multiplication hardware would ensure than Inf and |
262 | * NaN would retain their value on at least partially IEEE754-compliant |
263 | * implementations. |
264 | * |
265 | * Note that the above operations do not handle denormal inputs (where biased |
266 | * exponent == 0). However, they also do not operate on denormal inputs, and |
267 | * do not produce denormal results. |
268 | */ |
269 | constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; |
270 | // const float exp_scale = 0x1.0p-112f; |
271 | constexpr uint32_t scale_bits = (uint32_t)15 << 23; |
272 | float exp_scale_val; |
273 | std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); |
274 | const float exp_scale = exp_scale_val; |
275 | const float normalized_value = |
276 | fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; |
277 | |
278 | /* |
279 | * Convert denormalized half-precision inputs into single-precision results |
280 | * (always normalized). Zero inputs are also handled here. |
281 | * |
282 | * In a denormalized number the biased exponent is zero, and mantissa has |
283 | * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word. |
284 | * |
285 | * zeros | mantissa |
286 | * +---------------------------+------------+ |
287 | * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| |
288 | * +---------------------------+------------+ |
289 | * Bits 10-31 0-9 |
290 | * |
291 | * Now, remember that denormalized half-precision numbers are represented as: |
292 | * FP16 = mantissa * 2**(-24). |
293 | * The trick is to construct a normalized single-precision number with the |
294 | * same mantissa and thehalf-precision input and with an exponent which would |
295 | * scale the corresponding mantissa bits to 2**(-24). A normalized |
296 | * single-precision floating-point number is represented as: FP32 = (1 + |
297 | * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased |
298 | * exponent is 126, a unit change in the mantissa of the input denormalized |
299 | * half-precision number causes a change of the constructud single-precision |
300 | * number by 2**(-24), i.e. the same amount. |
301 | * |
302 | * The last step is to adjust the bias of the constructed single-precision |
303 | * number. When the input half-precision number is zero, the constructed |
304 | * single-precision number has the value of FP32 = 1 * 2**(126 - 127) = |
305 | * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed |
306 | * single-precision number to get the numerical equivalent of the input |
307 | * half-precision number. |
308 | */ |
309 | constexpr uint32_t magic_mask = UINT32_C(126) << 23; |
310 | constexpr float magic_bias = 0.5f; |
311 | const float denormalized_value = |
312 | fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; |
313 | |
314 | /* |
315 | * - Choose either results of conversion of input as a normalized number, or |
316 | * as a denormalized number, depending on the input exponent. The variable |
317 | * two_w contains input exponent in bits 27-31, therefore if its smaller than |
318 | * 2**27, the input is either a denormal number, or zero. |
319 | * - Combine the result of conversion of exponent and mantissa with the sign |
320 | * of the input number. |
321 | */ |
322 | constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; |
323 | const uint32_t result = sign | |
324 | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) |
325 | : fp32_to_bits(normalized_value)); |
326 | return fp32_from_bits(result); |
327 | } |
328 | |
329 | /* |
330 | * Convert a 32-bit floating-point number in IEEE single-precision format to a |
331 | * 16-bit floating-point number in IEEE half-precision format, in bit |
332 | * representation. |
333 | * |
334 | * @note The implementation relies on IEEE-like (no assumption about rounding |
335 | * mode and no operations on denormals) floating-point operations and bitcasts |
336 | * between integer and floating-point variables. |
337 | */ |
338 | inline uint16_t fp16_ieee_from_fp32_value(float f) { |
339 | // const float scale_to_inf = 0x1.0p+112f; |
340 | // const float scale_to_zero = 0x1.0p-110f; |
341 | constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; |
342 | constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; |
343 | float scale_to_inf_val, scale_to_zero_val; |
344 | std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); |
345 | std::memcpy( |
346 | &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); |
347 | const float scale_to_inf = scale_to_inf_val; |
348 | const float scale_to_zero = scale_to_zero_val; |
349 | |
350 | #if defined(_MSC_VER) && _MSC_VER == 1916 |
351 | float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; |
352 | #else |
353 | float base = (fabsf(f) * scale_to_inf) * scale_to_zero; |
354 | #endif |
355 | |
356 | const uint32_t w = fp32_to_bits(f); |
357 | const uint32_t shl1_w = w + w; |
358 | const uint32_t sign = w & UINT32_C(0x80000000); |
359 | uint32_t bias = shl1_w & UINT32_C(0xFF000000); |
360 | if (bias < UINT32_C(0x71000000)) { |
361 | bias = UINT32_C(0x71000000); |
362 | } |
363 | |
364 | base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; |
365 | const uint32_t bits = fp32_to_bits(base); |
366 | const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); |
367 | const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); |
368 | const uint32_t nonsign = exp_bits + mantissa_bits; |
369 | return static_cast<uint16_t>( |
370 | (sign >> 16) | |
371 | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); |
372 | } |
373 | |
374 | } // namespace detail |
375 | |
376 | struct alignas(2) Half { |
377 | unsigned short x; |
378 | |
379 | struct from_bits_t {}; |
380 | C10_HOST_DEVICE static constexpr from_bits_t from_bits() { |
381 | return from_bits_t(); |
382 | } |
383 | |
384 | // HIP wants __host__ __device__ tag, CUDA does not |
385 | #if defined(USE_ROCM) |
386 | C10_HOST_DEVICE Half() = default; |
387 | #else |
388 | Half() = default; |
389 | #endif |
390 | |
391 | constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits){}; |
392 | inline C10_HOST_DEVICE Half(float value); |
393 | inline C10_HOST_DEVICE operator float() const; |
394 | |
395 | #if defined(__CUDACC__) || defined(__HIPCC__) |
396 | inline C10_HOST_DEVICE Half(const __half& value); |
397 | inline C10_HOST_DEVICE operator __half() const; |
398 | #endif |
399 | #ifdef SYCL_LANGUAGE_VERSION |
400 | inline C10_HOST_DEVICE Half(const sycl::half& value); |
401 | inline C10_HOST_DEVICE operator sycl::half() const; |
402 | #endif |
403 | }; |
404 | |
405 | // TODO : move to complex.h |
406 | template <> |
407 | struct alignas(4) complex<Half> { |
408 | Half real_; |
409 | Half imag_; |
410 | |
411 | // Constructors |
412 | complex() = default; |
413 | // Half constructor is not constexpr so the following constructor can't |
414 | // be constexpr |
415 | C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) |
416 | : real_(real), imag_(imag) {} |
417 | C10_HOST_DEVICE inline complex(const c10::complex<float>& value) |
418 | : real_(value.real()), imag_(value.imag()) {} |
419 | |
420 | // Conversion operator |
421 | inline C10_HOST_DEVICE operator c10::complex<float>() const { |
422 | return {real_, imag_}; |
423 | } |
424 | |
425 | constexpr C10_HOST_DEVICE Half real() const { |
426 | return real_; |
427 | } |
428 | constexpr C10_HOST_DEVICE Half imag() const { |
429 | return imag_; |
430 | } |
431 | |
432 | C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) { |
433 | real_ = static_cast<float>(real_) + static_cast<float>(other.real_); |
434 | imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_); |
435 | return *this; |
436 | } |
437 | |
438 | C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) { |
439 | real_ = static_cast<float>(real_) - static_cast<float>(other.real_); |
440 | imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_); |
441 | return *this; |
442 | } |
443 | |
444 | C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) { |
445 | auto a = static_cast<float>(real_); |
446 | auto b = static_cast<float>(imag_); |
447 | auto c = static_cast<float>(other.real()); |
448 | auto d = static_cast<float>(other.imag()); |
449 | real_ = a * c - b * d; |
450 | imag_ = a * d + b * c; |
451 | return *this; |
452 | } |
453 | }; |
454 | |
455 | // In some versions of MSVC, there will be a compiler error when building. |
456 | // C4146: unary minus operator applied to unsigned type, result still unsigned |
457 | // C4804: unsafe use of type 'bool' in operation |
458 | // It can be addressed by disabling the following warning. |
459 | #ifdef _MSC_VER |
460 | #pragma warning(push) |
461 | #pragma warning(disable : 4146) |
462 | #pragma warning(disable : 4804) |
463 | #pragma warning(disable : 4018) |
464 | #endif |
465 | |
466 | // The overflow checks may involve float to int conversion which may |
467 | // trigger precision loss warning. Re-enable the warning once the code |
468 | // is fixed. See T58053069. |
469 | #ifdef __clang__ |
470 | #pragma GCC diagnostic push |
471 | #pragma GCC diagnostic ignored "-Wunknown-warning-option" |
472 | #pragma GCC diagnostic ignored "-Wimplicit-int-float-conversion" |
473 | #endif |
474 | |
475 | // bool can be converted to any type. |
476 | // Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build: |
477 | // `error: comparison of constant '255' with boolean expression is always false` |
478 | // for `f > limit::max()` below |
479 | template <typename To, typename From> |
480 | typename std::enable_if<std::is_same<From, bool>::value, bool>::type overflows( |
481 | From /*f*/) { |
482 | return false; |
483 | } |
484 | |
485 | // skip isnan and isinf check for integral types |
486 | template <typename To, typename From> |
487 | typename std::enable_if< |
488 | std::is_integral<From>::value && !std::is_same<From, bool>::value, |
489 | bool>::type |
490 | overflows(From f) { |
491 | using limit = std::numeric_limits<typename scalar_value_type<To>::type>; |
492 | if (!limit::is_signed && std::numeric_limits<From>::is_signed) { |
493 | // allow for negative numbers to wrap using two's complement arithmetic. |
494 | // For example, with uint8, this allows for `a - b` to be treated as |
495 | // `a + 255 * b`. |
496 | return greater_than_max<To>(f) || |
497 | (c10::is_negative(f) && -static_cast<uint64_t>(f) > limit::max()); |
498 | } else { |
499 | return c10::less_than_lowest<To>(f) || greater_than_max<To>(f); |
500 | } |
501 | } |
502 | |
503 | template <typename To, typename From> |
504 | typename std::enable_if<std::is_floating_point<From>::value, bool>::type |
505 | overflows(From f) { |
506 | using limit = std::numeric_limits<typename scalar_value_type<To>::type>; |
507 | if (limit::has_infinity && std::isinf(static_cast<double>(f))) { |
508 | return false; |
509 | } |
510 | if (!limit::has_quiet_NaN && (f != f)) { |
511 | return true; |
512 | } |
513 | return f < limit::lowest() || f > limit::max(); |
514 | } |
515 | |
516 | #ifdef __clang__ |
517 | #pragma GCC diagnostic pop |
518 | #endif |
519 | |
520 | #ifdef _MSC_VER |
521 | #pragma warning(pop) |
522 | #endif |
523 | |
524 | template <typename To, typename From> |
525 | typename std::enable_if<is_complex<From>::value, bool>::type overflows(From f) { |
526 | // casts from complex to real are considered to overflow if the |
527 | // imaginary component is non-zero |
528 | if (!is_complex<To>::value && f.imag() != 0) { |
529 | return true; |
530 | } |
531 | // Check for overflow componentwise |
532 | // (Technically, the imag overflow check is guaranteed to be false |
533 | // when !is_complex<To>, but any optimizer worth its salt will be |
534 | // able to figure it out.) |
535 | return overflows< |
536 | typename scalar_value_type<To>::type, |
537 | typename From::value_type>(f.real()) || |
538 | overflows< |
539 | typename scalar_value_type<To>::type, |
540 | typename From::value_type>(f.imag()); |
541 | } |
542 | |
543 | C10_API std::ostream& operator<<(std::ostream& out, const Half& value); |
544 | |
545 | } // namespace c10 |
546 | |
547 | #include <c10/util/Half-inl.h> // IWYU pragma: keep |
548 | |