1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "common/float16.hpp"
20
21#include "cpu/x64/cpu_isa_traits.hpp"
22#include "cpu/x64/jit_avx512_core_fp16cvt.hpp"
23#include "cpu/x64/jit_generator.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace Xbyak;
31
32#define GET_OFF(field) offsetof(f16_support::jit_call_t, field)
33
34void jit_avx512_core_fp16_add_cvt_ps_to_f16_t::generate() {
35 preamble();
36
37 auto add_cvt = [&](size_t idx, Xbyak::Opmask ktail_mask) {
38 vmovups(fp32_inp | ktail_mask | T_z,
39 ptr[reg_inp + sizeof(float) * (idx)]);
40 vaddps(fp32_inp | ktail_mask | T_z, fp32_inp,
41 ptr[reg_add + sizeof(float) * (idx)]);
42
43 vcvtps2ph(f16_out, fp32_inp, _op_mxcsr);
44
45 vmovdqu16(yword[reg_out + sizeof(float16_t) * (idx)] | ktail_mask,
46 f16_out);
47 };
48
49 mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
50 mov(reg_add, ptr[abi_param1 + GET_OFF(add)]);
51 mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
52 mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]);
53
54 mov(reg32_tail, 0xffff);
55 kmovw(ktail_mask, reg32_tail);
56
57 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
58 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
59 for (int i = n_unroll; i >= 0; i--) {
60 const int unroll = 1 << i; // 4, 2, 1
61 L(l_simd_loop[i + 1]);
62 {
63 cmp(reg_nelems, simd_w_ * unroll);
64 jl(l_simd_loop[i], T_NEAR);
65 for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
66 add_cvt(j, ktail_mask);
67 }
68 add(reg_inp, simd_w_ * unroll * sizeof(float));
69 add(reg_add, simd_w_ * unroll * sizeof(float));
70 add(reg_out, simd_w_ * unroll * sizeof(float16_t));
71
72 sub(reg_nelems, simd_w_ * unroll);
73 jmp(l_simd_loop[i + 1], T_NEAR);
74 }
75 }
76 L(l_simd_loop[0]);
77 test(reg_nelems, reg_nelems);
78 jz(l_simd_notail);
79 // JIT of `tail_mask_ = (1 << (nelems_ % simd_w_)) - 1;`
80 mov(reg32_mask, 1);
81 mov(reg64_tail, reg_nelems);
82 shl(reg32_mask, reg8_mask_shift);
83 sub(reg32_mask, 1);
84 kmovd(ktail_mask, reg32_mask);
85 add_cvt(0, ktail_mask);
86 L(l_simd_notail);
87
88 postamble();
89}
90#undef GET_OFF
91
92} // namespace x64
93} // namespace cpu
94} // namespace impl
95} // namespace dnnl
96