1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <array> |
18 | #include <memory> |
19 | |
20 | #include "common/bfloat16.hpp" |
21 | #include "common/bit_cast.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | |
24 | #include "cpu/platform.hpp" |
25 | |
26 | #if DNNL_X64 |
27 | #include "cpu/x64/cpu_isa_traits.hpp" |
28 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
29 | #include "cpu/x64/jit_uni_convert_xf16.hpp" |
30 | #endif |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | |
35 | bool try_cvt_float_to_bfloat16(bfloat16_t *out, const float *inp) { |
36 | |
37 | #if DNNL_X64 |
38 | using namespace cpu::x64; |
39 | if (mayiuse(cpu_isa_t::avx512_core) || mayiuse(avx2_vnni_2)) { |
40 | cpu::x64::cvt_xf16_support::jit_call_t p_; |
41 | p_.inp = (void *)inp; |
42 | p_.out = (void *)out; |
43 | static const cpu::x64::jit_cvt_ps_to_xf16_t cvt_one_ps_to_bf16( |
44 | data_type::bf16, 1); |
45 | cvt_one_ps_to_bf16(&p_); |
46 | return true; |
47 | } |
48 | #endif |
49 | return false; |
50 | } |
51 | |
52 | void cvt_float_to_bfloat16(bfloat16_t *out, const float *inp, size_t nelems) { |
53 | #if DNNL_X64 |
54 | using namespace cpu::x64; |
55 | if (mayiuse(cpu_isa_t::avx512_core) || mayiuse(avx2_vnni_2)) { |
56 | cpu::x64::cvt_xf16_support::jit_call_t p_; |
57 | p_.inp = (void *)inp; |
58 | p_.out = (void *)out; |
59 | p_.nelems = nelems; |
60 | static const cpu::x64::jit_cvt_ps_to_xf16_t cvt_ps_to_bf16( |
61 | data_type::bf16); |
62 | cvt_ps_to_bf16(&p_); |
63 | return; |
64 | } |
65 | #endif |
66 | |
67 | PRAGMA_OMP_SIMD() |
68 | for (size_t i = 0; i < nelems; ++i) |
69 | out[i] = inp[i]; |
70 | } |
71 | |
72 | void cvt_bfloat16_to_float(float *out, const bfloat16_t *inp, size_t nelems) { |
73 | #if DNNL_X64 |
74 | using namespace cpu::x64; |
75 | if (mayiuse(cpu_isa_t::avx512_core) || mayiuse(avx2_vnni_2)) { |
76 | static const cpu::x64::jit_cvt_xf16_to_ps_t kernel( |
77 | data_type::bf16, false); |
78 | return kernel(out, inp, nelems); |
79 | } |
80 | #endif |
81 | |
82 | PRAGMA_OMP_SIMD() |
83 | for (size_t i = 0; i < nelems; ++i) |
84 | out[i] = inp[i]; |
85 | } |
86 | |
87 | void add_floats_and_cvt_to_bfloat16( |
88 | bfloat16_t *out, const float *inp0, const float *inp1, size_t nelems) { |
89 | #if DNNL_X64 |
90 | if (cpu::x64::mayiuse(cpu::x64::cpu_isa_t::avx512_core)) { |
91 | cpu::x64::bf16_support::jit_call_t p_; |
92 | p_.inp = (void *)inp0; |
93 | p_.add = (void *)inp1; |
94 | p_.out = (void *)out; |
95 | p_.nelems = nelems; |
96 | static const cpu::x64::jit_avx512_core_add_cvt_ps_to_bf16_t |
97 | add_cvt_ps_to_bf16; |
98 | add_cvt_ps_to_bf16(&p_); |
99 | return; |
100 | } |
101 | #endif |
102 | |
103 | PRAGMA_OMP_SIMD() |
104 | for (size_t i = 0; i < nelems; ++i) |
105 | out[i] = inp0[i] + inp1[i]; |
106 | } |
107 | |
108 | } // namespace impl |
109 | } // namespace dnnl |
110 | |