1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_UTILS_JIT_IO_HELPER_HPP |
18 | #define CPU_X64_UTILS_JIT_IO_HELPER_HPP |
19 | |
20 | #include <map> |
21 | #include <memory> |
22 | #include <unordered_set> |
23 | |
24 | #include "common/optional.hpp" |
25 | |
26 | #include "cpu/x64/cpu_isa_traits.hpp" |
27 | #include "cpu/x64/jit_generator.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace x64 { |
33 | |
34 | struct bf16_emulation_t; |
35 | |
36 | namespace io { |
37 | |
38 | class io_conf_t { |
39 | public: |
40 | io_conf_t() = default; |
41 | io_conf_t(const bool nt_stores_enabled); |
42 | io_conf_t(const io_conf_t &other) = default; |
43 | |
44 | io_conf_t &operator=(const io_conf_t &other) = default; |
45 | |
46 | bool nt_stores_enabled_ = false; |
47 | }; |
48 | |
49 | class io_tail_conf_t { |
50 | public: |
51 | io_tail_conf_t(const std::size_t simd_w, const std::size_t tail_size, |
52 | const Xbyak::Opmask &tail_opmask, const int tail_vmm_mask_idx, |
53 | const Xbyak::Reg64 ®_tmp); |
54 | io_tail_conf_t(const std::size_t simd_w, const std::size_t tail_size, |
55 | int tail_opmask_idx, const int tail_vmm_mask_idx, |
56 | const Xbyak::Reg64 ®_tmp); |
57 | io_tail_conf_t(const io_tail_conf_t &other) = default; |
58 | |
59 | io_tail_conf_t &operator=(const io_tail_conf_t &other) = default; |
60 | |
61 | std::size_t simd_w_ = 0; |
62 | std::size_t tail_size_ = 0; |
63 | Xbyak::Opmask tail_opmask_ = Xbyak::Opmask(); |
64 | int tail_vmm_mask_idx_ = 0; |
65 | Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64(); |
66 | }; |
67 | |
68 | class io_emu_bf16_conf_t { |
69 | public: |
70 | io_emu_bf16_conf_t() = default; |
71 | io_emu_bf16_conf_t(const Xbyak::Zmm &bf16_emu_reserv_1, |
72 | const Xbyak::Zmm &bf16_emu_reserv_2, |
73 | const Xbyak::Zmm &bf16_emu_reserv_3, const Xbyak::Reg64 ®_tmp, |
74 | const Xbyak::Zmm &bf16_emu_reserv_4); |
75 | io_emu_bf16_conf_t(int bf16_emu_reserv_1_idx, int bf16_emu_reserv_2_idx, |
76 | int bf16_emu_reserv_3_idx, const Xbyak::Reg64 ®_tmp, |
77 | int bf16_emu_reserv_4_idx); |
78 | io_emu_bf16_conf_t(const io_emu_bf16_conf_t &other) = default; |
79 | |
80 | io_emu_bf16_conf_t &operator=(const io_emu_bf16_conf_t &other) = default; |
81 | |
82 | Xbyak::Zmm bf16_emu_reserv_1_ = Xbyak::Zmm(28); |
83 | Xbyak::Zmm bf16_emu_reserv_2_ = Xbyak::Zmm(29); |
84 | Xbyak::Zmm bf16_emu_reserv_3_ = Xbyak::Zmm(30); |
85 | Xbyak::Reg64 reg_tmp_ = Xbyak::util::rax; |
86 | Xbyak::Zmm bf16_emu_reserv_4_ = Xbyak::Zmm(31); |
87 | }; |
88 | |
89 | class io_saturation_conf_t { |
90 | public: |
91 | io_saturation_conf_t(const int vreg_zero_saturation_idx, |
92 | const int vreg_saturation_ubound_idx, const Xbyak::Reg64 ®_tmp); |
93 | io_saturation_conf_t(const io_saturation_conf_t &other) = default; |
94 | |
95 | io_saturation_conf_t &operator=(const io_saturation_conf_t &other) |
96 | = default; |
97 | |
98 | int vreg_zero_saturation_idx_ = 0; |
99 | int vreg_saturation_ubound_idx_ = 0; |
100 | Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64(); |
101 | }; |
102 | |
103 | class io_gather_conf_t { |
104 | public: |
105 | io_gather_conf_t(const std::size_t simd_w, const Xbyak::Opmask &full_opmask, |
106 | const int full_vmm_mask_idx, const Xbyak::Reg64 ®_tmp, |
107 | const Xbyak::Reg64 ®_tmp1, |
108 | const utils::optional_t<int> &vmm_tmp_idx = utils::nullopt); |
109 | io_gather_conf_t(const io_gather_conf_t &other) = default; |
110 | |
111 | io_gather_conf_t &operator=(const io_gather_conf_t &other) = default; |
112 | |
113 | std::size_t simd_w_ = 0; |
114 | Xbyak::Opmask full_opmask_ = Xbyak::Opmask(); |
115 | int full_vmm_mask_idx_ = 0; |
116 | Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64(); |
117 | Xbyak::Reg64 reg_tmp1_ = Xbyak::Reg64(); |
118 | // It is needed, when io_helper use emulation for gather |
119 | // and it is not needed for sse. |
120 | utils::optional_t<int> vmm_tmp_idx_ = utils::nullopt; |
121 | }; |
122 | |
123 | template <typename Vmm> |
124 | class jit_io_multi_dt_helper_t; |
125 | |
126 | template <typename Vmm> |
127 | class jit_io_helper_t { |
128 | public: |
129 | friend class jit_io_multi_dt_helper_t<Vmm>; |
130 | |
131 | jit_io_helper_t() = default; |
132 | jit_io_helper_t(jit_generator *host, const cpu_isa_t &isa, |
133 | const data_type_t &data_type, const io_conf_t &io_conf, |
134 | const utils::optional_t<io_tail_conf_t> &tail_conf = utils::nullopt, |
135 | const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf |
136 | = utils::nullopt, |
137 | const utils::optional_t<io_saturation_conf_t> &saturation_conf |
138 | = utils::nullopt, |
139 | const utils::optional_t<io_gather_conf_t> &gather_conf |
140 | = utils::nullopt); |
141 | jit_io_helper_t(jit_io_helper_t &&) = default; |
142 | jit_io_helper_t &operator=(jit_io_helper_t &&) = default; |
143 | |
144 | ~jit_io_helper_t(); |
145 | void prepare_tail_mask(); |
146 | void prepare_full_mask(); |
147 | /* |
148 | * Sometimes the values in the register can be nan at the |
149 | * beginning of the kernel, then using vcmpps(vmm, vmm, vmm) |
150 | * will not set all bits to 1, instead of that this instruction will |
151 | * return zero. At the beginning, it is worth to zeroing |
152 | * full mask vmm to be sure, that vcmpps work properly. |
153 | */ |
154 | void init_full_mask(); |
155 | void init_saturate_f32() const; |
156 | void init_bf16(); |
157 | void gather(const Xbyak::Reg64 &src_reg, const Vmm &indices_vmm, |
158 | const Vmm &dst_vmm, const bool tail); |
159 | void broadcast(const Xbyak::Address &src_addr, const Vmm &dst_vmm); |
160 | void load(const Xbyak::Address &src_addr, const Vmm &dst_vmm, |
161 | const bool tail); |
162 | void store(const Vmm &src_vmm, const Xbyak::Address &dst_addr, |
163 | const bool tail); |
164 | // Load odd and even data into two registers with ne_convert instructions |
165 | // Then we can call merge_interleaved_to_plain() to merge them into |
166 | // plain layouts when needed |
167 | void load_two_simdw_xf16(const Xbyak::Address &src_addr, |
168 | const Vmm &dst_even_vmm, const Vmm &dst_odd_vmm); |
169 | void merge_interleaved_to_plain( |
170 | const Vmm &vmm_even, const Vmm &vmm_odd, const Vmm &vmm_aux0); |
171 | |
172 | private: |
173 | bool is_data_type_supported(const data_type_t dt); |
174 | void prepare_opmask(const std::size_t how_many_bits_to_set, |
175 | const Xbyak::Reg64 ®_tmp, const Xbyak::Opmask &mask); |
176 | void prepare_vmm_mask(const std::size_t how_many_bits_to_set, |
177 | const std::size_t simd_w, const Xbyak::Reg64 ®_tmp, |
178 | const Vmm &mask); |
179 | void prepare_i8_data_to_store(const Vmm &i8_vmm); |
180 | void prepare_xf16_data_to_store(const Vmm &vmm); |
181 | // Emulates the behavior of vgatherdps for architectures |
182 | // that do not support this instruction. |
183 | void emu_gather(const Xbyak::Reg64 &src_reg, const Vmm &indices_vmm, |
184 | const Vmm &dst_vmm, const bool tail); |
185 | void load_byte_by_byte(const Xbyak::Address &src_addr, const Vmm &dst_vmm, |
186 | const int load_size); |
187 | void load_f32(const Xbyak::Address &src_addr, const Vmm &dst_vmm, |
188 | const bool tail); |
189 | void load_s32(const Xbyak::Address &src_addr, const Vmm &dst_vmm, |
190 | const bool tail); |
191 | void load_bf16(const Xbyak::Address &src_addr, const Vmm &dst_vmm); |
192 | void load_f16(const Xbyak::Address &src_addr, const Vmm &dst_vmm); |
193 | void load_i8(const Xbyak::Address &src_addr, const Vmm &dst_vmm); |
194 | void saturate(const Vmm &vmm); |
195 | void store_byte_by_byte(const Vmm &src_vmm, const Xbyak::Address &dst_addr, |
196 | const int store_size); |
197 | void store_f32(const Vmm &src_vmm, const Xbyak::Address &dst_addr, |
198 | const bool tail); |
199 | void store_bf16(const Vmm &src_vmm, const Xbyak::Address &dst_addr); |
200 | void store_f16(const Vmm &src_vmm, const Xbyak::Address &dst_addr); |
201 | void store_i8(const Vmm &src_vmm, const Xbyak::Address &dst_addr); |
202 | void convert_to_f32(const Vmm &dst_vmm, const Xbyak::Xmm &src_vmm, |
203 | const data_type_t src_data_type); |
204 | |
205 | jit_generator *host_; |
206 | const cpu_isa_t isa_; |
207 | const data_type_t data_type_; |
208 | const bool bf16_supported_; |
209 | const bool f16_supported_; |
210 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
211 | const io_conf_t io_conf_; |
212 | const utils::optional_t<io_tail_conf_t> tail_conf_; |
213 | const utils::optional_t<io_emu_bf16_conf_t> bf16_conf_; |
214 | const utils::optional_t<io_saturation_conf_t> saturation_conf_; |
215 | const utils::optional_t<io_gather_conf_t> gather_conf_; |
216 | }; |
217 | |
218 | template <typename Vmm> |
219 | class jit_io_multi_dt_helper_t { |
220 | public: |
221 | using data_types_t = std::unordered_set<data_type_t, std::hash<int>>; |
222 | using saturation_map_t = std::map<data_type_t, io_saturation_conf_t>; |
223 | |
224 | jit_io_multi_dt_helper_t() = default; |
225 | jit_io_multi_dt_helper_t(jit_generator *host, const cpu_isa_t &isa, |
226 | const data_types_t &data_types, const io_conf_t &io_conf, |
227 | const utils::optional_t<io_tail_conf_t> &tail_conf = utils::nullopt, |
228 | const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf |
229 | = utils::nullopt, |
230 | const saturation_map_t &saturation_confs = saturation_map_t {}, |
231 | const utils::optional_t<io_gather_conf_t> &gather_conf |
232 | = utils::nullopt); |
233 | ~jit_io_multi_dt_helper_t(); |
234 | void prepare_tail_mask(); |
235 | void prepare_full_mask(); |
236 | void init_saturate_f32(const data_types_t &store_data_types); |
237 | void init_full_mask(); |
238 | void init_bf16(); |
239 | |
240 | std::shared_ptr<jit_io_helper_t<Vmm>> at(const data_type_t dt) const; |
241 | std::shared_ptr<jit_io_helper_t<Vmm>> operator[]( |
242 | const data_type_t dt) const; |
243 | |
244 | private: |
245 | std::unordered_map<data_type_t, std::shared_ptr<jit_io_helper_t<Vmm>>, |
246 | std::hash<int>> |
247 | storage_; |
248 | }; |
249 | |
250 | } // namespace io |
251 | } // namespace x64 |
252 | } // namespace cpu |
253 | } // namespace impl |
254 | } // namespace dnnl |
255 | |
256 | #endif |
257 | |