1#include <vector>
2
3#include <c10/util/Half.h>
4#include <gtest/gtest.h>
5
6namespace {
7namespace half_legacy_impl {
8float halfbits2float(unsigned short h) {
9 unsigned sign = ((h >> 15) & 1);
10 unsigned exponent = ((h >> 10) & 0x1f);
11 unsigned mantissa = ((h & 0x3ff) << 13);
12
13 if (exponent == 0x1f) { /* NaN or Inf */
14 mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
15 exponent = 0xff;
16 } else if (!exponent) { /* Denorm or Zero */
17 if (mantissa) {
18 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
19 unsigned int msb;
20 exponent = 0x71;
21 do {
22 msb = (mantissa & 0x400000);
23 mantissa <<= 1; /* normalize */
24 --exponent;
25 } while (!msb);
26 mantissa &= 0x7fffff; /* 1.mantissa is implicit */
27 }
28 } else {
29 exponent += 0x70;
30 }
31
32 unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa;
33
34 // Reinterpret the result bit pattern as a float
35 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
36 float result_float;
37 std::memcpy(&result_float, &result_bit, sizeof(result_float));
38 return result_float;
39};
40
41unsigned short float2halfbits(float src) {
42 // Reinterpret the float as a bit pattern
43 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
44 unsigned x;
45 std::memcpy(&x, &src, sizeof(x));
46
47 // NOLINTNEXTLINE(cppcoreguidelines-init-variables,cppcoreguidelines-avoid-magic-numbers)
48 unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
49 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
50 unsigned sign, exponent, mantissa;
51
52 // Get rid of +NaN/-NaN case first.
53 if (u > 0x7f800000) {
54 return 0x7fffU;
55 }
56
57 sign = ((x >> 16) & 0x8000);
58
59 // Get rid of +Inf/-Inf, +0/-0.
60 if (u > 0x477fefff) {
61 return sign | 0x7c00U;
62 }
63 if (u < 0x33000001) {
64 return (sign | 0x0000);
65 }
66
67 exponent = ((u >> 23) & 0xff);
68 mantissa = (u & 0x7fffff);
69
70 if (exponent > 0x70) {
71 shift = 13;
72 exponent -= 0x70;
73 } else {
74 shift = 0x7e - exponent;
75 exponent = 0;
76 mantissa |= 0x800000;
77 }
78 lsb = (1 << shift);
79 lsb_s1 = (lsb >> 1);
80 lsb_m1 = (lsb - 1);
81
82 // Round to nearest even.
83 remainder = (mantissa & lsb_m1);
84 mantissa >>= shift;
85 if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
86 ++mantissa;
87 if (!(mantissa & 0x3ff)) {
88 ++exponent;
89 mantissa = 0;
90 }
91 }
92
93 return (sign | (exponent << 10) | mantissa);
94};
95} // namespace half_legacy_impl
96TEST(HalfDoubleConversionTest, Half2Double) {
97 std::vector<uint16_t> inputs = {
98 0,
99 0xfbff, // 1111 1011 1111 1111
100 (1 << 15 | 1),
101 0x7bff // 0111 1011 1111 1111
102 };
103 for (auto x : inputs) {
104 auto target = c10::detail::fp16_ieee_to_fp32_value(x);
105 EXPECT_EQ(half_legacy_impl::halfbits2float(x), target)
106 << "Test failed for uint16 to float " << x << "\n";
107 EXPECT_EQ(
108 half_legacy_impl::float2halfbits(target),
109 c10::detail::fp16_ieee_from_fp32_value(target))
110 << "Test failed for float to uint16" << target << "\n";
111 }
112}
113} // namespace
114