1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/float16.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | |
20 | #include "cpu/platform.hpp" |
21 | #if DNNL_X64 |
22 | #include "cpu/x64/cpu_isa_traits.hpp" |
23 | #include "cpu/x64/jit_avx512_core_fp16cvt.hpp" |
24 | #include "cpu/x64/jit_uni_convert_xf16.hpp" |
25 | #endif |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | |
30 | bool try_cvt_float_to_float16(float16_t *out, const float *inp) { |
31 | #if DNNL_X64 |
32 | using namespace cpu::x64; |
33 | if (mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2)) { |
34 | cvt_xf16_support::jit_call_t p; |
35 | p.inp = (void *)inp; |
36 | p.out = (void *)out; |
37 | static const jit_cvt_ps_to_xf16_t cvt_one_ps_to_f16(data_type::f16, 1); |
38 | cvt_one_ps_to_f16(&p); |
39 | return true; |
40 | } |
41 | #endif |
42 | return false; |
43 | } |
44 | |
45 | void cvt_float_to_float16(float16_t *out, const float *inp, size_t nelems) { |
46 | #if DNNL_X64 |
47 | using namespace cpu::x64; |
48 | if (mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2)) { |
49 | cvt_xf16_support::jit_call_t p_; |
50 | p_.inp = (void *)inp; |
51 | p_.out = (void *)out; |
52 | p_.nelems = nelems; |
53 | static const jit_cvt_ps_to_xf16_t cvt_ps_to_f16(data_type::f16); |
54 | cvt_ps_to_f16(&p_); |
55 | return; |
56 | } |
57 | #endif |
58 | |
59 | PRAGMA_OMP_SIMD() |
60 | for (size_t i = 0; i < nelems; ++i) |
61 | out[i] = static_cast<float16_t>(inp[i]); |
62 | } |
63 | |
64 | void cvt_float16_to_float(float *out, const float16_t *inp, size_t nelems) { |
65 | #if DNNL_X64 |
66 | using namespace cpu::x64; |
67 | if (mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2)) { |
68 | static const jit_cvt_xf16_to_ps_t kernel(data_type::f16, false); |
69 | return kernel(out, inp, nelems); |
70 | } |
71 | #endif |
72 | |
73 | PRAGMA_OMP_SIMD() |
74 | for (size_t i = 0; i < nelems; ++i) |
75 | out[i] = inp[i]; |
76 | } |
77 | |
78 | void add_floats_and_cvt_to_float16( |
79 | float16_t *out, const float *inp0, const float *inp1, size_t nelems) { |
80 | #if DNNL_X64 |
81 | if (cpu::x64::mayiuse(cpu::x64::cpu_isa_t::avx512_core_fp16)) { |
82 | cpu::x64::f16_support::jit_call_t p_; |
83 | p_.inp = (void *)inp0; |
84 | p_.add = (void *)inp1; |
85 | p_.out = (void *)out; |
86 | p_.nelems = nelems; |
87 | static const cpu::x64::jit_avx512_core_fp16_add_cvt_ps_to_f16_t |
88 | add_cvt_ps_to_f16; |
89 | add_cvt_ps_to_f16(&p_); |
90 | return; |
91 | } |
92 | #endif |
93 | |
94 | PRAGMA_OMP_SIMD() |
95 | for (size_t i = 0; i < nelems; ++i) |
96 | out[i] = static_cast<float16_t>(inp0[i] + inp1[i]); |
97 | } |
98 | |
99 | } // namespace impl |
100 | } // namespace dnnl |
101 | |