1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <array>
18#include <memory>
19
20#include "common/bfloat16.hpp"
21#include "common/bit_cast.hpp"
22#include "common/dnnl_thread.hpp"
23
24#include "cpu/platform.hpp"
25
26#if DNNL_X64
27#include "cpu/x64/cpu_isa_traits.hpp"
28#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
29#include "cpu/x64/jit_uni_convert_xf16.hpp"
30#endif
31
32namespace dnnl {
33namespace impl {
34
35bool try_cvt_float_to_bfloat16(bfloat16_t *out, const float *inp) {
36
37#if DNNL_X64
38 using namespace cpu::x64;
39 if (mayiuse(cpu_isa_t::avx512_core) || mayiuse(avx2_vnni_2)) {
40 cpu::x64::cvt_xf16_support::jit_call_t p_;
41 p_.inp = (void *)inp;
42 p_.out = (void *)out;
43 static const cpu::x64::jit_cvt_ps_to_xf16_t cvt_one_ps_to_bf16(
44 data_type::bf16, 1);
45 cvt_one_ps_to_bf16(&p_);
46 return true;
47 }
48#endif
49 return false;
50}
51
52void cvt_float_to_bfloat16(bfloat16_t *out, const float *inp, size_t nelems) {
53#if DNNL_X64
54 using namespace cpu::x64;
55 if (mayiuse(cpu_isa_t::avx512_core) || mayiuse(avx2_vnni_2)) {
56 cpu::x64::cvt_xf16_support::jit_call_t p_;
57 p_.inp = (void *)inp;
58 p_.out = (void *)out;
59 p_.nelems = nelems;
60 static const cpu::x64::jit_cvt_ps_to_xf16_t cvt_ps_to_bf16(
61 data_type::bf16);
62 cvt_ps_to_bf16(&p_);
63 return;
64 }
65#endif
66
67 PRAGMA_OMP_SIMD()
68 for (size_t i = 0; i < nelems; ++i)
69 out[i] = inp[i];
70}
71
72void cvt_bfloat16_to_float(float *out, const bfloat16_t *inp, size_t nelems) {
73#if DNNL_X64
74 using namespace cpu::x64;
75 if (mayiuse(cpu_isa_t::avx512_core) || mayiuse(avx2_vnni_2)) {
76 static const cpu::x64::jit_cvt_xf16_to_ps_t kernel(
77 data_type::bf16, false);
78 return kernel(out, inp, nelems);
79 }
80#endif
81
82 PRAGMA_OMP_SIMD()
83 for (size_t i = 0; i < nelems; ++i)
84 out[i] = inp[i];
85}
86
87void add_floats_and_cvt_to_bfloat16(
88 bfloat16_t *out, const float *inp0, const float *inp1, size_t nelems) {
89#if DNNL_X64
90 if (cpu::x64::mayiuse(cpu::x64::cpu_isa_t::avx512_core)) {
91 cpu::x64::bf16_support::jit_call_t p_;
92 p_.inp = (void *)inp0;
93 p_.add = (void *)inp1;
94 p_.out = (void *)out;
95 p_.nelems = nelems;
96 static const cpu::x64::jit_avx512_core_add_cvt_ps_to_bf16_t
97 add_cvt_ps_to_bf16;
98 add_cvt_ps_to_bf16(&p_);
99 return;
100 }
101#endif
102
103 PRAGMA_OMP_SIMD()
104 for (size_t i = 0; i < nelems; ++i)
105 out[i] = inp0[i] + inp1[i];
106}
107
108} // namespace impl
109} // namespace dnnl
110