1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_BFLOAT16_HPP
18#define COMMON_BFLOAT16_HPP
19
20#include <array>
21#include <cmath>
22#include <cstddef>
23#include <cstdint>
24#include <cstdlib>
25#include <limits>
26#include <type_traits>
27
28#include "common/bit_cast.hpp"
29
30#include "oneapi/dnnl/dnnl.h"
31
32namespace dnnl {
33namespace impl {
34
35#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
36struct bfloat16_t;
37bool try_cvt_float_to_bfloat16(bfloat16_t *out, const float *inp);
38#endif
39
40struct bfloat16_t {
41 uint16_t raw_bits_;
42 bfloat16_t() = default;
43 constexpr bfloat16_t(uint16_t r, bool) : raw_bits_(r) {}
44 bfloat16_t(float f) { (*this) = f; }
45
46 template <typename IntegerType,
47 typename SFINAE = typename std::enable_if<
48 std::is_integral<IntegerType>::value>::type>
49 bfloat16_t(const IntegerType i)
50 : raw_bits_ {convert_bits_of_normal_or_zero(
51 utils::bit_cast<uint32_t>(static_cast<float>(i)))} {}
52
53 bfloat16_t DNNL_API &operator=(float f);
54
55 template <typename IntegerType,
56 typename SFINAE = typename std::enable_if<
57 std::is_integral<IntegerType>::value>::type>
58 bfloat16_t &operator=(const IntegerType i) {
59 // Call the converting constructor that is optimized for integer types,
60 // followed by the fast defaulted move-assignment operator.
61 return (*this) = bfloat16_t {i};
62 }
63
64 DNNL_API operator float() const;
65
66 bfloat16_t &operator+=(const float a) {
67 (*this) = float {*this} + a;
68 return *this;
69 }
70
71private:
72 // Converts the 32 bits of a normal float or zero to the bits of a bfloat16.
73 static constexpr uint16_t convert_bits_of_normal_or_zero(
74 const uint32_t bits) {
75 return uint32_t {
76 bits + uint32_t {0x7FFFU + (uint32_t {bits >> 16} & 1U)}}
77 >> 16;
78 }
79};
80
81static_assert(sizeof(bfloat16_t) == 2, "bfloat16_t must be 2 bytes");
82
83void cvt_float_to_bfloat16(bfloat16_t *out, const float *inp, size_t nelems);
84void cvt_bfloat16_to_float(float *out, const bfloat16_t *inp, size_t nelems);
85
86// performs element-by-element sum of inp and add float arrays and stores
87// result to bfloat16 out array with downconversion
88// out[:] = (bfloat16_t)(inp0[:] + inp1[:])
89void add_floats_and_cvt_to_bfloat16(
90 bfloat16_t *out, const float *inp0, const float *inp1, size_t nelems);
91
92} // namespace impl
93} // namespace dnnl
94
95#endif
96