1 | #include <vector> |
2 | |
3 | #include <c10/util/Half.h> |
4 | #include <gtest/gtest.h> |
5 | |
6 | namespace { |
7 | namespace half_legacy_impl { |
8 | float halfbits2float(unsigned short h) { |
9 | unsigned sign = ((h >> 15) & 1); |
10 | unsigned exponent = ((h >> 10) & 0x1f); |
11 | unsigned mantissa = ((h & 0x3ff) << 13); |
12 | |
13 | if (exponent == 0x1f) { /* NaN or Inf */ |
14 | mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); |
15 | exponent = 0xff; |
16 | } else if (!exponent) { /* Denorm or Zero */ |
17 | if (mantissa) { |
18 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
19 | unsigned int msb; |
20 | exponent = 0x71; |
21 | do { |
22 | msb = (mantissa & 0x400000); |
23 | mantissa <<= 1; /* normalize */ |
24 | --exponent; |
25 | } while (!msb); |
26 | mantissa &= 0x7fffff; /* 1.mantissa is implicit */ |
27 | } |
28 | } else { |
29 | exponent += 0x70; |
30 | } |
31 | |
32 | unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa; |
33 | |
34 | // Reinterpret the result bit pattern as a float |
35 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
36 | float result_float; |
37 | std::memcpy(&result_float, &result_bit, sizeof(result_float)); |
38 | return result_float; |
39 | }; |
40 | |
41 | unsigned short float2halfbits(float src) { |
42 | // Reinterpret the float as a bit pattern |
43 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
44 | unsigned x; |
45 | std::memcpy(&x, &src, sizeof(x)); |
46 | |
47 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables,cppcoreguidelines-avoid-magic-numbers) |
48 | unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; |
49 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
50 | unsigned sign, exponent, mantissa; |
51 | |
52 | // Get rid of +NaN/-NaN case first. |
53 | if (u > 0x7f800000) { |
54 | return 0x7fffU; |
55 | } |
56 | |
57 | sign = ((x >> 16) & 0x8000); |
58 | |
59 | // Get rid of +Inf/-Inf, +0/-0. |
60 | if (u > 0x477fefff) { |
61 | return sign | 0x7c00U; |
62 | } |
63 | if (u < 0x33000001) { |
64 | return (sign | 0x0000); |
65 | } |
66 | |
67 | exponent = ((u >> 23) & 0xff); |
68 | mantissa = (u & 0x7fffff); |
69 | |
70 | if (exponent > 0x70) { |
71 | shift = 13; |
72 | exponent -= 0x70; |
73 | } else { |
74 | shift = 0x7e - exponent; |
75 | exponent = 0; |
76 | mantissa |= 0x800000; |
77 | } |
78 | lsb = (1 << shift); |
79 | lsb_s1 = (lsb >> 1); |
80 | lsb_m1 = (lsb - 1); |
81 | |
82 | // Round to nearest even. |
83 | remainder = (mantissa & lsb_m1); |
84 | mantissa >>= shift; |
85 | if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { |
86 | ++mantissa; |
87 | if (!(mantissa & 0x3ff)) { |
88 | ++exponent; |
89 | mantissa = 0; |
90 | } |
91 | } |
92 | |
93 | return (sign | (exponent << 10) | mantissa); |
94 | }; |
95 | } // namespace half_legacy_impl |
96 | TEST(HalfDoubleConversionTest, Half2Double) { |
97 | std::vector<uint16_t> inputs = { |
98 | 0, |
99 | 0xfbff, // 1111 1011 1111 1111 |
100 | (1 << 15 | 1), |
101 | 0x7bff // 0111 1011 1111 1111 |
102 | }; |
103 | for (auto x : inputs) { |
104 | auto target = c10::detail::fp16_ieee_to_fp32_value(x); |
105 | EXPECT_EQ(half_legacy_impl::halfbits2float(x), target) |
106 | << "Test failed for uint16 to float " << x << "\n" ; |
107 | EXPECT_EQ( |
108 | half_legacy_impl::float2halfbits(target), |
109 | c10::detail::fp16_ieee_from_fp32_value(target)) |
110 | << "Test failed for float to uint16" << target << "\n" ; |
111 | } |
112 | } |
113 | } // namespace |
114 | |