1 | // clang-format off |
2 | #include <c10/util/BFloat16.h> |
3 | #include <c10/util/BFloat16-math.h> |
4 | #include <c10/util/irange.h> |
5 | // clang-format on |
6 | #include <gtest/gtest.h> |
7 | |
8 | namespace { |
9 | float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) { |
10 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
11 | uint32_t bytes; |
12 | bytes = 0; |
13 | bytes |= sign; |
14 | bytes <<= 8; |
15 | bytes |= exponent; |
16 | bytes <<= 23; |
17 | bytes |= fraction; |
18 | |
19 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
20 | float res; |
21 | std::memcpy(&res, &bytes, sizeof(res)); |
22 | return res; |
23 | } |
24 | |
25 | TEST(BFloat16Conversion, FloatToBFloat16AndBack) { |
26 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) |
27 | float in[100]; |
28 | for (const auto i : c10::irange(100)) { |
29 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) |
30 | in[i] = i + 1.25; |
31 | } |
32 | |
33 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) |
34 | c10::BFloat16 bfloats[100]; |
35 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) |
36 | float out[100]; |
37 | |
38 | for (const auto i : c10::irange(100)) { |
39 | bfloats[i].x = c10::detail::bits_from_f32(in[i]); |
40 | out[i] = c10::detail::f32_from_bits(bfloats[i].x); |
41 | |
42 | // The relative error should be less than 1/(2^7) since BFloat16 |
43 | // has 7 bits mantissa. |
44 | EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128); |
45 | } |
46 | } |
47 | |
48 | TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) { |
49 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) |
50 | float in[100]; |
51 | for (const auto i : c10::irange(100)) { |
52 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) |
53 | in[i] = i + 1.25; |
54 | } |
55 | |
56 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) |
57 | c10::BFloat16 bfloats[100]; |
58 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) |
59 | float out[100]; |
60 | |
61 | for (const auto i : c10::irange(100)) { |
62 | bfloats[i].x = c10::detail::round_to_nearest_even(in[i]); |
63 | out[i] = c10::detail::f32_from_bits(bfloats[i].x); |
64 | |
65 | // The relative error should be less than 1/(2^7) since BFloat16 |
66 | // has 7 bits mantissa. |
67 | EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128); |
68 | } |
69 | } |
70 | |
71 | TEST(BFloat16Conversion, NaN) { |
72 | float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF); |
73 | EXPECT_TRUE(std::isnan(inNaN)); |
74 | |
75 | c10::BFloat16 a = c10::BFloat16(inNaN); |
76 | float out = c10::detail::f32_from_bits(a.x); |
77 | |
78 | EXPECT_TRUE(std::isnan(out)); |
79 | } |
80 | |
81 | TEST(BFloat16Conversion, Inf) { |
82 | float inInf = float_from_bytes(0, 0xFF, 0); |
83 | EXPECT_TRUE(std::isinf(inInf)); |
84 | |
85 | c10::BFloat16 a = c10::BFloat16(inInf); |
86 | float out = c10::detail::f32_from_bits(a.x); |
87 | |
88 | EXPECT_TRUE(std::isinf(out)); |
89 | } |
90 | |
91 | TEST(BFloat16Conversion, SmallestDenormal) { |
92 | float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero |
93 | // subnormal number |
94 | c10::BFloat16 a = c10::BFloat16(in); |
95 | float out = c10::detail::f32_from_bits(a.x); |
96 | |
97 | EXPECT_FLOAT_EQ(in, out); |
98 | } |
99 | |
100 | TEST(BFloat16Math, Addition) { |
101 | // This test verifies that if only first 7 bits of float's mantissa are |
102 | // changed after addition, we should have no loss in precision. |
103 | |
104 | // input bits |
105 | // S | Exponent | Mantissa |
106 | // 0 | 10000000 | 10010000000000000000000 = 3.125 |
107 | float input = float_from_bytes(0, 0, 0x40480000); |
108 | |
109 | // expected bits |
110 | // S | Exponent | Mantissa |
111 | // 0 | 10000001 | 10010000000000000000000 = 6.25 |
112 | float expected = float_from_bytes(0, 0, 0x40c80000); |
113 | |
114 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
115 | c10::BFloat16 b; |
116 | b.x = c10::detail::bits_from_f32(input); |
117 | b = b + b; |
118 | |
119 | float res = c10::detail::f32_from_bits(b.x); |
120 | EXPECT_EQ(res, expected); |
121 | } |
122 | |
123 | TEST(BFloat16Math, Subtraction) { |
124 | // This test verifies that if only first 7 bits of float's mantissa are |
125 | // changed after subtraction, we should have no loss in precision. |
126 | |
127 | // input bits |
128 | // S | Exponent | Mantissa |
129 | // 0 | 10000001 | 11101000000000000000000 = 7.625 |
130 | float input = float_from_bytes(0, 0, 0x40f40000); |
131 | |
132 | // expected bits |
133 | // S | Exponent | Mantissa |
134 | // 0 | 10000000 | 01010000000000000000000 = 2.625 |
135 | float expected = float_from_bytes(0, 0, 0x40280000); |
136 | |
137 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
138 | c10::BFloat16 b; |
139 | b.x = c10::detail::bits_from_f32(input); |
140 | b = b - 5; |
141 | |
142 | float res = c10::detail::f32_from_bits(b.x); |
143 | EXPECT_EQ(res, expected); |
144 | } |
145 | |
146 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
147 | TEST(BFloat16Math, NextAfterZero) { |
148 | const c10::BFloat16 zero{0}; |
149 | |
150 | auto check_nextafter = |
151 | [](c10::BFloat16 from, c10::BFloat16 to, c10::BFloat16 expected) { |
152 | c10::BFloat16 actual = std::nextafter(from, to); |
153 | // Check for bitwise equality! |
154 | ASSERT_EQ(actual.x ^ expected.x, uint16_t{0}); |
155 | }; |
156 | check_nextafter(zero, zero, /*expected=*/zero); |
157 | check_nextafter(zero, -zero, /*expected=*/-zero); |
158 | check_nextafter(-zero, zero, /*expected=*/zero); |
159 | check_nextafter(-zero, -zero, /*expected=*/-zero); |
160 | } |
161 | |
162 | float BinaryToFloat(uint32_t bytes) { |
163 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
164 | float res; |
165 | std::memcpy(&res, &bytes, sizeof(res)); |
166 | return res; |
167 | } |
168 | |
169 | struct BFloat16TestParam { |
170 | uint32_t input; |
171 | uint16_t rne; |
172 | }; |
173 | |
174 | class BFloat16Test : public ::testing::Test, |
175 | public ::testing::WithParamInterface<BFloat16TestParam> {}; |
176 | |
177 | TEST_P(BFloat16Test, BFloat16RNETest) { |
178 | float value = BinaryToFloat(GetParam().input); |
179 | uint16_t rounded = c10::detail::round_to_nearest_even(value); |
180 | EXPECT_EQ(GetParam().rne, rounded); |
181 | } |
182 | |
183 | INSTANTIATE_TEST_CASE_P( |
184 | BFloat16Test_Instantiation, |
185 | BFloat16Test, |
186 | ::testing::Values( |
187 | BFloat16TestParam{0x3F848000, 0x3F84}, |
188 | BFloat16TestParam{0x3F848010, 0x3F85}, |
189 | BFloat16TestParam{0x3F850000, 0x3F85}, |
190 | BFloat16TestParam{0x3F858000, 0x3F86}, |
191 | BFloat16TestParam{0x3FFF8000, 0x4000})); |
192 | |
193 | } // namespace |
194 | |