1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8
9#include <cstdint>
10#include <cstdlib>
11#include <cstring>
12
13namespace fbgemm {
14
15using float16 = std::uint16_t;
16using bfloat16 = std::uint16_t;
17
18// Round to nearest even
19static inline float16 cpu_float2half_rn(float f) {
20 float16 ret;
21
22 static_assert(
23 sizeof(unsigned int) == sizeof(float),
24 "Programming error sizeof(unsigned int) != sizeof(float)");
25
26 unsigned* xp = reinterpret_cast<unsigned int*>(&f);
27 unsigned x = *xp;
28 unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
29 unsigned sign, exponent, mantissa;
30
31 // Get rid of +NaN/-NaN case first.
32 if (u > 0x7f800000) {
33 ret = 0x7fffU;
34 return ret;
35 }
36
37 sign = ((x >> 16) & 0x8000);
38
39 // Get rid of +Inf/-Inf, +0/-0.
40 if (u > 0x477fefff) {
41 ret = static_cast<float16>(sign | 0x7c00U);
42 return ret;
43 }
44 if (u < 0x33000001) {
45 ret = static_cast<float16>(sign | 0x0000);
46 return ret;
47 }
48
49 exponent = ((u >> 23) & 0xff);
50 mantissa = (u & 0x7fffff);
51
52 if (exponent > 0x70) {
53 shift = 13;
54 exponent -= 0x70;
55 } else {
56 shift = 0x7e - exponent;
57 exponent = 0;
58 mantissa |= 0x800000;
59 }
60 lsb = (1 << shift);
61 lsb_s1 = (lsb >> 1);
62 lsb_m1 = (lsb - 1);
63
64 // Round to nearest even.
65 remainder = (mantissa & lsb_m1);
66 mantissa >>= shift;
67 if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
68 ++mantissa;
69 if (!(mantissa & 0x3ff)) {
70 ++exponent;
71 mantissa = 0;
72 }
73 }
74
75 ret = static_cast<float16>(sign | (exponent << 10) | mantissa);
76
77 return ret;
78}
79
80// Round to zero
81static inline float16 cpu_float2half_rz(float f) {
82 float16 ret;
83
84 static_assert(
85 sizeof(unsigned int) == sizeof(float),
86 "Programming error sizeof(unsigned int) != sizeof(float)");
87
88 unsigned* xp = reinterpret_cast<unsigned int*>(&f);
89 unsigned x = *xp;
90 unsigned u = (x & 0x7fffffff);
91 unsigned shift, sign, exponent, mantissa;
92
93 // Get rid of +NaN/-NaN case first.
94 if (u > 0x7f800000) {
95 ret = static_cast<float16>(0x7fffU);
96 return ret;
97 }
98
99 sign = ((x >> 16) & 0x8000);
100
101 // Get rid of +Inf/-Inf, +0/-0.
102 if (u > 0x477fefff) {
103 ret = static_cast<float16>(sign | 0x7c00U);
104 return ret;
105 }
106 if (u < 0x33000001) {
107 ret = static_cast<float16>(sign | 0x0000);
108 return ret;
109 }
110
111 exponent = ((u >> 23) & 0xff);
112 mantissa = (u & 0x7fffff);
113
114 if (exponent > 0x70) {
115 shift = 13;
116 exponent -= 0x70;
117 } else {
118 shift = 0x7e - exponent;
119 exponent = 0;
120 mantissa |= 0x800000;
121 }
122
123 // Round to zero.
124 mantissa >>= shift;
125
126 ret = static_cast<float16>(sign | (exponent << 10) | mantissa);
127
128 return ret;
129}
130
131static inline float cpu_half2float(float16 h) {
132 unsigned sign = ((h >> 15) & 1);
133 unsigned exponent = ((h >> 10) & 0x1f);
134 unsigned mantissa = ((h & 0x3ff) << 13);
135
136 if (exponent == 0x1f) { /* NaN or Inf */
137 mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
138 exponent = 0xff;
139 } else if (!exponent) { /* Denorm or Zero */
140 if (mantissa) {
141 unsigned int msb;
142 exponent = 0x71;
143 do {
144 msb = (mantissa & 0x400000);
145 mantissa <<= 1; /* normalize */
146 --exponent;
147 } while (!msb);
148 mantissa &= 0x7fffff; /* 1.mantissa is implicit */
149 }
150 } else {
151 exponent += 0x70;
152 }
153
154 unsigned i = ((sign << 31) | (exponent << 23) | mantissa);
155 float ret;
156 memcpy(&ret, &i, sizeof(i));
157 return ret;
158}
159
160static inline float cpu_bf162float(bfloat16 src) {
161 float ret;
162 uint32_t val_fp32 =
163 static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(&src)[0]) << 16;
164 memcpy(&ret, &val_fp32, sizeof(ret));
165 return ret;
166}
167
168static inline bfloat16 cpu_float2bfloat16(float src) {
169 uint32_t temp;
170 memcpy(&temp, &src, sizeof(temp));
171 return (temp + (1 << 15)) >> 16;
172}
173
174} // namespace fbgemm
175