1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_FLOAT16_HPP |
18 | #define COMMON_FLOAT16_HPP |
19 | |
20 | #include <cmath> |
21 | #include <cstdint> |
22 | #include <limits> |
23 | #include <type_traits> |
24 | |
25 | #include "bit_cast.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | |
30 | struct float16_t { |
31 | uint16_t raw; |
32 | |
33 | constexpr float16_t(uint16_t raw, bool) : raw(raw) {} |
34 | |
35 | float16_t() = default; |
36 | float16_t(float f) { (*this) = f; } |
37 | |
38 | float16_t &operator=(float f); |
39 | |
40 | operator float() const; |
41 | float f() { return (float)(*this); } |
42 | |
43 | float16_t &operator+=(float16_t a) { |
44 | (*this) = float(f() + a.f()); |
45 | return *this; |
46 | } |
47 | }; |
48 | |
49 | static_assert(sizeof(float16_t) == 2, "float16_t must be 2 bytes" ); |
50 | |
51 | inline float16_t &float16_t::operator=(float f) { |
52 | uint32_t i = utils::bit_cast<uint32_t>(f); |
53 | uint32_t s = i >> 31; |
54 | uint32_t e = (i >> 23) & 0xFF; |
55 | uint32_t m = i & 0x7FFFFF; |
56 | |
57 | uint32_t ss = s; |
58 | uint32_t mm = m >> 13; |
59 | uint32_t r = m & 0x1FFF; |
60 | uint32_t ee = 0; |
61 | int32_t eee = (e - 127) + 15; |
62 | |
63 | if (e == 0) { |
64 | // Denormal/zero floats all become zero. |
65 | ee = 0; |
66 | mm = 0; |
67 | } else if (e == 0xFF) { |
68 | // Preserve inf/nan. |
69 | ee = 0x1F; |
70 | if (m != 0 && mm == 0) mm = 1; |
71 | } else if (eee > 0 && eee < 0x1F) { |
72 | // Normal range. Perform round to even on mantissa. |
73 | ee = eee; |
74 | if (r > (0x1000 - (mm & 1))) { |
75 | // Round up. |
76 | mm++; |
77 | if (mm == 0x400) { |
78 | // Rounds up to next dyad (or inf). |
79 | mm = 0; |
80 | ee++; |
81 | } |
82 | } |
83 | } else if (eee >= 0x1F) { |
84 | // Overflow. |
85 | ee = 0x1F; |
86 | mm = 0; |
87 | } else { |
88 | // Underflow. |
89 | float ff = fabsf(f) + 0.5; |
90 | uint32_t ii = utils::bit_cast<uint32_t>(ff); |
91 | ee = 0; |
92 | mm = ii & 0x7FF; |
93 | } |
94 | |
95 | this->raw = (ss << 15) | (ee << 10) | mm; |
96 | return *this; |
97 | } |
98 | |
99 | inline float16_t::operator float() const { |
100 | uint32_t ss = raw >> 15; |
101 | uint32_t ee = (raw >> 10) & 0x1F; |
102 | uint32_t mm = raw & 0x3FF; |
103 | |
104 | uint32_t s = ss; |
105 | uint32_t eee = ee - 15 + 127; |
106 | uint32_t m = mm << 13; |
107 | uint32_t e; |
108 | |
109 | if (ee == 0) { |
110 | if (mm == 0) |
111 | e = 0; |
112 | else { |
113 | // Half denormal -> float normal |
114 | return (ss ? -1 : 1) * std::scalbn((float)mm, -24); |
115 | } |
116 | } else if (ee == 0x1F) { |
117 | // inf/nan |
118 | e = 0xFF; |
119 | } else |
120 | e = eee; |
121 | |
122 | uint32_t f = (s << 31) | (e << 23) | m; |
123 | |
124 | return utils::bit_cast<float>(f); |
125 | } |
126 | |
127 | void cvt_float_to_float16(float16_t *out, const float *inp, size_t nelems); |
128 | void cvt_float16_to_float(float *out, const float16_t *inp, size_t nelems); |
129 | |
130 | // performs element-by-element sum of inp and add float arrays and stores |
131 | // result to float16 out array with downconversion |
132 | // out[:] = (float16_t)(inp0[:] + inp1[:]) |
133 | void add_floats_and_cvt_to_float16( |
134 | float16_t *out, const float *inp0, const float *inp1, size_t nelems); |
135 | |
136 | } // namespace impl |
137 | } // namespace dnnl |
138 | |
139 | #endif |
140 | |