1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_SHUFFLE_KERNEL_HPP
18#define CPU_X64_JIT_UNI_SHUFFLE_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/cpu_shuffle_pd.hpp"
25
26#include "cpu/x64/cpu_isa_traits.hpp"
27#include "cpu/x64/jit_generator.hpp"
28#include "cpu/x64/jit_primitive_conf.hpp"
29#include "cpu/x64/shuffle/jit_uni_shuffle.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36using namespace Xbyak;
37
38template <cpu_isa_t isa>
39struct jit_uni_shuffle_kernel_t : public jit_generator {
40 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_shuffle_kernel_t)
41
42 jit_uni_shuffle_kernel_t(const jit_shuffle_conf_t conf);
43
44 using Vmm = typename cpu_isa_traits<isa>::Vmm;
45
46 constexpr int vmm_idx(int idx) const {
47 return (cpu_isa_traits<isa>::n_vregs - 1) - idx;
48 }
49
50 /*
51 * Prepare the mask to be used during tail processing.
52 * vmm_tail_mask_ is filled if it is avx and
53 * if it is avx512_core at least then k_tail_mask_ is filled.
54 */
55 void prepare_mask();
56
57 /*
58 * Emulates the behavior of vgatherdps for architectures
59 * that do not support this instruction.
60 */
61 void emu_gather_data(const Reg64 &reg_src_addr, const int indices_idx,
62 const int data_idx, const bool is_tail = false);
63
64 void gather_data(const Reg64 &reg_src_addr, const int indices_idx,
65 const int data_idx, const bool is_tail = false);
66
67 void store_data(const int data_idx, const Reg64 &reg_dst_addr,
68 const int offset = 0, const bool is_tail = false);
69
70 void shuffle_blocked_format();
71
72 void append_zero_padding(
73 const Reg64 &reg_dst_addr, const bool zero_extend_write);
74
75 void generate() override;
76
77 const Vmm vmm_tail_mask_ = Vmm(0);
78 // Used only for avx
79 // Vgatherdps always gets data using a conditional mask
80 // This register contains all bits set to 1, allowing
81 // to get the maximum number of values available to the register
82 const Vmm vmm_full_mask_ = Vmm(1);
83 const Vmm vmm_src_ = Vmm(2);
84 const Vmm vmm_tmp_ = Vmm(3);
85 const Vmm vmm_indices_ = Vmm(4);
86 const Vmm vmm_zero_ = Vmm(11);
87
88 const Opmask k_tail_mask_ = k1;
89 const Opmask k_full_mask_ = k2;
90
91 const Reg64 &reg_tmp_ = rax;
92 const Reg64 &reg_dst_ = rbx;
93 const Reg64 &reg_indices_ = rcx;
94 const Reg64 &reg_work_ = rdx;
95 // Always mimic the Unix ABI
96 const Reg64 &reg_param = rdi;
97 const Reg64 &reg_src_ = rsi;
98 const Reg64 &reg_tmp1_ = r8;
99 const Reg64 &reg_tmp2_ = r9;
100 const Reg64 &reg_tmp3_ = r10;
101 const Reg64 &reg_tmp4_ = r11;
102 const Reg64 &reg_tmp5_ = r12;
103 const Reg64 &reg_tmp6_ = r13;
104 const Reg8 &reg_padded_block = r14b;
105
106 const jit_shuffle_conf_t conf_;
107 const size_t padding_size_;
108};
109
110} // namespace x64
111} // namespace cpu
112} // namespace impl
113} // namespace dnnl
114
115#endif
116
117// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
118