1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_FP16CVT_HPP
18#define CPU_X64_JIT_AVX512_CORE_FP16CVT_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/float16.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26
27#include "cpu/x64/cpu_isa_traits.hpp"
28#include "cpu/x64/jit_generator.hpp"
29
30#include "oneapi/dnnl/dnnl_debug.h"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37namespace f16_support {
38struct jit_call_t {
39 void *inp;
40 void *out;
41 void *add;
42 size_t nelems;
43};
44} // namespace f16_support
45
46// performs element-by-element sum of inp and add float arrays and stores
47// result to float16 out array with downconversion
48struct jit_avx512_core_fp16_add_cvt_ps_to_f16_t : public jit_generator {
49 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_add_cvt_ps_to_f16)
50
51 jit_avx512_core_fp16_add_cvt_ps_to_f16_t()
52 : jit_generator(jit_name()), simd_w_(16) {
53 create_kernel();
54 }
55
56 void generate() override;
57
58 void operator()(f16_support::jit_call_t *params) const {
59 jit_generator::operator()(params);
60 msan_unpoison(params->out, params->nelems * sizeof(float16_t));
61 }
62
63private:
64 int simd_w_;
65
66 Xbyak::Opmask ktail_mask = k2;
67 Xbyak::Zmm fp32_inp = Xbyak::Zmm(0);
68 Xbyak::Zmm fp32_tmp = Xbyak::Zmm(1);
69
70 Xbyak::Zmm one = Xbyak::Zmm(2);
71 Xbyak::Zmm even = Xbyak::Zmm(3);
72 Xbyak::Zmm selector = Xbyak::Zmm(4);
73 Xbyak::Reg64 scratch = r15;
74
75 Xbyak::Ymm f16_out = Xbyak::Ymm(5);
76
77 Xbyak::Reg64 reg_inp = rax;
78 Xbyak::Reg64 reg_out = rbx;
79 Xbyak::Reg64 reg_add = r11;
80 Xbyak::Reg64 reg_nelems = rdx;
81
82 Xbyak::Reg64 reg64_tail = rcx;
83 Xbyak::Reg32 reg32_tail = ecx;
84 Xbyak::Reg8 reg8_mask_shift = cl;
85 Xbyak::Reg32 reg32_mask = r8d;
86};
87
88} // namespace x64
89} // namespace cpu
90} // namespace impl
91} // namespace dnnl
92
93#endif
94