1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_FLOAT16_HPP
18#define COMMON_FLOAT16_HPP
19
20#include <cmath>
21#include <cstdint>
22#include <limits>
23#include <type_traits>
24
25#include "bit_cast.hpp"
26
27namespace dnnl {
28namespace impl {
29
30struct float16_t {
31 uint16_t raw;
32
33 constexpr float16_t(uint16_t raw, bool) : raw(raw) {}
34
35 float16_t() = default;
36 float16_t(float f) { (*this) = f; }
37
38 float16_t &operator=(float f);
39
40 operator float() const;
41 float f() { return (float)(*this); }
42
43 float16_t &operator+=(float16_t a) {
44 (*this) = float(f() + a.f());
45 return *this;
46 }
47};
48
49static_assert(sizeof(float16_t) == 2, "float16_t must be 2 bytes");
50
51inline float16_t &float16_t::operator=(float f) {
52 uint32_t i = utils::bit_cast<uint32_t>(f);
53 uint32_t s = i >> 31;
54 uint32_t e = (i >> 23) & 0xFF;
55 uint32_t m = i & 0x7FFFFF;
56
57 uint32_t ss = s;
58 uint32_t mm = m >> 13;
59 uint32_t r = m & 0x1FFF;
60 uint32_t ee = 0;
61 int32_t eee = (e - 127) + 15;
62
63 if (e == 0) {
64 // Denormal/zero floats all become zero.
65 ee = 0;
66 mm = 0;
67 } else if (e == 0xFF) {
68 // Preserve inf/nan.
69 ee = 0x1F;
70 if (m != 0 && mm == 0) mm = 1;
71 } else if (eee > 0 && eee < 0x1F) {
72 // Normal range. Perform round to even on mantissa.
73 ee = eee;
74 if (r > (0x1000 - (mm & 1))) {
75 // Round up.
76 mm++;
77 if (mm == 0x400) {
78 // Rounds up to next dyad (or inf).
79 mm = 0;
80 ee++;
81 }
82 }
83 } else if (eee >= 0x1F) {
84 // Overflow.
85 ee = 0x1F;
86 mm = 0;
87 } else {
88 // Underflow.
89 float ff = fabsf(f) + 0.5;
90 uint32_t ii = utils::bit_cast<uint32_t>(ff);
91 ee = 0;
92 mm = ii & 0x7FF;
93 }
94
95 this->raw = (ss << 15) | (ee << 10) | mm;
96 return *this;
97}
98
99inline float16_t::operator float() const {
100 uint32_t ss = raw >> 15;
101 uint32_t ee = (raw >> 10) & 0x1F;
102 uint32_t mm = raw & 0x3FF;
103
104 uint32_t s = ss;
105 uint32_t eee = ee - 15 + 127;
106 uint32_t m = mm << 13;
107 uint32_t e;
108
109 if (ee == 0) {
110 if (mm == 0)
111 e = 0;
112 else {
113 // Half denormal -> float normal
114 return (ss ? -1 : 1) * std::scalbn((float)mm, -24);
115 }
116 } else if (ee == 0x1F) {
117 // inf/nan
118 e = 0xFF;
119 } else
120 e = eee;
121
122 uint32_t f = (s << 31) | (e << 23) | m;
123
124 return utils::bit_cast<float>(f);
125}
126
127void cvt_float_to_float16(float16_t *out, const float *inp, size_t nelems);
128void cvt_float16_to_float(float *out, const float16_t *inp, size_t nelems);
129
130// performs element-by-element sum of inp and add float arrays and stores
131// result to float16 out array with downconversion
132// out[:] = (float16_t)(inp0[:] + inp1[:])
133void add_floats_and_cvt_to_float16(
134 float16_t *out, const float *inp0, const float *inp1, size_t nelems);
135
136} // namespace impl
137} // namespace dnnl
138
139#endif
140