1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3* Copyright 2018 YANDEX LLC
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#ifndef CPU_X64_JIT_UNI_POOL_KERNEL_HPP
19#define CPU_X64_JIT_UNI_POOL_KERNEL_HPP
20
21#include <cfloat>
22#include <functional>
23#include <memory>
24
25#include "common/memory_tracking.hpp"
26
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28#include "cpu/x64/jit_generator.hpp"
29#include "cpu/x64/jit_primitive_conf.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36struct bf16_emulation_t;
37
38template <cpu_isa_t isa>
39struct jit_uni_pool_kernel : public jit_generator {
40
41 jit_uni_pool_kernel(
42 const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md);
43 jit_pool_conf_t jpp;
44 ~jit_uni_pool_kernel();
45
46 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel)
47
48 static status_t init_conf(jit_pool_conf_t &jbp,
49 memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr,
50 const pooling_pd_t *ppd);
51
52private:
53 using Xmm = Xbyak::Xmm;
54 using Ymm = Xbyak::Ymm;
55 using Zmm = Xbyak::Zmm;
56 using Opmask = Xbyak::Opmask;
57 using Reg32 = Xbyak::Reg32;
58 using Reg64 = Xbyak::Reg64;
59
60 using Vmm = typename cpu_isa_traits<isa>::Vmm;
61
62 int vmm_idx_upper_bound() const noexcept {
63 return is_superset(isa, avx512_core) ? 31 : 15;
64 }
65
66 int reg_idx(int idx) const noexcept { return vmm_idx_upper_bound() - idx; }
67
68 Xmm xreg(int idx) const noexcept { return Xmm(reg_idx(idx)); }
69 Ymm yreg(int idx) const noexcept { return Ymm(reg_idx(idx)); }
70 Zmm zreg(int idx) const noexcept { return Zmm(reg_idx(idx)); }
71 Vmm vreg(int idx) const noexcept { return Vmm(reg_idx(idx)); }
72
73 const Xbyak::AddressFrame &vmmword = (isa == sse41)
74 ? xword
75 : utils::one_of(isa, avx, avx2, avx2_vnni_2) ? yword : zword;
76
77 Xmm vmm_mask = Xmm(0);
78 Xmm xmm_tmp_1 = Xmm(0);
79 Ymm ymm_tmp_1 = Ymm(0);
80 Vmm vmm_tmp_1 = Vmm(0);
81
82 // Used only for avx and if c tail is present
83 Vmm vmm_c_tail_mask = Vmm(2);
84 Xmm xmm_c_tail_mask = Xmm(2);
85
86 Xmm xmm_tmp = Xmm(3);
87
88 Vmm vmm_ker_area_h = Vmm(2);
89 Vmm vmm_one = Vmm(2);
90 Vmm vmm_tmp = Vmm(3);
91 Ymm ymm_tmp = Ymm(3);
92
93 Vmm vmm_k_offset = Vmm(1);
94
95 // Used only for avx512 when bf16 is present
96 inline Vmm vmm_idx() {
97 if (!jpp.is_backward) {
98 return (jpp.is_training) ? Vmm(4) : Vmm(1);
99 } else
100 return Vmm(4);
101 }
102
103 Zmm bf16_emu_reserv_1 = Zmm(5);
104 Zmm bf16_emu_reserv_2 = Zmm(6);
105 Zmm bf16_emu_reserv_3 = Zmm(7);
106 Reg64 bf16_emu_reserv_4 = r11;
107 Zmm bf16_emu_reserv_5 = Zmm(8);
108
109 Opmask k_c_tail_mask = Opmask(4);
110 Opmask k_mask_cvt = Opmask(5);
111 Opmask k_store_mask = Opmask(6);
112
113 // Here be some (tame) dragons. This kernel does not follow the regular
114 // OS-agnostic ABI pattern because when isa is sse41 it uses maskmovdqu
115 // instruction which has its destination hardcoded in rdi. Therefore:
116 // - all registers are hardcoded
117 // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI
118 //
119 // While this is only required by the backward pass, the quirk above
120 // is applied to the forward pass as well to keep things simpler.
121
122 using reg64_t = const Reg64;
123 reg64_t reg_param = rdi; // Always mimic the Unix ABI
124 reg64_t reg_input = r8;
125 reg64_t aux_reg_input = r9;
126 reg64_t reg_index = r10;
127 reg64_t reg_output = r12;
128 reg64_t reg_kd_pad_shift = r13;
129 reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu
130
131 reg64_t kj = r14;
132 reg64_t oi_iter = r15;
133 reg64_t reg_kh = rax;
134 reg64_t reg_k_shift = rbx;
135 reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above
136 reg64_t reg_ker_area_h = rdx;
137 reg64_t reg_nbc = rsi;
138
139 reg64_t reg_zero_ptr = r9;
140 reg64_t reg_zero_id = r13;
141 reg64_t reg_zero_ih = r14;
142 reg64_t aux_reg_zero_ih = r15;
143 reg64_t ki = r12;
144 reg64_t aux_reg_input_d = r8;
145
146 Reg32 reg_shuf_mask = esi;
147
148 bool sse_high_half = false;
149 bool disable_postops_when_sse_high_half_processed_ = false;
150
151 int prev_kw;
152
153 void prepare_tail_mask();
154 void put_one_in_vmm();
155 void uni_broadcast_reg_val(const int reg_idx, const int vmm_idx);
156 void push_vmm_val(const int idx);
157 void pop_vmm_val(const int idx);
158 void load(const int idx, const reg64_t &reg_ptr, const int offset,
159 const bool is_c_tail_proccessing);
160 void store(const int idx, const reg64_t &reg_ptr, const int offset,
161 const bool is_c_tail_proccessing);
162
163 void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r,
164 bool with_c_tail_proccessing);
165 void avg_step(int ur_w, int ur_bc, int pad_l, int pad_r,
166 bool with_c_tail_proccessing);
167 void max_step_fwd(int ur_w, int ur_bc, int pad_l, int pad_r,
168 bool with_c_tail_proccessing);
169 void max_step_bwd(int ur_w, int ur_bc, int pad_l, int pad_r,
170 bool with_c_tail_proccessing);
171
172 void zero_diff_src(int ur_bc, bool with_c_tail_proccessing);
173
174 void step(int ur_w, int ur_bc, int pad_l, int pad_r,
175 bool with_c_tail_proccessing) {
176 if (jpp.alg == alg_kind::pooling_max) {
177 if (jpp.is_backward)
178 max_step_bwd(
179 ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing);
180 else
181 max_step_fwd(
182 ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing);
183 } else
184 avg_step(ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing);
185 }
186
187 void step_high_half(int ur_w, int ur_bc, int pad_l, int pad_r,
188 bool with_c_tail_processing) {
189 add(reg_input, sizeof(float) * 4);
190 add(reg_output, sizeof(float) * 4);
191 if (jpp.alg == alg_kind::pooling_max
192 && (jpp.is_training || jpp.is_backward))
193 add(reg_index, types::data_type_size(jpp.ind_dt) * 4);
194
195 step(ur_w, ur_bc, pad_l, pad_r, with_c_tail_processing);
196 }
197
198 void generate() override;
199
200 void avx_vpadd1(const Ymm &y0, const Xmm &x1, const Xmm &xtmp) {
201 assert(y0.getIdx() != x1.getIdx());
202 vextractf128(xtmp, y0, 0);
203 vpaddd(xtmp, xtmp, x1);
204 vinsertf128(y0, y0, xtmp, 0);
205 vextractf128(xtmp, y0, 1);
206 vpaddd(xtmp, xtmp, x1);
207 vinsertf128(y0, y0, xtmp, 1);
208 }
209
210 void avx_vpadd1(const Xmm &x0, const Xmm &x1, const Xmm &) {
211 assert(false /*function should not be used*/);
212 paddd(x0, x1);
213 }
214
215 void avx_pmovzxbd(const Ymm &y0, const Xmm &x1, const Xmm &xtmp) {
216 Xmm x0(y0.getIdx());
217 pshufd(xmm_tmp, x1, 1);
218 pmovzxbd(x0, x1);
219 pmovzxbd(xmm_tmp, xmm_tmp);
220 vinsertf128(y0, y0, xmm_tmp, 1);
221 }
222
223 void avx_pmovzxbd(const Xmm &x0, const Xmm &x1, const Xmm &) {
224 assert(false /*function should not be used*/);
225 pmovzxbd(x0, x1);
226 }
227
228 void avx_pcmpeqd(
229 const Ymm &y0, const Ymm &y1, const Ymm &y2, const Xmm &xtmp) {
230 assert(y0.getIdx() != y1.getIdx());
231 assert(y0.getIdx() != y2.getIdx());
232 Xmm x0(y0.getIdx());
233 Xmm x2(y2.getIdx());
234 vextractf128(x0, y1, 1);
235 vextractf128(xtmp, y2, 1);
236 pcmpeqd(xtmp, x0);
237 vextractf128(x0, y1, 0);
238 pcmpeqd(x0, x2);
239 vinsertf128(y0, y0, xtmp, 1);
240 }
241
242 void avx_pcmpeqd(const Xmm &x0, const Xmm &x1, const Xmm &, const Xmm &) {
243 assert(false /*function should not be used*/);
244 pcmpeqd(x0, x1);
245 }
246
247 void apply_postops(int ur_bc, int ur_w, int c_block,
248 const std::function<bool(int, bool)> &is_tail_predicate);
249
250 static bool post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr,
251 const memory_desc_wrapper &dst_d);
252
253 inline bool use_bf16_emulation() const {
254 return jpp.is_bf16 && !isa_has_bf16(jpp.isa) && isa != avx2_vnni_2;
255 }
256
257 std::unique_ptr<bf16_emulation_t> bf16_emu_;
258 std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
259 postops_injector_;
260};
261
262} // namespace x64
263} // namespace cpu
264} // namespace impl
265} // namespace dnnl
266
267#endif
268
269// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
270