1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_UTILS_JIT_IO_HELPER_HPP
18#define CPU_X64_UTILS_JIT_IO_HELPER_HPP
19
20#include <map>
21#include <memory>
22#include <unordered_set>
23
24#include "common/optional.hpp"
25
26#include "cpu/x64/cpu_isa_traits.hpp"
27#include "cpu/x64/jit_generator.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34struct bf16_emulation_t;
35
36namespace io {
37
38class io_conf_t {
39public:
40 io_conf_t() = default;
41 io_conf_t(const bool nt_stores_enabled);
42 io_conf_t(const io_conf_t &other) = default;
43
44 io_conf_t &operator=(const io_conf_t &other) = default;
45
46 bool nt_stores_enabled_ = false;
47};
48
49class io_tail_conf_t {
50public:
51 io_tail_conf_t(const std::size_t simd_w, const std::size_t tail_size,
52 const Xbyak::Opmask &tail_opmask, const int tail_vmm_mask_idx,
53 const Xbyak::Reg64 &reg_tmp);
54 io_tail_conf_t(const std::size_t simd_w, const std::size_t tail_size,
55 int tail_opmask_idx, const int tail_vmm_mask_idx,
56 const Xbyak::Reg64 &reg_tmp);
57 io_tail_conf_t(const io_tail_conf_t &other) = default;
58
59 io_tail_conf_t &operator=(const io_tail_conf_t &other) = default;
60
61 std::size_t simd_w_ = 0;
62 std::size_t tail_size_ = 0;
63 Xbyak::Opmask tail_opmask_ = Xbyak::Opmask();
64 int tail_vmm_mask_idx_ = 0;
65 Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64();
66};
67
68class io_emu_bf16_conf_t {
69public:
70 io_emu_bf16_conf_t() = default;
71 io_emu_bf16_conf_t(const Xbyak::Zmm &bf16_emu_reserv_1,
72 const Xbyak::Zmm &bf16_emu_reserv_2,
73 const Xbyak::Zmm &bf16_emu_reserv_3, const Xbyak::Reg64 &reg_tmp,
74 const Xbyak::Zmm &bf16_emu_reserv_4);
75 io_emu_bf16_conf_t(int bf16_emu_reserv_1_idx, int bf16_emu_reserv_2_idx,
76 int bf16_emu_reserv_3_idx, const Xbyak::Reg64 &reg_tmp,
77 int bf16_emu_reserv_4_idx);
78 io_emu_bf16_conf_t(const io_emu_bf16_conf_t &other) = default;
79
80 io_emu_bf16_conf_t &operator=(const io_emu_bf16_conf_t &other) = default;
81
82 Xbyak::Zmm bf16_emu_reserv_1_ = Xbyak::Zmm(28);
83 Xbyak::Zmm bf16_emu_reserv_2_ = Xbyak::Zmm(29);
84 Xbyak::Zmm bf16_emu_reserv_3_ = Xbyak::Zmm(30);
85 Xbyak::Reg64 reg_tmp_ = Xbyak::util::rax;
86 Xbyak::Zmm bf16_emu_reserv_4_ = Xbyak::Zmm(31);
87};
88
89class io_saturation_conf_t {
90public:
91 io_saturation_conf_t(const int vreg_zero_saturation_idx,
92 const int vreg_saturation_ubound_idx, const Xbyak::Reg64 &reg_tmp);
93 io_saturation_conf_t(const io_saturation_conf_t &other) = default;
94
95 io_saturation_conf_t &operator=(const io_saturation_conf_t &other)
96 = default;
97
98 int vreg_zero_saturation_idx_ = 0;
99 int vreg_saturation_ubound_idx_ = 0;
100 Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64();
101};
102
103class io_gather_conf_t {
104public:
105 io_gather_conf_t(const std::size_t simd_w, const Xbyak::Opmask &full_opmask,
106 const int full_vmm_mask_idx, const Xbyak::Reg64 &reg_tmp,
107 const Xbyak::Reg64 &reg_tmp1,
108 const utils::optional_t<int> &vmm_tmp_idx = utils::nullopt);
109 io_gather_conf_t(const io_gather_conf_t &other) = default;
110
111 io_gather_conf_t &operator=(const io_gather_conf_t &other) = default;
112
113 std::size_t simd_w_ = 0;
114 Xbyak::Opmask full_opmask_ = Xbyak::Opmask();
115 int full_vmm_mask_idx_ = 0;
116 Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64();
117 Xbyak::Reg64 reg_tmp1_ = Xbyak::Reg64();
118 // It is needed, when io_helper use emulation for gather
119 // and it is not needed for sse.
120 utils::optional_t<int> vmm_tmp_idx_ = utils::nullopt;
121};
122
123template <typename Vmm>
124class jit_io_multi_dt_helper_t;
125
126template <typename Vmm>
127class jit_io_helper_t {
128public:
129 friend class jit_io_multi_dt_helper_t<Vmm>;
130
131 jit_io_helper_t() = default;
132 jit_io_helper_t(jit_generator *host, const cpu_isa_t &isa,
133 const data_type_t &data_type, const io_conf_t &io_conf,
134 const utils::optional_t<io_tail_conf_t> &tail_conf = utils::nullopt,
135 const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf
136 = utils::nullopt,
137 const utils::optional_t<io_saturation_conf_t> &saturation_conf
138 = utils::nullopt,
139 const utils::optional_t<io_gather_conf_t> &gather_conf
140 = utils::nullopt);
141 jit_io_helper_t(jit_io_helper_t &&) = default;
142 jit_io_helper_t &operator=(jit_io_helper_t &&) = default;
143
144 ~jit_io_helper_t();
145 void prepare_tail_mask();
146 void prepare_full_mask();
147 /*
148 * Sometimes the values in the register can be nan at the
149 * beginning of the kernel, then using vcmpps(vmm, vmm, vmm)
150 * will not set all bits to 1, instead of that this instruction will
151 * return zero. At the beginning, it is worth to zeroing
152 * full mask vmm to be sure, that vcmpps work properly.
153 */
154 void init_full_mask();
155 void init_saturate_f32() const;
156 void init_bf16();
157 void gather(const Xbyak::Reg64 &src_reg, const Vmm &indices_vmm,
158 const Vmm &dst_vmm, const bool tail);
159 void broadcast(const Xbyak::Address &src_addr, const Vmm &dst_vmm);
160 void load(const Xbyak::Address &src_addr, const Vmm &dst_vmm,
161 const bool tail);
162 void store(const Vmm &src_vmm, const Xbyak::Address &dst_addr,
163 const bool tail);
164 // Load odd and even data into two registers with ne_convert instructions
165 // Then we can call merge_interleaved_to_plain() to merge them into
166 // plain layouts when needed
167 void load_two_simdw_xf16(const Xbyak::Address &src_addr,
168 const Vmm &dst_even_vmm, const Vmm &dst_odd_vmm);
169 void merge_interleaved_to_plain(
170 const Vmm &vmm_even, const Vmm &vmm_odd, const Vmm &vmm_aux0);
171
172private:
173 bool is_data_type_supported(const data_type_t dt);
174 void prepare_opmask(const std::size_t how_many_bits_to_set,
175 const Xbyak::Reg64 &reg_tmp, const Xbyak::Opmask &mask);
176 void prepare_vmm_mask(const std::size_t how_many_bits_to_set,
177 const std::size_t simd_w, const Xbyak::Reg64 &reg_tmp,
178 const Vmm &mask);
179 void prepare_i8_data_to_store(const Vmm &i8_vmm);
180 void prepare_xf16_data_to_store(const Vmm &vmm);
181 // Emulates the behavior of vgatherdps for architectures
182 // that do not support this instruction.
183 void emu_gather(const Xbyak::Reg64 &src_reg, const Vmm &indices_vmm,
184 const Vmm &dst_vmm, const bool tail);
185 void load_byte_by_byte(const Xbyak::Address &src_addr, const Vmm &dst_vmm,
186 const int load_size);
187 void load_f32(const Xbyak::Address &src_addr, const Vmm &dst_vmm,
188 const bool tail);
189 void load_s32(const Xbyak::Address &src_addr, const Vmm &dst_vmm,
190 const bool tail);
191 void load_bf16(const Xbyak::Address &src_addr, const Vmm &dst_vmm);
192 void load_f16(const Xbyak::Address &src_addr, const Vmm &dst_vmm);
193 void load_i8(const Xbyak::Address &src_addr, const Vmm &dst_vmm);
194 void saturate(const Vmm &vmm);
195 void store_byte_by_byte(const Vmm &src_vmm, const Xbyak::Address &dst_addr,
196 const int store_size);
197 void store_f32(const Vmm &src_vmm, const Xbyak::Address &dst_addr,
198 const bool tail);
199 void store_bf16(const Vmm &src_vmm, const Xbyak::Address &dst_addr);
200 void store_f16(const Vmm &src_vmm, const Xbyak::Address &dst_addr);
201 void store_i8(const Vmm &src_vmm, const Xbyak::Address &dst_addr);
202 void convert_to_f32(const Vmm &dst_vmm, const Xbyak::Xmm &src_vmm,
203 const data_type_t src_data_type);
204
205 jit_generator *host_;
206 const cpu_isa_t isa_;
207 const data_type_t data_type_;
208 const bool bf16_supported_;
209 const bool f16_supported_;
210 std::unique_ptr<bf16_emulation_t> bf16_emu_;
211 const io_conf_t io_conf_;
212 const utils::optional_t<io_tail_conf_t> tail_conf_;
213 const utils::optional_t<io_emu_bf16_conf_t> bf16_conf_;
214 const utils::optional_t<io_saturation_conf_t> saturation_conf_;
215 const utils::optional_t<io_gather_conf_t> gather_conf_;
216};
217
218template <typename Vmm>
219class jit_io_multi_dt_helper_t {
220public:
221 using data_types_t = std::unordered_set<data_type_t, std::hash<int>>;
222 using saturation_map_t = std::map<data_type_t, io_saturation_conf_t>;
223
224 jit_io_multi_dt_helper_t() = default;
225 jit_io_multi_dt_helper_t(jit_generator *host, const cpu_isa_t &isa,
226 const data_types_t &data_types, const io_conf_t &io_conf,
227 const utils::optional_t<io_tail_conf_t> &tail_conf = utils::nullopt,
228 const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf
229 = utils::nullopt,
230 const saturation_map_t &saturation_confs = saturation_map_t {},
231 const utils::optional_t<io_gather_conf_t> &gather_conf
232 = utils::nullopt);
233 ~jit_io_multi_dt_helper_t();
234 void prepare_tail_mask();
235 void prepare_full_mask();
236 void init_saturate_f32(const data_types_t &store_data_types);
237 void init_full_mask();
238 void init_bf16();
239
240 std::shared_ptr<jit_io_helper_t<Vmm>> at(const data_type_t dt) const;
241 std::shared_ptr<jit_io_helper_t<Vmm>> operator[](
242 const data_type_t dt) const;
243
244private:
245 std::unordered_map<data_type_t, std::shared_ptr<jit_io_helper_t<Vmm>>,
246 std::hash<int>>
247 storage_;
248};
249
250} // namespace io
251} // namespace x64
252} // namespace cpu
253} // namespace impl
254} // namespace dnnl
255
256#endif
257