1// clang-format off
2#include <c10/util/BFloat16.h>
3#include <c10/util/BFloat16-math.h>
4#include <c10/util/irange.h>
5// clang-format on
6#include <gtest/gtest.h>
7
8namespace {
9float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) {
10 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
11 uint32_t bytes;
12 bytes = 0;
13 bytes |= sign;
14 bytes <<= 8;
15 bytes |= exponent;
16 bytes <<= 23;
17 bytes |= fraction;
18
19 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
20 float res;
21 std::memcpy(&res, &bytes, sizeof(res));
22 return res;
23}
24
25TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
26 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
27 float in[100];
28 for (const auto i : c10::irange(100)) {
29 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
30 in[i] = i + 1.25;
31 }
32
33 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
34 c10::BFloat16 bfloats[100];
35 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
36 float out[100];
37
38 for (const auto i : c10::irange(100)) {
39 bfloats[i].x = c10::detail::bits_from_f32(in[i]);
40 out[i] = c10::detail::f32_from_bits(bfloats[i].x);
41
42 // The relative error should be less than 1/(2^7) since BFloat16
43 // has 7 bits mantissa.
44 EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
45 }
46}
47
48TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) {
49 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
50 float in[100];
51 for (const auto i : c10::irange(100)) {
52 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
53 in[i] = i + 1.25;
54 }
55
56 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
57 c10::BFloat16 bfloats[100];
58 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
59 float out[100];
60
61 for (const auto i : c10::irange(100)) {
62 bfloats[i].x = c10::detail::round_to_nearest_even(in[i]);
63 out[i] = c10::detail::f32_from_bits(bfloats[i].x);
64
65 // The relative error should be less than 1/(2^7) since BFloat16
66 // has 7 bits mantissa.
67 EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
68 }
69}
70
71TEST(BFloat16Conversion, NaN) {
72 float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF);
73 EXPECT_TRUE(std::isnan(inNaN));
74
75 c10::BFloat16 a = c10::BFloat16(inNaN);
76 float out = c10::detail::f32_from_bits(a.x);
77
78 EXPECT_TRUE(std::isnan(out));
79}
80
81TEST(BFloat16Conversion, Inf) {
82 float inInf = float_from_bytes(0, 0xFF, 0);
83 EXPECT_TRUE(std::isinf(inInf));
84
85 c10::BFloat16 a = c10::BFloat16(inInf);
86 float out = c10::detail::f32_from_bits(a.x);
87
88 EXPECT_TRUE(std::isinf(out));
89}
90
91TEST(BFloat16Conversion, SmallestDenormal) {
92 float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero
93 // subnormal number
94 c10::BFloat16 a = c10::BFloat16(in);
95 float out = c10::detail::f32_from_bits(a.x);
96
97 EXPECT_FLOAT_EQ(in, out);
98}
99
100TEST(BFloat16Math, Addition) {
101 // This test verifies that if only first 7 bits of float's mantissa are
102 // changed after addition, we should have no loss in precision.
103
104 // input bits
105 // S | Exponent | Mantissa
106 // 0 | 10000000 | 10010000000000000000000 = 3.125
107 float input = float_from_bytes(0, 0, 0x40480000);
108
109 // expected bits
110 // S | Exponent | Mantissa
111 // 0 | 10000001 | 10010000000000000000000 = 6.25
112 float expected = float_from_bytes(0, 0, 0x40c80000);
113
114 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
115 c10::BFloat16 b;
116 b.x = c10::detail::bits_from_f32(input);
117 b = b + b;
118
119 float res = c10::detail::f32_from_bits(b.x);
120 EXPECT_EQ(res, expected);
121}
122
123TEST(BFloat16Math, Subtraction) {
124 // This test verifies that if only first 7 bits of float's mantissa are
125 // changed after subtraction, we should have no loss in precision.
126
127 // input bits
128 // S | Exponent | Mantissa
129 // 0 | 10000001 | 11101000000000000000000 = 7.625
130 float input = float_from_bytes(0, 0, 0x40f40000);
131
132 // expected bits
133 // S | Exponent | Mantissa
134 // 0 | 10000000 | 01010000000000000000000 = 2.625
135 float expected = float_from_bytes(0, 0, 0x40280000);
136
137 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
138 c10::BFloat16 b;
139 b.x = c10::detail::bits_from_f32(input);
140 b = b - 5;
141
142 float res = c10::detail::f32_from_bits(b.x);
143 EXPECT_EQ(res, expected);
144}
145
146// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
147TEST(BFloat16Math, NextAfterZero) {
148 const c10::BFloat16 zero{0};
149
150 auto check_nextafter =
151 [](c10::BFloat16 from, c10::BFloat16 to, c10::BFloat16 expected) {
152 c10::BFloat16 actual = std::nextafter(from, to);
153 // Check for bitwise equality!
154 ASSERT_EQ(actual.x ^ expected.x, uint16_t{0});
155 };
156 check_nextafter(zero, zero, /*expected=*/zero);
157 check_nextafter(zero, -zero, /*expected=*/-zero);
158 check_nextafter(-zero, zero, /*expected=*/zero);
159 check_nextafter(-zero, -zero, /*expected=*/-zero);
160}
161
162float BinaryToFloat(uint32_t bytes) {
163 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
164 float res;
165 std::memcpy(&res, &bytes, sizeof(res));
166 return res;
167}
168
169struct BFloat16TestParam {
170 uint32_t input;
171 uint16_t rne;
172};
173
174class BFloat16Test : public ::testing::Test,
175 public ::testing::WithParamInterface<BFloat16TestParam> {};
176
177TEST_P(BFloat16Test, BFloat16RNETest) {
178 float value = BinaryToFloat(GetParam().input);
179 uint16_t rounded = c10::detail::round_to_nearest_even(value);
180 EXPECT_EQ(GetParam().rne, rounded);
181}
182
183INSTANTIATE_TEST_CASE_P(
184 BFloat16Test_Instantiation,
185 BFloat16Test,
186 ::testing::Values(
187 BFloat16TestParam{0x3F848000, 0x3F84},
188 BFloat16TestParam{0x3F848010, 0x3F85},
189 BFloat16TestParam{0x3F850000, 0x3F85},
190 BFloat16TestParam{0x3F858000, 0x3F86},
191 BFloat16TestParam{0x3FFF8000, 0x4000}));
192
193} // namespace
194