1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * Copyright 2018 YANDEX LLC |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #ifndef CPU_X64_JIT_UNI_POOL_KERNEL_HPP |
19 | #define CPU_X64_JIT_UNI_POOL_KERNEL_HPP |
20 | |
21 | #include <cfloat> |
22 | #include <functional> |
23 | #include <memory> |
24 | |
25 | #include "common/memory_tracking.hpp" |
26 | |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | #include "cpu/x64/jit_generator.hpp" |
29 | #include "cpu/x64/jit_primitive_conf.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | struct bf16_emulation_t; |
37 | |
38 | template <cpu_isa_t isa> |
39 | struct jit_uni_pool_kernel : public jit_generator { |
40 | |
41 | jit_uni_pool_kernel( |
42 | const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md); |
43 | jit_pool_conf_t jpp; |
44 | ~jit_uni_pool_kernel(); |
45 | |
46 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel) |
47 | |
48 | static status_t init_conf(jit_pool_conf_t &jbp, |
49 | memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr, |
50 | const pooling_pd_t *ppd); |
51 | |
52 | private: |
53 | using Xmm = Xbyak::Xmm; |
54 | using Ymm = Xbyak::Ymm; |
55 | using Zmm = Xbyak::Zmm; |
56 | using Opmask = Xbyak::Opmask; |
57 | using Reg32 = Xbyak::Reg32; |
58 | using Reg64 = Xbyak::Reg64; |
59 | |
60 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
61 | |
62 | int vmm_idx_upper_bound() const noexcept { |
63 | return is_superset(isa, avx512_core) ? 31 : 15; |
64 | } |
65 | |
66 | int reg_idx(int idx) const noexcept { return vmm_idx_upper_bound() - idx; } |
67 | |
68 | Xmm xreg(int idx) const noexcept { return Xmm(reg_idx(idx)); } |
69 | Ymm yreg(int idx) const noexcept { return Ymm(reg_idx(idx)); } |
70 | Zmm zreg(int idx) const noexcept { return Zmm(reg_idx(idx)); } |
71 | Vmm vreg(int idx) const noexcept { return Vmm(reg_idx(idx)); } |
72 | |
73 | const Xbyak::AddressFrame &vmmword = (isa == sse41) |
74 | ? xword |
75 | : utils::one_of(isa, avx, avx2, avx2_vnni_2) ? yword : zword; |
76 | |
77 | Xmm vmm_mask = Xmm(0); |
78 | Xmm xmm_tmp_1 = Xmm(0); |
79 | Ymm ymm_tmp_1 = Ymm(0); |
80 | Vmm vmm_tmp_1 = Vmm(0); |
81 | |
82 | // Used only for avx and if c tail is present |
83 | Vmm vmm_c_tail_mask = Vmm(2); |
84 | Xmm xmm_c_tail_mask = Xmm(2); |
85 | |
86 | Xmm xmm_tmp = Xmm(3); |
87 | |
88 | Vmm vmm_ker_area_h = Vmm(2); |
89 | Vmm vmm_one = Vmm(2); |
90 | Vmm vmm_tmp = Vmm(3); |
91 | Ymm ymm_tmp = Ymm(3); |
92 | |
93 | Vmm vmm_k_offset = Vmm(1); |
94 | |
95 | // Used only for avx512 when bf16 is present |
96 | inline Vmm vmm_idx() { |
97 | if (!jpp.is_backward) { |
98 | return (jpp.is_training) ? Vmm(4) : Vmm(1); |
99 | } else |
100 | return Vmm(4); |
101 | } |
102 | |
103 | Zmm bf16_emu_reserv_1 = Zmm(5); |
104 | Zmm bf16_emu_reserv_2 = Zmm(6); |
105 | Zmm bf16_emu_reserv_3 = Zmm(7); |
106 | Reg64 bf16_emu_reserv_4 = r11; |
107 | Zmm bf16_emu_reserv_5 = Zmm(8); |
108 | |
109 | Opmask k_c_tail_mask = Opmask(4); |
110 | Opmask k_mask_cvt = Opmask(5); |
111 | Opmask k_store_mask = Opmask(6); |
112 | |
113 | // Here be some (tame) dragons. This kernel does not follow the regular |
114 | // OS-agnostic ABI pattern because when isa is sse41 it uses maskmovdqu |
115 | // instruction which has its destination hardcoded in rdi. Therefore: |
116 | // - all registers are hardcoded |
117 | // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI |
118 | // |
119 | // While this is only required by the backward pass, the quirk above |
120 | // is applied to the forward pass as well to keep things simpler. |
121 | |
122 | using reg64_t = const Reg64; |
123 | reg64_t reg_param = rdi; // Always mimic the Unix ABI |
124 | reg64_t reg_input = r8; |
125 | reg64_t aux_reg_input = r9; |
126 | reg64_t reg_index = r10; |
127 | reg64_t reg_output = r12; |
128 | reg64_t reg_kd_pad_shift = r13; |
129 | reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu |
130 | |
131 | reg64_t kj = r14; |
132 | reg64_t oi_iter = r15; |
133 | reg64_t reg_kh = rax; |
134 | reg64_t reg_k_shift = rbx; |
135 | reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above |
136 | reg64_t reg_ker_area_h = rdx; |
137 | reg64_t reg_nbc = rsi; |
138 | |
139 | reg64_t reg_zero_ptr = r9; |
140 | reg64_t reg_zero_id = r13; |
141 | reg64_t reg_zero_ih = r14; |
142 | reg64_t aux_reg_zero_ih = r15; |
143 | reg64_t ki = r12; |
144 | reg64_t aux_reg_input_d = r8; |
145 | |
146 | Reg32 reg_shuf_mask = esi; |
147 | |
148 | bool sse_high_half = false; |
149 | bool disable_postops_when_sse_high_half_processed_ = false; |
150 | |
151 | int prev_kw; |
152 | |
153 | void prepare_tail_mask(); |
154 | void put_one_in_vmm(); |
155 | void uni_broadcast_reg_val(const int reg_idx, const int vmm_idx); |
156 | void push_vmm_val(const int idx); |
157 | void pop_vmm_val(const int idx); |
158 | void load(const int idx, const reg64_t ®_ptr, const int offset, |
159 | const bool is_c_tail_proccessing); |
160 | void store(const int idx, const reg64_t ®_ptr, const int offset, |
161 | const bool is_c_tail_proccessing); |
162 | |
163 | void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r, |
164 | bool with_c_tail_proccessing); |
165 | void avg_step(int ur_w, int ur_bc, int pad_l, int pad_r, |
166 | bool with_c_tail_proccessing); |
167 | void max_step_fwd(int ur_w, int ur_bc, int pad_l, int pad_r, |
168 | bool with_c_tail_proccessing); |
169 | void max_step_bwd(int ur_w, int ur_bc, int pad_l, int pad_r, |
170 | bool with_c_tail_proccessing); |
171 | |
172 | void zero_diff_src(int ur_bc, bool with_c_tail_proccessing); |
173 | |
174 | void step(int ur_w, int ur_bc, int pad_l, int pad_r, |
175 | bool with_c_tail_proccessing) { |
176 | if (jpp.alg == alg_kind::pooling_max) { |
177 | if (jpp.is_backward) |
178 | max_step_bwd( |
179 | ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing); |
180 | else |
181 | max_step_fwd( |
182 | ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing); |
183 | } else |
184 | avg_step(ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing); |
185 | } |
186 | |
187 | void step_high_half(int ur_w, int ur_bc, int pad_l, int pad_r, |
188 | bool with_c_tail_processing) { |
189 | add(reg_input, sizeof(float) * 4); |
190 | add(reg_output, sizeof(float) * 4); |
191 | if (jpp.alg == alg_kind::pooling_max |
192 | && (jpp.is_training || jpp.is_backward)) |
193 | add(reg_index, types::data_type_size(jpp.ind_dt) * 4); |
194 | |
195 | step(ur_w, ur_bc, pad_l, pad_r, with_c_tail_processing); |
196 | } |
197 | |
198 | void generate() override; |
199 | |
200 | void avx_vpadd1(const Ymm &y0, const Xmm &x1, const Xmm &xtmp) { |
201 | assert(y0.getIdx() != x1.getIdx()); |
202 | vextractf128(xtmp, y0, 0); |
203 | vpaddd(xtmp, xtmp, x1); |
204 | vinsertf128(y0, y0, xtmp, 0); |
205 | vextractf128(xtmp, y0, 1); |
206 | vpaddd(xtmp, xtmp, x1); |
207 | vinsertf128(y0, y0, xtmp, 1); |
208 | } |
209 | |
210 | void avx_vpadd1(const Xmm &x0, const Xmm &x1, const Xmm &) { |
211 | assert(false /*function should not be used*/); |
212 | paddd(x0, x1); |
213 | } |
214 | |
215 | void avx_pmovzxbd(const Ymm &y0, const Xmm &x1, const Xmm &xtmp) { |
216 | Xmm x0(y0.getIdx()); |
217 | pshufd(xmm_tmp, x1, 1); |
218 | pmovzxbd(x0, x1); |
219 | pmovzxbd(xmm_tmp, xmm_tmp); |
220 | vinsertf128(y0, y0, xmm_tmp, 1); |
221 | } |
222 | |
223 | void avx_pmovzxbd(const Xmm &x0, const Xmm &x1, const Xmm &) { |
224 | assert(false /*function should not be used*/); |
225 | pmovzxbd(x0, x1); |
226 | } |
227 | |
228 | void avx_pcmpeqd( |
229 | const Ymm &y0, const Ymm &y1, const Ymm &y2, const Xmm &xtmp) { |
230 | assert(y0.getIdx() != y1.getIdx()); |
231 | assert(y0.getIdx() != y2.getIdx()); |
232 | Xmm x0(y0.getIdx()); |
233 | Xmm x2(y2.getIdx()); |
234 | vextractf128(x0, y1, 1); |
235 | vextractf128(xtmp, y2, 1); |
236 | pcmpeqd(xtmp, x0); |
237 | vextractf128(x0, y1, 0); |
238 | pcmpeqd(x0, x2); |
239 | vinsertf128(y0, y0, xtmp, 1); |
240 | } |
241 | |
242 | void avx_pcmpeqd(const Xmm &x0, const Xmm &x1, const Xmm &, const Xmm &) { |
243 | assert(false /*function should not be used*/); |
244 | pcmpeqd(x0, x1); |
245 | } |
246 | |
247 | void apply_postops(int ur_bc, int ur_w, int c_block, |
248 | const std::function<bool(int, bool)> &is_tail_predicate); |
249 | |
250 | static bool post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr, |
251 | const memory_desc_wrapper &dst_d); |
252 | |
253 | inline bool use_bf16_emulation() const { |
254 | return jpp.is_bf16 && !isa_has_bf16(jpp.isa) && isa != avx2_vnni_2; |
255 | } |
256 | |
257 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
258 | std::unique_ptr<injector::jit_uni_postops_injector_t<isa>> |
259 | postops_injector_; |
260 | }; |
261 | |
262 | } // namespace x64 |
263 | } // namespace cpu |
264 | } // namespace impl |
265 | } // namespace dnnl |
266 | |
267 | #endif |
268 | |
269 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
270 | |