1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * Copyright 2018 YANDEX LLC |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #include <bitset> |
19 | |
20 | #include "common/dnnl_thread.hpp" |
21 | |
22 | #include "cpu/cpu_pooling_pd.hpp" |
23 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
24 | #include "cpu/x64/jit_uni_pool_kernel.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace Xbyak; |
32 | using namespace alg_kind; |
33 | |
34 | #define GET_OFF(field) offsetof(jit_pool_call_s, field) |
35 | |
36 | static bcast_set_t get_supported_bcast_strategies() { |
37 | return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
38 | broadcasting_strategy_t::no_broadcast}; |
39 | } |
40 | |
41 | template <cpu_isa_t isa> |
42 | jit_uni_pool_kernel<isa>::~jit_uni_pool_kernel() = default; |
43 | |
44 | template <cpu_isa_t isa> |
45 | jit_uni_pool_kernel<isa>::jit_uni_pool_kernel( |
46 | const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md) |
47 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa) |
48 | , jpp(ajpp) |
49 | , bf16_emu_(nullptr) { |
50 | if (use_bf16_emulation()) |
51 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
52 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
53 | bf16_emu_reserv_4, bf16_emu_reserv_5); |
54 | |
55 | if (jpp.with_postops) { |
56 | static constexpr bool preserve_gpr = true; |
57 | static constexpr bool preserve_vmm = true; |
58 | static constexpr bool use_exact_tail_scalar_bcast = false; |
59 | static constexpr int sse41_single_block_size |
60 | = cpu_isa_traits<sse41>::vlen / sizeof(float); |
61 | size_t postop_tail = static_cast<size_t>(jpp.c_tail); |
62 | const bool high_half_block_empty = isa == sse41 |
63 | && static_cast<size_t>(jpp.c_tail) > sse41_single_block_size; |
64 | if (high_half_block_empty) postop_tail -= sse41_single_block_size; |
65 | |
66 | const binary_injector::rhs_arg_static_params_t rhs_sp { |
67 | static_cast<std::size_t>(this->xmm4.getIdx()), this->r14, |
68 | this->r15, this->r13, preserve_gpr, preserve_vmm, |
69 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
70 | memory_desc_wrapper(jpp.tag_kind == jit_memory_tag_kind_t::ncsp |
71 | ? jpp.tmp_md |
72 | : *dst_md), |
73 | postop_tail, k_c_tail_mask, use_exact_tail_scalar_bcast}; |
74 | |
75 | const binary_injector::static_params_t bsp { |
76 | reg_param, get_supported_bcast_strategies(), rhs_sp}; |
77 | |
78 | postops_injector_ |
79 | = utils::make_unique<injector::jit_uni_postops_injector_t<isa>>( |
80 | this, jpp.post_ops, bsp); |
81 | } |
82 | } |
83 | |
84 | static status_t set_binary_postops_formats( |
85 | post_ops_t &post_ops, const memory_desc_t *dst_md) { |
86 | for (int idx = 0; idx < post_ops.len(); ++idx) { |
87 | if (!post_ops.contain(primitive_kind::binary, idx)) continue; |
88 | |
89 | auto &src1_md = post_ops.entry_[idx].binary.src1_desc; |
90 | const memory_desc_wrapper src1_mdw(src1_md); |
91 | if (!src1_mdw.format_any()) { |
92 | if (src1_mdw.is_blocking_desc()) |
93 | continue; |
94 | else |
95 | return status::unimplemented; |
96 | } |
97 | |
98 | const memory_desc_wrapper dst_mdw(dst_md); |
99 | assert(!dst_mdw.format_any()); |
100 | |
101 | CHECK(memory_desc_init_by_blocking_desc( |
102 | src1_md, dst_mdw.blocking_desc())); |
103 | } |
104 | |
105 | return status::success; |
106 | } |
107 | |
108 | template <cpu_isa_t isa> |
109 | status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp, |
110 | memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr, |
111 | const pooling_pd_t *ppd) { |
112 | |
113 | const auto &pd = *ppd->desc(); |
114 | const memory_desc_wrapper src_d( |
115 | ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md()); |
116 | const memory_desc_wrapper dst_d( |
117 | ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md()); |
118 | |
119 | const int ndims = src_d.ndims(); |
120 | |
121 | jpp.nthr = dnnl_get_max_threads(); |
122 | jpp.is_training = pd.prop_kind == prop_kind::forward_training; |
123 | jpp.is_backward = pd.prop_kind == prop_kind::backward_data; |
124 | |
125 | jpp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
126 | jpp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
127 | jpp.iw = src_d.dims()[ndims - 1]; |
128 | jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
129 | jpp.ow = dst_d.dims()[ndims - 1]; |
130 | jpp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; |
131 | |
132 | const bool is_avx512 = is_superset(isa, avx512_core); |
133 | jpp.ndims = ndims; |
134 | jpp.mb = src_d.dims()[0]; |
135 | jpp.c_without_padding = src_d.dims()[1]; |
136 | jpp.c_block = is_avx512 ? 16 : 8; |
137 | |
138 | jpp.alg = pd.alg_kind; |
139 | jpp.tmp_md = memory_desc_t(); |
140 | |
141 | jpp.is_bf16 = (src_d.data_type() == data_type::bf16 |
142 | && dst_d.data_type() == data_type::bf16); |
143 | jpp.is_f16 = (src_d.data_type() == data_type::f16 |
144 | && dst_d.data_type() == data_type::f16); |
145 | |
146 | using namespace format_tag; |
147 | |
148 | const auto blocked_fmt_tag = is_avx512 |
149 | ? utils::pick(ndims - 3, nCw16c, nChw16c, nCdhw16c) |
150 | : utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); |
151 | |
152 | // src_d.data_type() is equal to dst_d.data_type(). This is checked in init |
153 | auto ncsp_fmt_tag = format_tag::undef; |
154 | |
155 | const unsigned int L3_cache_size_per_core |
156 | = platform::get_per_core_cache_size(3); |
157 | const size_t block_size |
158 | = ((size_t)jpp.id * jpp.ih * jpp.iw + jpp.od * jpp.oh * jpp.ow) |
159 | * jpp.c_block * types::data_type_size(src_d.data_type()); |
160 | |
161 | const bool forward_ncsp_allowed = !jpp.is_backward |
162 | && jpp.c_without_padding > 3 |
163 | && ((jpp.ih > 1 && jpp.iw > 1 |
164 | && block_size <= L3_cache_size_per_core) |
165 | || utils::one_of(src_d.data_type(), data_type::bf16, |
166 | data_type::f16)); |
167 | |
168 | const bool backward_ncsp_allowed = jpp.is_backward |
169 | && ((jpp.ih > 1 && jpp.iw > 1 && jpp.c_without_padding > 1 |
170 | && block_size <= L3_cache_size_per_core) |
171 | || (utils::one_of(src_d.data_type(), data_type::bf16, |
172 | data_type::f16) |
173 | && !(jpp.alg == pooling_max |
174 | && block_size > L3_cache_size_per_core))); |
175 | |
176 | ncsp_fmt_tag = ((forward_ncsp_allowed || backward_ncsp_allowed) && is_avx512 |
177 | && ndims <= 5) |
178 | ? utils::pick(ndims - 3, ncw, nchw, ncdhw) |
179 | : format_tag::undef; |
180 | |
181 | const auto nspc_fmt_tag = (ndims <= 5) |
182 | ? utils::pick(ndims - 3, nwc, nhwc, ndhwc) |
183 | : format_tag::undef; |
184 | |
185 | const auto fmt_tag = src_d.matches_one_of_tag( |
186 | blocked_fmt_tag, ncsp_fmt_tag, nspc_fmt_tag); |
187 | |
188 | if (!dst_d.matches_tag(fmt_tag)) return status::unimplemented; |
189 | |
190 | if (!post_ops_ok(jpp, attr, dst_d)) return status::unimplemented; |
191 | |
192 | if (fmt_tag == ncsp_fmt_tag) { |
193 | // transform input to blocked f32, call f32 jit, transform result to |
194 | // plain output |
195 | jpp.is_bf16 = false; |
196 | jpp.is_f16 = false; |
197 | jpp.dt_size = types::data_type_size(data_type::f32); |
198 | jpp.tag_kind = jit_memory_tag_kind_t::ncsp; |
199 | |
200 | // used to initialize binary post-ops |
201 | if (ppd->is_fwd() && jpp.with_binary) { |
202 | CHECK(memory_desc_init_by_tag(jpp.tmp_md, ndims, dst_d.md_->dims, |
203 | data_type::f32, blocked_fmt_tag)); |
204 | } |
205 | } else { |
206 | jpp.is_bf16 = (src_d.data_type() == data_type::bf16 |
207 | && dst_d.data_type() == data_type::bf16); |
208 | jpp.is_f16 = (src_d.data_type() == data_type::f16 |
209 | && dst_d.data_type() == data_type::f16); |
210 | jpp.dt_size = types::data_type_size(src_d.data_type()); |
211 | jpp.tag_kind = (fmt_tag == nspc_fmt_tag) |
212 | ? jit_memory_tag_kind_t::nspc |
213 | : jit_memory_tag_kind_t::blocked; |
214 | } |
215 | |
216 | if (ppd->is_fwd() && jpp.with_binary) { |
217 | CHECK(set_binary_postops_formats(attr.post_ops_, |
218 | jpp.tag_kind == jit_memory_tag_kind_t::ncsp ? &jpp.tmp_md |
219 | : dst_d.md_)); |
220 | } |
221 | |
222 | jpp.isa = (jpp.is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 |
223 | : isa; |
224 | |
225 | const bool args_ok = true && mayiuse(isa) && (fmt_tag != format_tag::undef) |
226 | && IMPLICATION(jpp.is_bf16, |
227 | utils::one_of(jpp.isa, avx512_core_bf16, avx512_core, |
228 | avx2_vnni_2)) |
229 | && IMPLICATION(jpp.is_f16, |
230 | utils::one_of(jpp.isa, avx512_core_fp16, avx2_vnni_2)) |
231 | && utils::one_of(pd.alg_kind, pooling_max, |
232 | pooling_avg_include_padding, pooling_avg_exclude_padding); |
233 | if (!args_ok) return status::unimplemented; |
234 | |
235 | const bool is_xf16_avx2_vnni_2 |
236 | = (jpp.is_bf16 || jpp.is_f16) && isa == avx2_vnni_2; |
237 | // note: avx2_vnni_2 only supports nxc format |
238 | if (!IMPLICATION(is_xf16_avx2_vnni_2, |
239 | jpp.tag_kind == jit_memory_tag_kind_t::nspc)) |
240 | return status::unimplemented; |
241 | |
242 | // note: avx2_vnni_2 only supports FWD direction |
243 | if (!IMPLICATION(is_xf16_avx2_vnni_2, !jpp.is_backward)) |
244 | return status::unimplemented; |
245 | |
246 | jpp.c = jpp.tag_kind == jit_memory_tag_kind_t::blocked |
247 | ? utils::rnd_up(jpp.c_without_padding, jpp.c_block) |
248 | : jpp.c_without_padding; |
249 | if (jpp.tag_kind == jit_memory_tag_kind_t::blocked) |
250 | assert(src_d.padded_dims()[1] == jpp.c); |
251 | jpp.nb_c = utils::div_up(jpp.c, jpp.c_block); |
252 | jpp.c_tail = jpp.c_without_padding % jpp.c_block; |
253 | jpp.is_c_padded = jpp.tag_kind == jit_memory_tag_kind_t::blocked |
254 | && src_d.padded_dims()[1] != jpp.c_without_padding; |
255 | |
256 | jpp.stride_d = (ndims == 5) ? pd.strides[0] : 1; |
257 | jpp.stride_h = (ndims == 3) ? 1 : pd.strides[ndims - 4]; |
258 | jpp.stride_w = pd.strides[ndims - 3]; |
259 | jpp.kd = (ndims == 5) ? pd.kernel[0] : 1; |
260 | jpp.kh = (ndims == 3) ? 1 : pd.kernel[ndims - 4]; |
261 | jpp.kw = pd.kernel[ndims - 3]; |
262 | |
263 | jpp.f_pad = (ndims == 5) ? pd.padding[0][0] : 0; |
264 | jpp.t_pad = (ndims == 3) ? 0 : pd.padding[0][ndims - 4]; |
265 | jpp.l_pad = pd.padding[0][ndims - 3]; |
266 | |
267 | const int back_pad = calculate_end_padding( |
268 | jpp.f_pad, jpp.od, jpp.id, jpp.stride_d, jpp.kd); |
269 | const int bottom_pad = calculate_end_padding( |
270 | jpp.t_pad, jpp.oh, jpp.ih, jpp.stride_h, jpp.kh); |
271 | const int right_pad = calculate_end_padding( |
272 | jpp.l_pad, jpp.ow, jpp.iw, jpp.stride_w, jpp.kw); |
273 | |
274 | if (jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw |
275 | || back_pad >= jpp.kd || bottom_pad >= jpp.kh |
276 | || right_pad >= jpp.kw) |
277 | return status::unimplemented; |
278 | |
279 | jpp.ind_dt = ppd->workspace_md() ? ppd->workspace_md()->data_type |
280 | : data_type::undef; |
281 | |
282 | jpp.simple_alg = jpp.is_training |
283 | || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d); |
284 | |
285 | jpp.ur = 0; |
286 | if (jpp.alg == pooling_max) { |
287 | jpp.ur = is_avx512 ? 16 : 4; |
288 | |
289 | if (utils::one_of(isa, avx, avx2, avx2_vnni_2) && jpp.c_tail > 0) |
290 | // Additional register needed for tail mask |
291 | jpp.ur -= 1; |
292 | |
293 | if (jpp.is_training) |
294 | jpp.ur = is_avx512 ? 9 : 3; |
295 | else if (jpp.is_backward) |
296 | jpp.ur = is_avx512 ? 6 : 3; |
297 | } else { |
298 | if (jpp.is_backward) |
299 | jpp.ur = is_avx512 ? 12 : 6; |
300 | else |
301 | jpp.ur = is_avx512 ? 24 : 12; |
302 | } |
303 | if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) { |
304 | jpp.ur = (!isa_has_bf16(jpp.isa)) |
305 | ? jpp.ur - 4 // Free registers for AVX512 emulation |
306 | : jpp.ur - 1; // Free register for cvt from bf16/f16 to f32 |
307 | } |
308 | |
309 | // select jpp.ur_bc |
310 | if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { |
311 | auto min_ur_w = nstl::max(1, utils::div_up(jpp.l_pad, jpp.stride_w)); |
312 | int min_ur_w1 = utils::div_up(right_pad, jpp.stride_w); |
313 | if (min_ur_w < min_ur_w1) { min_ur_w = min_ur_w1; } |
314 | jpp.ur_bc = nstl::min(jpp.nb_c, nstl::max(1, jpp.ur / min_ur_w)); |
315 | //take into account threading - to have enough work for parallelization |
316 | float best_eff = 0; |
317 | for (int ur_bc = jpp.ur_bc; ur_bc > 0; ur_bc--) { |
318 | |
319 | const auto nb2_c = utils::div_up(jpp.nb_c, ur_bc); |
320 | auto work = jpp.is_backward |
321 | ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1) |
322 | : (ndims == 5 ? jpp.od : jpp.oh); |
323 | work *= jpp.mb * nb2_c; |
324 | auto eff = (float)work / utils::rnd_up(work, jpp.nthr); |
325 | if (eff > best_eff) { |
326 | |
327 | best_eff = eff; |
328 | jpp.ur_bc = ur_bc; |
329 | } |
330 | if (eff > 0.9f) break; // Heuristic threshold |
331 | } |
332 | |
333 | //take into account cache re-usage after zeroing on backward |
334 | if (jpp.is_backward && ndims < 5) { |
335 | const int L2 = platform::get_per_core_cache_size(2) |
336 | / sizeof(jpp.dt_size); |
337 | int ur_bc = nstl::max(1, L2 / (jpp.kh * jpp.iw * jpp.c_block)); |
338 | jpp.ur_bc = nstl::min(jpp.ur_bc, ur_bc); |
339 | } |
340 | |
341 | jpp.ur_bc_tail = jpp.nb_c % jpp.ur_bc; |
342 | } else { |
343 | jpp.ur_bc = 1; |
344 | jpp.ur_bc_tail = 0; |
345 | } |
346 | |
347 | // scratchpad for c_block slice of input and/or output |
348 | using namespace memory_tracking::names; |
349 | const int nscr = nstl::min(dnnl_get_max_threads(), jpp.mb * jpp.nb_c); |
350 | if (jpp.tag_kind == jit_memory_tag_kind_t::ncsp) { |
351 | scratchpad.book(key_pool_src_plain2blocked_cvt, |
352 | static_cast<size_t>(jpp.c_block) * jpp.id * jpp.ih * jpp.iw |
353 | * nscr, |
354 | jpp.dt_size); |
355 | scratchpad.book(key_pool_dst_plain2blocked_cvt, |
356 | static_cast<size_t>(jpp.c_block) * jpp.od * jpp.oh * jpp.ow |
357 | * nscr, |
358 | jpp.dt_size); |
359 | scratchpad.book<uint32_t>(key_pool_ind_plain2blocked_cvt, |
360 | static_cast<size_t>(jpp.c_block) * jpp.od * jpp.oh * jpp.ow |
361 | * nscr); |
362 | } |
363 | |
364 | jpp.post_ops = attr.post_ops_; |
365 | |
366 | return status::success; |
367 | } |
368 | |
369 | static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept { |
370 | return shift * ur_bc * ur_w + bc * ur_w + j; |
371 | }; |
372 | |
373 | template <cpu_isa_t isa> |
374 | inline void jit_uni_pool_kernel<isa>::prepare_tail_mask() { |
375 | if (is_superset(isa, avx512_core)) { |
376 | size_t c_tail_mask = (1ULL << jpp.c_tail) - 1ULL; |
377 | mov(tmp_gpr.cvt32(), c_tail_mask); |
378 | kmovw(k_c_tail_mask, tmp_gpr.cvt32()); |
379 | } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) { |
380 | constexpr int max_words_in_ymm = 8; |
381 | |
382 | // for 'avx2_vnni_2' mask works with 2 x xf16 elements, |
383 | // in case of 'c_tail % 2 != 0' load/store an additional word |
384 | // for the remaining element. |
385 | auto dt_elem_div = isa == avx2_vnni_2 ? 2 : 1; |
386 | auto mask_offset = max_words_in_ymm - (jpp.c_tail / dt_elem_div); |
387 | auto mask_register |
388 | = isa == avx2_vnni_2 ? xmm_c_tail_mask : vmm_c_tail_mask; |
389 | static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, |
390 | 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, |
391 | 0, 0, 0, 0, 0, 0, 0}; |
392 | mov(tmp_gpr, reinterpret_cast<size_t>(&mask[mask_offset])); |
393 | vmovups(mask_register, ptr[tmp_gpr]); |
394 | } |
395 | } |
396 | |
397 | template <cpu_isa_t isa> |
398 | inline void jit_uni_pool_kernel<isa>::put_one_in_vmm() { |
399 | mov(tmp_gpr, 1); |
400 | uni_broadcast_reg_val(tmp_gpr.getIdx(), vmm_one.getIdx()); |
401 | } |
402 | |
403 | template <cpu_isa_t isa> |
404 | inline void jit_uni_pool_kernel<isa>::uni_broadcast_reg_val( |
405 | const int reg_idx, const int vmm_idx) { |
406 | uni_vmovq(Xmm(vmm_idx), reg64_t(reg_idx)); |
407 | uni_vpbroadcastd(Vmm(vmm_idx), Xmm(vmm_idx)); |
408 | } |
409 | |
410 | template <cpu_isa_t isa> |
411 | inline void jit_uni_pool_kernel<isa>::push_vmm_val(const int idx) { |
412 | Vmm val_to_store(idx); |
413 | sub(rsp, val_to_store.getBit()); |
414 | uni_vmovups(ptr[rsp], val_to_store); |
415 | } |
416 | |
417 | template <cpu_isa_t isa> |
418 | inline void jit_uni_pool_kernel<isa>::pop_vmm_val(const int idx) { |
419 | Vmm val_to_load(idx); |
420 | uni_vmovups(val_to_load, ptr[rsp]); |
421 | add(rsp, val_to_load.getBit()); |
422 | } |
423 | |
424 | template <cpu_isa_t isa> |
425 | inline void jit_uni_pool_kernel<isa>::load(const int idx, |
426 | const reg64_t ®_ptr, const int offset, |
427 | const bool is_c_tail_proccessing) { |
428 | if (jpp.is_bf16) { |
429 | /*TODO: maybe use vpmovzxwd + vpslld, |
430 | * in order to free up vmm_idx() register */ |
431 | if (is_c_tail_proccessing && !jpp.is_c_padded) { |
432 | Vmm vmm_to_load = Vmm(idx) | k_c_tail_mask | T_z; |
433 | vpmovzxwd(vmm_to_load, ptr[reg_ptr + offset]); |
434 | vpslld(vmm_to_load, vmm_to_load, 16); |
435 | } else { |
436 | vmovups(Ymm(idx), ptr[reg_ptr + offset]); |
437 | vpermw(Vmm(idx) | k_mask_cvt | T_z, vmm_idx(), Vmm(idx)); |
438 | } |
439 | } else if (jpp.is_f16) { |
440 | Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded |
441 | ? Vmm(idx) | k_c_tail_mask | T_z |
442 | : Vmm(idx); |
443 | vcvtph2psx(vmm_to_load, ptr[reg_ptr + offset]); |
444 | } else { |
445 | if (is_c_tail_proccessing && !jpp.is_c_padded) { |
446 | if (isa == avx || isa == avx2) { |
447 | vmaskmovps(Vmm(idx), vmm_c_tail_mask, ptr[reg_ptr + offset]); |
448 | } else { |
449 | vmovups(Zmm(idx) | k_c_tail_mask | T_z, ptr[reg_ptr + offset]); |
450 | } |
451 | } else { |
452 | uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]); |
453 | } |
454 | } |
455 | } |
456 | |
457 | template <> |
458 | inline void jit_uni_pool_kernel<avx2_vnni_2>::load(const int idx, |
459 | const reg64_t ®_ptr, const int offset, |
460 | const bool is_c_tail_proccessing) { |
461 | if (is_c_tail_proccessing) { |
462 | vmaskmovps(Xmm(idx), xmm_c_tail_mask, ptr[reg_ptr + offset]); |
463 | if (jpp.c_tail % 2 != 0) { |
464 | const int tail_pos = jpp.c_tail - 1; |
465 | auto word_addr |
466 | = ptr[reg_ptr + offset + tail_pos * sizeof(bfloat16_t)]; |
467 | vpinsrw(Xmm(idx), Xmm(idx), word_addr, tail_pos); |
468 | } |
469 | } |
470 | if (jpp.is_bf16) { |
471 | if (is_c_tail_proccessing) |
472 | vpmovzxwd(Ymm(idx), Xmm(idx)); |
473 | else |
474 | vpmovzxwd(Ymm(idx), ptr[reg_ptr + offset]); |
475 | vpslld(Ymm(idx), Ymm(idx), 16); |
476 | } else if (jpp.is_f16) { |
477 | if (is_c_tail_proccessing) |
478 | vcvtph2ps(Ymm(idx), Xmm(idx)); |
479 | else |
480 | vcvtph2ps(Ymm(idx), ptr[reg_ptr + offset]); |
481 | } else |
482 | assert(!"invalid data type" ); |
483 | } |
484 | |
485 | template <> |
486 | inline void jit_uni_pool_kernel<sse41>::load(const int idx, |
487 | const reg64_t ®_ptr, const int offset, |
488 | const bool is_c_tail_proccessing) { |
489 | if (is_c_tail_proccessing && !jpp.is_c_padded) { |
490 | for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++) |
491 | pinsrd(Xmm(idx), ptr[reg_ptr + offset + i * jpp.dt_size], i); |
492 | } else |
493 | uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]); |
494 | } |
495 | |
496 | template <cpu_isa_t isa> |
497 | inline void jit_uni_pool_kernel<isa>::store(const int idx, |
498 | const reg64_t ®_ptr, const int offset, |
499 | const bool is_c_tail_proccessing) { |
500 | if (jpp.is_bf16 || jpp.is_f16) { |
501 | if (is_c_tail_proccessing) { |
502 | if (jpp.is_c_padded) { |
503 | vmovdqu16(Ymm(idx) | k_c_tail_mask | T_z, Ymm(idx)); |
504 | vmovups(yword[reg_ptr + offset], Ymm(idx)); |
505 | } else |
506 | vmovdqu16(ptr[reg_ptr + offset] | k_c_tail_mask, Ymm(idx)); |
507 | } else |
508 | vmovups(yword[reg_ptr + offset], Ymm(idx)); |
509 | } else { |
510 | if (is_c_tail_proccessing) { |
511 | if (!jpp.is_c_padded) { |
512 | if (isa == avx || isa == avx2) |
513 | vmaskmovps( |
514 | ptr[reg_ptr + offset], vmm_c_tail_mask, Vmm(idx)); |
515 | else |
516 | vmovups(ptr[reg_ptr + offset] | k_c_tail_mask, Zmm(idx)); |
517 | } else { |
518 | if (jpp.with_postops) { |
519 | if (isa == avx || isa == avx2) { |
520 | uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1); |
521 | uni_vblendvps( |
522 | Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask); |
523 | } else |
524 | uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx)); |
525 | } |
526 | uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); |
527 | } |
528 | } else |
529 | uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); |
530 | } |
531 | } |
532 | |
533 | template <> |
534 | inline void jit_uni_pool_kernel<avx2_vnni_2>::store(const int idx, |
535 | const reg64_t ®_ptr, const int offset, |
536 | const bool is_c_tail_proccessing) { |
537 | if (jpp.is_bf16 || jpp.is_f16) { |
538 | if (is_c_tail_proccessing) { |
539 | vmaskmovps(ptr[reg_ptr + offset], xmm_c_tail_mask, Xmm(idx)); |
540 | if (jpp.c_tail % 2 != 0) { |
541 | const int tail_pos = jpp.c_tail - 1; |
542 | auto word_addr = ptr[reg_ptr + offset + tail_pos * 2]; |
543 | vpextrw(word_addr, Xmm(idx), tail_pos); |
544 | } |
545 | } else |
546 | vmovups(xword[reg_ptr + offset], Xmm(idx)); |
547 | } else |
548 | assert(!"datatype not supported" ); |
549 | } |
550 | |
551 | template <> |
552 | inline void jit_uni_pool_kernel<sse41>::store(const int idx, |
553 | const reg64_t ®_ptr, const int offset, |
554 | const bool is_c_tail_proccessing) { |
555 | if (is_c_tail_proccessing) { |
556 | if (!jpp.is_c_padded) { |
557 | for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++) |
558 | pextrd(ptr[reg_ptr + offset + i * jpp.dt_size], Xmm(idx), i); |
559 | } else { |
560 | if (jpp.with_postops) { |
561 | static constexpr auto xmm_half = 4; |
562 | const auto tail_size = (jpp.c_without_padding > jpp.c_block) |
563 | ? jpp.c_without_padding % (jpp.c - jpp.c_block) |
564 | : jpp.c_without_padding; |
565 | const auto tail_size_real = (tail_size >= xmm_half) |
566 | ? tail_size - xmm_half |
567 | : tail_size; |
568 | uni_vxorps(xmm_tmp_1, xmm_tmp_1, xmm_tmp_1); |
569 | if (tail_size <= xmm_half && sse_high_half) { |
570 | // just zero out upper half padding and don't write anything else |
571 | uni_vmovups(vmmword[reg_ptr + offset], xmm_tmp_1); |
572 | return; |
573 | } |
574 | |
575 | if ((tail_size < xmm_half && !sse_high_half) |
576 | || (tail_size > xmm_half && sse_high_half)) { |
577 | std::bitset<8> tail_mask((1 << tail_size_real) - 1); |
578 | tail_mask.flip(); |
579 | uni_vblendps(Vmm(idx), Vmm(idx), xmm_tmp_1, |
580 | tail_mask.to_ulong()); |
581 | } |
582 | } |
583 | uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); |
584 | } |
585 | } else |
586 | uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx)); |
587 | } |
588 | |
589 | template <cpu_isa_t isa> |
590 | bool jit_uni_pool_kernel<isa>::post_ops_ok(jit_pool_conf_t &jpp, |
591 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) { |
592 | const auto &post_ops = attr.post_ops_; |
593 | const auto &entries = post_ops.entry_; |
594 | jpp.with_postops = false; |
595 | jpp.with_eltwise = false; |
596 | jpp.with_binary = false; |
597 | |
598 | if (!jpp.is_backward) { |
599 | for (const auto &entry : entries) { |
600 | if (entry.is_eltwise()) { |
601 | const auto alg = entry.eltwise.alg; |
602 | jpp.with_eltwise = eltwise_injector::is_supported(isa, alg); |
603 | } else if (entry.is_binary()) { |
604 | const bool is_bf16_ok = IMPLICATION( |
605 | entry.binary.src1_desc.data_type == data_type::bf16, |
606 | utils::one_of(isa, avx512_core, avx2_vnni_2)); |
607 | const bool is_f16_ok = IMPLICATION( |
608 | entry.binary.src1_desc.data_type == data_type::f16, |
609 | utils::one_of(isa, avx512_core_fp16, avx2_vnni_2)); |
610 | if (!(is_bf16_ok && is_f16_ok)) return false; |
611 | |
612 | jpp.with_binary = true; |
613 | } else |
614 | return false; |
615 | } |
616 | |
617 | jpp.with_postops = jpp.with_eltwise || jpp.with_binary; |
618 | } |
619 | |
620 | return binary_injector::binary_args_broadcast_supported( |
621 | post_ops, dst_d, get_supported_bcast_strategies()); |
622 | } |
623 | |
624 | template <cpu_isa_t isa> |
625 | void jit_uni_pool_kernel<isa>::apply_postops(int ur_bc, int ur_w, int c_block, |
626 | const std::function<bool(int, bool)> &is_tail_predicate) { |
627 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
628 | const int end_idx = vmm_idx_upper_bound() + 1; |
629 | const int start_idx = end_idx - (ur_bc * ur_w); |
630 | const bool sse41_postops_disabled |
631 | = isa == sse41 && disable_postops_when_sse_high_half_processed_; |
632 | |
633 | if (jpp.with_binary && !sse41_postops_disabled) { |
634 | |
635 | const int c_off = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) |
636 | ? jpp.c |
637 | : c_block; |
638 | |
639 | if (jpp.tag_kind == jit_memory_tag_kind_t::ncsp) { |
640 | mov(tmp_gpr, reg_output); |
641 | sub(tmp_gpr, ptr[reg_param + GET_OFF(dst)]); |
642 | add(tmp_gpr, ptr[reg_param + GET_OFF(dst_po_helper)]); |
643 | } |
644 | |
645 | for (int jj = 0; jj < ur_w; jj++) { |
646 | for (int bci = 0; bci < ur_bc; bci++) { |
647 | const auto vmm_idx |
648 | = vreg(reg_ind(0, bci, jj, ur_bc, ur_w)).getIdx(); |
649 | |
650 | const size_t output_offset |
651 | = jpp.dt_size * (jj * c_off + bci * c_block); |
652 | |
653 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, |
654 | jpp.tag_kind == jit_memory_tag_kind_t::ncsp |
655 | ? tmp_gpr |
656 | : reg_output); |
657 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
658 | vmm_idx, output_offset); |
659 | if (is_tail_predicate |
660 | && is_tail_predicate( |
661 | bci, true /*process_with_postops*/)) { |
662 | rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
663 | } |
664 | } |
665 | } |
666 | } |
667 | postops_injector_->compute_vector_range(start_idx, end_idx, rhs_arg_params); |
668 | } |
669 | |
670 | template <cpu_isa_t isa> |
671 | inline void jit_uni_pool_kernel<isa>::maybe_recalculate_divisor( |
672 | int jj, int ur_w, int pad_l, int pad_r, bool with_c_tail_proccessing) { |
673 | if (jpp.alg == pooling_avg_exclude_padding) { |
674 | int kw = jpp.kw; |
675 | int stride_w = jpp.stride_w; |
676 | |
677 | int non_zero_kw = kw; |
678 | non_zero_kw -= nstl::max(0, pad_l - jj * stride_w); |
679 | non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj) * stride_w); |
680 | |
681 | if (non_zero_kw != prev_kw) { |
682 | mov(tmp_gpr, float2int((float)non_zero_kw)); |
683 | uni_vmovq(xmm_tmp, tmp_gpr); |
684 | uni_vbroadcastss(vmm_tmp, xmm_tmp); |
685 | if (with_c_tail_proccessing |
686 | && (utils::one_of(isa, avx, avx2, avx2_vnni_2))) { |
687 | push_vmm_val(vmm_c_tail_mask.getIdx()); |
688 | uni_broadcast_reg_val( |
689 | reg_ker_area_h.getIdx(), vmm_ker_area_h.getIdx()); |
690 | } |
691 | uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h); |
692 | if (with_c_tail_proccessing |
693 | && (utils::one_of(isa, avx, avx2, avx2_vnni_2))) { |
694 | pop_vmm_val(vmm_c_tail_mask.getIdx()); |
695 | } |
696 | prev_kw = non_zero_kw; |
697 | } |
698 | } |
699 | } |
700 | |
701 | template <cpu_isa_t isa> |
702 | inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l, |
703 | int pad_r, bool with_c_tail_proccessing) { |
704 | |
705 | auto iw = jpp.iw; |
706 | auto kw = jpp.kw; |
707 | auto stride_w = jpp.stride_w; |
708 | auto c_block = jpp.c_block; |
709 | auto dt_size = jpp.dt_size; |
710 | const int c_off |
711 | = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block; |
712 | Label kd_label, kh_label; |
713 | |
714 | const auto is_tail_processing = [&](int bc, |
715 | bool process_with_postops = false) { |
716 | if (isa == sse41 && (!jpp.is_c_padded || process_with_postops)) { |
717 | return with_c_tail_proccessing && bc == (ur_bc - 1) |
718 | && ((jpp.c_tail > (jpp.c_block / 2) && sse_high_half) |
719 | || (jpp.c_tail < (jpp.c_block / 2) |
720 | && !sse_high_half)); |
721 | } else |
722 | return with_c_tail_proccessing && bc == (ur_bc - 1); |
723 | }; |
724 | |
725 | for (int jj = 0; jj < ur_w; jj++) { |
726 | if (jpp.is_backward) |
727 | maybe_recalculate_divisor( |
728 | jj, ur_w, pad_l, pad_r, with_c_tail_proccessing); |
729 | for (int bci = 0; bci < ur_bc; bci++) { |
730 | const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w); |
731 | auto accvr = vreg(accr_i); |
732 | if (jpp.is_backward) { |
733 | auto output_offset = dt_size * (jj * c_off + bci * c_block); |
734 | load(accvr.getIdx(), reg_output, output_offset, |
735 | is_tail_processing(bci)); |
736 | uni_vdivps(accvr, accvr, vmm_tmp); |
737 | } else { |
738 | uni_vpxor(accvr, accvr, accvr); |
739 | } |
740 | } |
741 | } |
742 | |
743 | if (jpp.simple_alg && jpp.ndims == 5) { |
744 | push(reg_input); |
745 | push(reg_output); |
746 | mov(aux_reg_input_d, reg_input); |
747 | mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); |
748 | L(kd_label); |
749 | mov(aux_reg_input, aux_reg_input_d); |
750 | } else { |
751 | mov(aux_reg_input, reg_input); |
752 | } |
753 | |
754 | xor_(kj, kj); |
755 | L(kh_label); |
756 | { |
757 | for (int ki = 0; ki < kw; ki++) { |
758 | int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w)); |
759 | int jj_end = ur_w |
760 | - utils::div_up( |
761 | nstl::max(0, ki + pad_r - (kw - 1)), stride_w); |
762 | |
763 | for_(int jj = jj_start; jj < jj_end; jj++) |
764 | for (int bci = 0; bci < ur_bc; bci++) { |
765 | const auto accvr = vreg(reg_ind(0, bci, jj, ur_bc, ur_w)); |
766 | const auto inpr_i = reg_ind(1, bci, jj, ur_bc, ur_w); |
767 | auto inpvr = vreg(inpr_i); |
768 | int aux_input_offset |
769 | = (ki + jj * stride_w - pad_l) * c_off + bci * c_block; |
770 | if (aux_input_offset >= iw * c_off) continue; |
771 | int input_offset = dt_size * aux_input_offset; |
772 | if (jpp.is_backward) { |
773 | auto inpyr = yreg(inpr_i); |
774 | load(reg_idx(inpr_i), aux_reg_input, input_offset, |
775 | is_tail_processing(bci)); |
776 | uni_vaddps(inpvr, inpvr, accvr); |
777 | if (jpp.is_bf16) { |
778 | if (!isa_has_bf16(jpp.isa)) |
779 | bf16_emu_->vcvtneps2bf16(inpyr, zreg(inpr_i)); |
780 | else |
781 | vcvtneps2bf16(inpyr, inpvr); |
782 | } else if (jpp.is_f16) { |
783 | vcvtps2ph(inpyr, inpvr, _op_mxcsr); |
784 | } |
785 | store(reg_idx(inpr_i), aux_reg_input, input_offset, |
786 | is_tail_processing(bci)); |
787 | } else { |
788 | if (jpp.is_bf16 || jpp.is_f16 || is_tail_processing(bci) |
789 | || (isa == sse41 |
790 | && c_off % (jpp.c_block / 2) != 0)) { |
791 | load(vmm_tmp_1.getIdx(), aux_reg_input, input_offset, |
792 | is_tail_processing(bci)); |
793 | |
794 | uni_vaddps(accvr, accvr, vmm_tmp_1); |
795 | } else { |
796 | uni_vaddps(accvr, accvr, |
797 | ptr[aux_reg_input + input_offset]); |
798 | } |
799 | } |
800 | } |
801 | } |
802 | add(aux_reg_input, jpp.dt_size * iw * c_off); |
803 | inc(kj); |
804 | cmp(kj, reg_kh); |
805 | jl(kh_label, T_NEAR); |
806 | } |
807 | |
808 | if (jpp.simple_alg && jpp.ndims == 5) { |
809 | add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_off); |
810 | dec(ki); |
811 | cmp(ki, 0); |
812 | jg(kd_label, T_NEAR); |
813 | pop(reg_output); |
814 | pop(reg_input); |
815 | } |
816 | |
817 | if (!jpp.is_backward) { |
818 | for (int jj = 0; jj < ur_w; jj++) { |
819 | maybe_recalculate_divisor( |
820 | jj, ur_w, pad_l, pad_r, with_c_tail_proccessing); |
821 | for (int bci = 0; bci < ur_bc; bci++) { |
822 | const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w); |
823 | const auto accvr = vreg(accr_i); |
824 | uni_vdivps(accvr, accvr, vmm_tmp); |
825 | } |
826 | } |
827 | |
828 | if (jpp.with_postops) |
829 | apply_postops(ur_bc, ur_w, c_block, is_tail_processing); |
830 | |
831 | for (int jj = 0; jj < ur_w; jj++) { |
832 | for (int bci = 0; bci < ur_bc; bci++) { |
833 | const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w); |
834 | const auto accvr = vreg(accr_i); |
835 | const auto output_offset |
836 | = dt_size * (jj * c_off + bci * c_block); |
837 | const auto accyr = yreg(accr_i); |
838 | if (jpp.is_bf16) { |
839 | if (isa == avx2_vnni_2) { |
840 | auto accxr = xreg(accr_i); |
841 | vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding); |
842 | } else { |
843 | const auto acczr = zreg(accr_i); |
844 | if (!isa_has_bf16(jpp.isa)) |
845 | bf16_emu_->vcvtneps2bf16(accyr, acczr); |
846 | else |
847 | vcvtneps2bf16(accyr, accvr); |
848 | } |
849 | } else if (jpp.is_f16) { |
850 | if (isa == avx2_vnni_2) { |
851 | auto accxr = xreg(accr_i); |
852 | vcvtps2ph(accxr, accyr, _op_mxcsr); |
853 | } else |
854 | vcvtps2ph(accyr, accvr, _op_mxcsr); |
855 | } |
856 | store(reg_idx(accr_i), reg_output, output_offset, |
857 | is_tail_processing(bci)); |
858 | } |
859 | } |
860 | } |
861 | } |
862 | |
863 | template <cpu_isa_t isa> |
864 | inline void jit_uni_pool_kernel<isa>::max_step_fwd(int ur_w, int ur_bc, |
865 | int pad_l, int pad_r, bool with_c_tail_proccessing) { |
866 | int iw = jpp.iw; |
867 | int kw = jpp.kw; |
868 | int stride_w = jpp.stride_w; |
869 | int c_block = jpp.c_block; |
870 | const int c_off |
871 | = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block; |
872 | Label kd_label, kh_label; |
873 | |
874 | auto is_tail_processing = [&](int bc, bool process_with_postops = false) { |
875 | if (isa == sse41 && (!jpp.is_c_padded || process_with_postops)) { |
876 | return with_c_tail_proccessing && bc == (ur_bc - 1) |
877 | && ((jpp.c_tail > (jpp.c_block / 2) && sse_high_half) |
878 | || (jpp.c_tail < (jpp.c_block / 2) |
879 | && !sse_high_half)); |
880 | } else |
881 | return with_c_tail_proccessing && bc == (ur_bc - 1); |
882 | }; |
883 | |
884 | mov(tmp_gpr, float2int(nstl::numeric_limits<float>::lowest())); |
885 | uni_vmovq(xmm_tmp, tmp_gpr); |
886 | uni_vbroadcastss(vmm_tmp, xmm_tmp); |
887 | |
888 | for_(int jj = 0; jj < ur_w; jj++) |
889 | for (int bci = 0; bci < ur_bc; bci++) { |
890 | const auto accvr = vreg(reg_ind(0, bci, jj, ur_bc, ur_w)); |
891 | uni_vmovups(accvr, vmm_tmp); |
892 | if (jpp.is_training) { |
893 | const auto indvr = vreg(reg_ind(2, bci, jj, ur_bc, ur_w)); |
894 | uni_vpxor(indvr, indvr, indvr); |
895 | } |
896 | } |
897 | if (jpp.is_training) { |
898 | uni_vmovq(xmm_tmp, reg_k_shift); |
899 | uni_vpbroadcastd(vmm_k_offset, xmm_tmp); |
900 | } |
901 | if (jpp.ndims == 5) { |
902 | push(reg_input); |
903 | push(reg_output); |
904 | mov(aux_reg_input_d, reg_input); |
905 | mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); |
906 | L(kd_label); |
907 | mov(aux_reg_input, aux_reg_input_d); |
908 | } else { |
909 | mov(aux_reg_input, reg_input); |
910 | } |
911 | xor_(kj, kj); |
912 | L(kh_label); |
913 | { |
914 | for (int ki = 0; ki < kw; ki++) { |
915 | int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w)); |
916 | int jj_end = ur_w |
917 | - utils::div_up( |
918 | nstl::max(0, ki + pad_r - (kw - 1)), stride_w); |
919 | for_(int jj = jj_start; jj < jj_end; jj++) |
920 | for (int bci = 0; bci < ur_bc; bci++) { |
921 | const auto accvr = vreg(reg_ind(0, bci, jj, ur_bc, ur_w)); |
922 | const auto inpr_i = reg_ind(1, bci, jj, ur_bc, ur_w); |
923 | const auto inpvr = vreg(inpr_i); |
924 | const auto indvr = vreg(reg_ind(2, bci, jj, ur_bc, ur_w)); |
925 | const auto cvtvr = vreg(reg_ind(3, bci, jj, ur_bc, ur_w)); |
926 | int aux_input_offset |
927 | = (ki + jj * stride_w - pad_l) * c_off + bci * c_block; |
928 | if (aux_input_offset >= iw * c_off) continue; |
929 | int input_offset = jpp.dt_size * aux_input_offset; |
930 | load(reg_idx(inpr_i), aux_reg_input, input_offset, |
931 | is_tail_processing(bci)); |
932 | if (isa == sse41) { |
933 | movups(vmm_mask, accvr); |
934 | cmpps(vmm_mask, inpvr, _cmp_lt_os); |
935 | blendvps(accvr, inpvr); |
936 | if (jpp.is_training) blendvps(indvr, vmm_k_offset); |
937 | } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) { |
938 | vcmpps(cvtvr, accvr, inpvr, _cmp_lt_os); |
939 | vblendvps(accvr, accvr, inpvr, cvtvr); |
940 | if (jpp.is_training) |
941 | vblendvps(indvr, indvr, vmm_k_offset, cvtvr); |
942 | } else { |
943 | vcmpps(k_store_mask, accvr, inpvr, _cmp_lt_os); |
944 | vblendmps(accvr | k_store_mask, accvr, inpvr); |
945 | if (jpp.is_training) |
946 | vblendmps(indvr | k_store_mask, indvr, vmm_k_offset); |
947 | } |
948 | } |
949 | if (jpp.is_training) { |
950 | if (with_c_tail_proccessing |
951 | && (utils::one_of(isa, avx, avx2, avx2_vnni_2))) { |
952 | push_vmm_val(vmm_c_tail_mask.getIdx()); |
953 | put_one_in_vmm(); |
954 | } |
955 | |
956 | if (isa == avx && !mayiuse(avx2)) |
957 | avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); |
958 | else |
959 | uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); |
960 | |
961 | if (with_c_tail_proccessing |
962 | && (utils::one_of(isa, avx, avx2, avx2_vnni_2))) |
963 | pop_vmm_val(vmm_c_tail_mask.getIdx()); |
964 | } |
965 | } |
966 | add(aux_reg_input, jpp.dt_size * iw * c_off); |
967 | inc(kj); |
968 | cmp(kj, reg_kh); |
969 | jl(kh_label, T_NEAR); |
970 | } |
971 | |
972 | if (jpp.ndims == 5) { |
973 | add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_off); |
974 | if (jpp.is_training) { |
975 | mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]); |
976 | uni_vmovq(xmm_tmp, tmp_gpr); |
977 | uni_vpbroadcastd(vmm_tmp, xmm_tmp); |
978 | if (isa == avx && !mayiuse(avx2)) { |
979 | Xmm t(vmm_mask.getIdx()); |
980 | avx_vpadd1(vmm_k_offset, xmm_tmp, t); |
981 | } else { |
982 | uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); |
983 | } |
984 | } |
985 | |
986 | dec(ki); |
987 | cmp(ki, 0); |
988 | jg(kd_label, T_NEAR); |
989 | pop(reg_output); |
990 | pop(reg_input); |
991 | } |
992 | |
993 | if (with_c_tail_proccessing && jpp.is_c_padded && isa == sse41) |
994 | mov(tmp_gpr, 0); // needed zero to fill padded tail |
995 | |
996 | if (jpp.with_postops) |
997 | apply_postops(ur_bc, ur_w, c_block, is_tail_processing); |
998 | |
999 | for_(int jj = 0; jj < ur_w; jj++) |
1000 | for (int bci = 0; bci < ur_bc; bci++) { |
1001 | const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w); |
1002 | const auto accvr = vreg(accr_i); |
1003 | const auto output_offset = jpp.dt_size * (jj * c_off + bci * c_block); |
1004 | auto accyr = yreg(accr_i); |
1005 | if (jpp.is_bf16) { |
1006 | if (isa == avx2_vnni_2) { |
1007 | auto accxr = xreg(accr_i); |
1008 | vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding); |
1009 | } else { |
1010 | auto acczr = zreg(accr_i); |
1011 | if (!isa_has_bf16(jpp.isa)) |
1012 | bf16_emu_->vcvtneps2bf16(accyr, acczr); |
1013 | else |
1014 | vcvtneps2bf16(accyr, accvr); |
1015 | } |
1016 | } else if (jpp.is_f16) { |
1017 | if (isa == avx2_vnni_2) { |
1018 | auto accxr = xreg(accr_i); |
1019 | vcvtps2ph(accxr, accyr, _op_mxcsr); |
1020 | } else |
1021 | vcvtps2ph(accyr, accvr, _op_mxcsr); |
1022 | } |
1023 | store(reg_idx(accr_i), reg_output, output_offset, |
1024 | is_tail_processing(bci)); |
1025 | |
1026 | if (jpp.is_training) { |
1027 | const size_t step_index = (jj * c_off + bci * c_block) |
1028 | * types::data_type_size(jpp.ind_dt); |
1029 | |
1030 | const auto indr_i = reg_ind(2, bci, jj, ur_bc, ur_w); |
1031 | auto vr = vreg(indr_i); |
1032 | if (jpp.ind_dt == data_type::u8) { |
1033 | auto xr = xreg(indr_i); |
1034 | if (isa == sse41) { |
1035 | for (int i = 0; i < (jpp.c_block / 2); ++i) { |
1036 | if (is_tail_processing(bci) |
1037 | && i + (sse_high_half ? (jpp.c_block / 2) : 0) |
1038 | >= jpp.c_tail) { |
1039 | if (jpp.is_c_padded) |
1040 | mov(ptr[reg_index + step_index + i], |
1041 | tmp_gpr.cvt8()); // fill padded tail with zeros |
1042 | else |
1043 | break; // tail end |
1044 | } else { |
1045 | // bytes which should be stored are located in |
1046 | // least significant bits(8 to be precise) of 32 bits parts |
1047 | // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm |
1048 | pextrb(ptr[reg_index + step_index + i], xr, 4 * i); |
1049 | } |
1050 | } |
1051 | } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) { |
1052 | auto yr = yreg(indr_i); |
1053 | if (is_tail_processing(bci) && !jpp.is_c_padded) { |
1054 | const int max_nr_of_vals |
1055 | = jpp.c_tail > (jpp.c_block / 2) |
1056 | ? (jpp.c_block / 2) |
1057 | : jpp.c_tail; |
1058 | for (int i = 0; i < max_nr_of_vals; ++i) { |
1059 | // bytes which should be stored are located in |
1060 | // least significant bits(8 to be precise) of 32 bits parts |
1061 | // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm |
1062 | vpextrb(ptr[reg_index + step_index + i], xr, 4 * i); |
1063 | } |
1064 | |
1065 | if (jpp.c_tail > (jpp.c_block / 2)) { |
1066 | Xmm higher_128bits(vmm_mask.getIdx()); |
1067 | vextractf128(higher_128bits, yr, 1); |
1068 | for (int i = 0; i < jpp.c_tail - (jpp.c_block / 2); |
1069 | ++i) { |
1070 | // bytes which should be stored are located in |
1071 | // least significant bits(8 to be precise) of 32 bits parts |
1072 | // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm |
1073 | vpextrb(ptr[reg_index + step_index |
1074 | + (jpp.c_block / 2) + i], |
1075 | higher_128bits, 4 * i); |
1076 | } |
1077 | } |
1078 | } else { |
1079 | if (is_tail_processing(bci)) { |
1080 | assert(jpp.is_c_padded); |
1081 | vandps(yr, yr, vmm_c_tail_mask); |
1082 | } |
1083 | if (jj == 0) { |
1084 | vmovd(xmm_tmp, reg_shuf_mask); |
1085 | uni_vpbroadcastd(vmm_tmp, xmm_tmp); |
1086 | } |
1087 | if (mayiuse(avx2)) { |
1088 | vpshufb(yr, yr, vmm_tmp); |
1089 | vmovd(ptr[reg_index + step_index], xr); |
1090 | vperm2i128(yr, yr, yr, 0x1u); |
1091 | vmovd(ptr[reg_index + step_index |
1092 | + (jpp.c_block / 2)], |
1093 | xr); |
1094 | } else { |
1095 | Xmm t(vmm_mask.getIdx()); |
1096 | vextractf128(t, yr, 0); |
1097 | vpshufb(t, t, xmm_tmp); |
1098 | vmovd(ptr[reg_index + step_index], t); |
1099 | vextractf128(t, yr, 1); |
1100 | vpshufb(t, t, |
1101 | xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0] |
1102 | vmovd(ptr[reg_index + step_index |
1103 | + (jpp.c_block / 2)], |
1104 | t); |
1105 | } |
1106 | } |
1107 | } else { |
1108 | if (is_tail_processing(bci)) { |
1109 | if (jpp.is_c_padded) { |
1110 | knotw(k_c_tail_mask, k_c_tail_mask); |
1111 | vpxord(vr | k_c_tail_mask, vr, vr); |
1112 | knotw(k_c_tail_mask, k_c_tail_mask); |
1113 | vpmovusdb(ptr[reg_index + step_index], vr); |
1114 | } else |
1115 | vpmovusdb(ptr[reg_index + step_index], |
1116 | vr | k_c_tail_mask); |
1117 | } else { |
1118 | vpmovusdb(ptr[reg_index + step_index], vr); |
1119 | } |
1120 | } |
1121 | } else { |
1122 | store(vr.getIdx(), reg_index, step_index, |
1123 | is_tail_processing(bci)); |
1124 | } |
1125 | } |
1126 | } |
1127 | } |
1128 | |
1129 | template <cpu_isa_t isa> |
1130 | inline void jit_uni_pool_kernel<isa>::max_step_bwd(int ur_w, int ur_bc, |
1131 | int pad_l, int pad_r, bool with_c_tail_proccessing) { |
1132 | |
1133 | int iw = jpp.iw; |
1134 | int kw = jpp.kw; |
1135 | int stride_w = jpp.stride_w; |
1136 | int c_block = jpp.c_block; |
1137 | const int c_off |
1138 | = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block; |
1139 | Label kd_label, kh_label; |
1140 | |
1141 | const auto is_tail_processing = [&](int bc) { |
1142 | if (isa == sse41) { |
1143 | return with_c_tail_proccessing && bc == (ur_bc - 1) |
1144 | && ((jpp.c_tail > (jpp.c_block / 2) && sse_high_half) |
1145 | || (jpp.c_tail < (jpp.c_block / 2) |
1146 | && !sse_high_half) |
1147 | || (jpp.c_tail == (jpp.c_block / 2) && sse_high_half |
1148 | && jpp.is_c_padded)); |
1149 | } else |
1150 | return with_c_tail_proccessing && bc == (ur_bc - 1); |
1151 | }; |
1152 | |
1153 | for_(int jj = 0; jj < ur_w; jj++) |
1154 | for (int bci = 0; bci < ur_bc; bci++) { |
1155 | const auto outr_i = reg_ind(0, bci, jj, ur_bc, ur_w); |
1156 | auto out_offset = jpp.dt_size * (jj * c_off + bci * c_block); |
1157 | load(reg_idx(outr_i), reg_output, out_offset, is_tail_processing(bci)); |
1158 | const size_t step_index = (jj * c_off + bci * c_block) |
1159 | * types::data_type_size(jpp.ind_dt); |
1160 | |
1161 | const auto indr_i = reg_ind(1, bci, jj, ur_bc, ur_w); |
1162 | auto indvr = vreg(indr_i); |
1163 | if (jpp.ind_dt == data_type::u8) { |
1164 | auto indxr = xreg(indr_i); |
1165 | if (isa == sse41) { |
1166 | if (is_tail_processing(bci) && !jpp.is_c_padded) { |
1167 | for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++) |
1168 | pinsrb(indxr, ptr[reg_index + step_index + i], i); |
1169 | } else { |
1170 | movd(indxr, ptr[reg_index + step_index]); |
1171 | } |
1172 | pmovzxbd(indvr, indxr); |
1173 | } else if (isa == avx || isa == avx2) { |
1174 | if (is_tail_processing(bci) && !jpp.is_c_padded) { |
1175 | for (int i = 0; i < jpp.c_tail; i++) |
1176 | vpinsrb(indxr, indxr, ptr[reg_index + step_index + i], |
1177 | i); |
1178 | } else { |
1179 | vmovq(indxr, ptr[reg_index + step_index]); |
1180 | } |
1181 | if (!mayiuse(avx2)) { |
1182 | avx_pmovzxbd(indvr, indxr, xmm_tmp); |
1183 | } else { |
1184 | vpmovzxbd(indvr, indxr); |
1185 | } |
1186 | } else { |
1187 | if (is_tail_processing(bci) && !jpp.is_c_padded) { |
1188 | vpmovzxbd(indvr | k_c_tail_mask | T_z, |
1189 | ptr[reg_index + step_index]); |
1190 | } else { |
1191 | vpmovzxbd(indvr, ptr[reg_index + step_index]); |
1192 | } |
1193 | } |
1194 | } else { |
1195 | load(indvr.getIdx(), reg_index, step_index, |
1196 | is_tail_processing(bci)); |
1197 | } |
1198 | } |
1199 | uni_vmovq(xmm_tmp, reg_k_shift); |
1200 | uni_vpbroadcastd(vmm_k_offset, xmm_tmp); |
1201 | |
1202 | if (jpp.simple_alg && jpp.ndims == 5) { |
1203 | push(reg_input); |
1204 | push(reg_output); |
1205 | if (isa == sse41) { |
1206 | // Save rdi since it is used in maskmovdqu |
1207 | assert(dst_ptr == rdi); |
1208 | push(dst_ptr); |
1209 | } |
1210 | mov(aux_reg_input_d, reg_input); |
1211 | mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); |
1212 | mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]); |
1213 | L(kd_label); |
1214 | mov(aux_reg_input, aux_reg_input_d); |
1215 | } else { |
1216 | mov(aux_reg_input, reg_input); |
1217 | } |
1218 | |
1219 | xor_(kj, kj); |
1220 | L(kh_label); |
1221 | { |
1222 | for (int ki = 0; ki < kw; ki++) { |
1223 | int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w)); |
1224 | int jj_end = ur_w |
1225 | - utils::div_up( |
1226 | nstl::max(0, ki + pad_r - (kw - 1)), stride_w); |
1227 | for_(int jj = jj_start; jj < jj_end; jj++) |
1228 | for (int bci = 0; bci < ur_bc; bci++) { |
1229 | const auto outvr = vreg(reg_ind(0, bci, jj, ur_bc, ur_w)); |
1230 | const auto indvr = vreg(reg_ind(1, bci, jj, ur_bc, ur_w)); |
1231 | const auto inpr_i = reg_ind(2, bci, jj, ur_bc, ur_w); |
1232 | const auto inpvr = vreg(inpr_i); |
1233 | const auto cvtvr = vreg(reg_ind(3, bci, jj, ur_bc, ur_w)); |
1234 | int aux_inp_offset |
1235 | = (ki + jj * stride_w - pad_l) * c_off + bci * c_block; |
1236 | if (aux_inp_offset >= iw * c_off) continue; |
1237 | int inp_offset = jpp.dt_size * aux_inp_offset; |
1238 | load(reg_idx(inpr_i), aux_reg_input, inp_offset, |
1239 | is_tail_processing(bci)); |
1240 | if (isa == sse41) { |
1241 | mov(dst_ptr, aux_reg_input); |
1242 | add(dst_ptr, inp_offset); |
1243 | |
1244 | movups(cvtvr, indvr); |
1245 | pcmpeqd(cvtvr, vmm_k_offset); |
1246 | addps(inpvr, outvr); |
1247 | if (is_tail_processing(bci)) { |
1248 | Label end_cond_move[4]; |
1249 | for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); |
1250 | i++) { |
1251 | pextrd(tmp_gpr.cvt32(), cvtvr, i); |
1252 | cmp(tmp_gpr, 0); |
1253 | je(end_cond_move[i], T_NEAR); |
1254 | pextrd(ptr[dst_ptr + i * jpp.dt_size], inpvr, i); |
1255 | L(end_cond_move[i]); |
1256 | } |
1257 | } else |
1258 | maskmovdqu(inpvr, cvtvr); |
1259 | } else if (isa == avx || isa == avx2) { |
1260 | if (mayiuse(avx2)) { |
1261 | vpcmpeqd(cvtvr, indvr, vmm_k_offset); |
1262 | } else { |
1263 | avx_pcmpeqd(cvtvr, indvr, vmm_k_offset, xmm_tmp); |
1264 | } |
1265 | vaddps(inpvr, inpvr, outvr); |
1266 | if (is_tail_processing(bci)) { |
1267 | vandps(cvtvr, cvtvr, vmm_c_tail_mask); |
1268 | } |
1269 | vmaskmovps( |
1270 | vmmword[aux_reg_input + inp_offset], cvtvr, inpvr); |
1271 | } else { |
1272 | auto indzr = zreg(inpr_i); |
1273 | auto indyr = yreg(inpr_i); |
1274 | vpcmpeqd(k_store_mask, indvr, vmm_k_offset); |
1275 | vblendmps(vmm_tmp | k_store_mask | T_z, outvr, outvr); |
1276 | vaddps(inpvr, inpvr, vmm_tmp); |
1277 | if (jpp.is_bf16) { |
1278 | if (!isa_has_bf16(jpp.isa)) |
1279 | bf16_emu_->vcvtneps2bf16(indyr, indzr); |
1280 | else |
1281 | vcvtneps2bf16(indyr, inpvr); |
1282 | } else if (jpp.is_f16) { |
1283 | vcvtps2ph(indyr, inpvr, _op_mxcsr); |
1284 | } |
1285 | store(inpvr.getIdx(), aux_reg_input, inp_offset, |
1286 | is_tail_processing(bci)); |
1287 | } |
1288 | } |
1289 | |
1290 | if (with_c_tail_proccessing && (isa == avx || isa == avx2)) { |
1291 | push_vmm_val(vmm_c_tail_mask.getIdx()); |
1292 | put_one_in_vmm(); |
1293 | } |
1294 | |
1295 | if (isa == avx && !mayiuse(avx2)) { |
1296 | avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); |
1297 | } else { |
1298 | uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); |
1299 | } |
1300 | |
1301 | if (with_c_tail_proccessing && (isa == avx || isa == avx2)) |
1302 | pop_vmm_val(vmm_c_tail_mask.getIdx()); |
1303 | } |
1304 | add(aux_reg_input, jpp.dt_size * iw * c_off); |
1305 | inc(kj); |
1306 | cmp(kj, reg_kh); |
1307 | jl(kh_label, T_NEAR); |
1308 | } |
1309 | if (jpp.simple_alg && jpp.ndims == 5) { |
1310 | add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_off); |
1311 | |
1312 | mov(tmp_gpr, reg_kd_pad_shift); |
1313 | uni_vmovq(xmm_tmp, tmp_gpr); |
1314 | uni_vpbroadcastd(vmm_tmp, xmm_tmp); |
1315 | if (isa == avx && !mayiuse(avx2)) { |
1316 | Xmm t(vmm_mask.getIdx()); |
1317 | avx_vpadd1(vmm_k_offset, vmm_tmp, t); |
1318 | } else { |
1319 | uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); |
1320 | } |
1321 | |
1322 | dec(ki); |
1323 | cmp(ki, 0); |
1324 | jg(kd_label, T_NEAR); |
1325 | if (isa == sse41) { |
1326 | // Save rdi since it is used in maskmovdqu |
1327 | assert(dst_ptr == rdi); |
1328 | pop(dst_ptr); |
1329 | } |
1330 | pop(reg_output); |
1331 | pop(reg_input); |
1332 | } |
1333 | } |
1334 | |
1335 | template <cpu_isa_t isa> |
1336 | void jit_uni_pool_kernel<isa>::zero_diff_src( |
1337 | int ur_bc, bool with_c_tail_proccessing) { |
1338 | const int c_off = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) |
1339 | ? jpp.c |
1340 | : jpp.c_block; |
1341 | |
1342 | Label l_skip, l_ih_loop, l_id_loop; |
1343 | |
1344 | auto is_tail_processing = [&](int bc) { |
1345 | return with_c_tail_proccessing && bc == (ur_bc - 1); |
1346 | }; |
1347 | |
1348 | mov(reg_zero_id, ptr[reg_param + GET_OFF(zero_id)]); |
1349 | cmp(reg_zero_id, 0); |
1350 | jz(l_skip, T_NEAR); |
1351 | |
1352 | mov(reg_zero_ih, ptr[reg_param + GET_OFF(zero_ih)]); |
1353 | cmp(reg_zero_ih, 0); |
1354 | jz(l_skip, T_NEAR); |
1355 | |
1356 | mov(reg_zero_ptr, ptr[reg_param + GET_OFF(zero_ptr)]); |
1357 | |
1358 | Vmm vzero = vmm_tmp; |
1359 | uni_vpxor(vzero, vzero, vzero); |
1360 | |
1361 | const int width_size = jpp.iw * c_off * jpp.dt_size; |
1362 | |
1363 | auto aux_reg_zero_ptr = tmp_gpr; |
1364 | |
1365 | L(l_id_loop); |
1366 | { |
1367 | mov(aux_reg_zero_ptr, reg_zero_ptr); |
1368 | mov(aux_reg_zero_ih, reg_zero_ih); |
1369 | L(l_ih_loop); |
1370 | { |
1371 | const auto vlen = cpu_isa_traits<isa>::vlen; |
1372 | const int step = c_off * jpp.dt_size; |
1373 | |
1374 | // TODO: maybe a big code generated here |
1375 | for_(int i = 0; i < width_size; i += step) |
1376 | for (int bci = 0; bci < ur_bc; bci++) { |
1377 | const int offs = i + bci * jpp.c_block * jpp.dt_size; |
1378 | if (isa == sse41) { |
1379 | bool is_needed_c_tail_processing = false; |
1380 | if (is_tail_processing(bci) |
1381 | && jpp.c_tail < (jpp.c_block / 2)) |
1382 | is_needed_c_tail_processing = true; |
1383 | store(vzero.getIdx(), reg_zero_ptr, offs, |
1384 | is_needed_c_tail_processing); |
1385 | if (!is_tail_processing(bci) |
1386 | || (is_tail_processing(bci) |
1387 | && (jpp.is_c_padded |
1388 | || jpp.c_tail |
1389 | > (jpp.c_block / 2)))) { |
1390 | store(vzero.getIdx(), reg_zero_ptr, offs + vlen, |
1391 | is_tail_processing(bci)); |
1392 | } |
1393 | |
1394 | } else { |
1395 | store(vzero.getIdx(), reg_zero_ptr, offs, |
1396 | is_tail_processing(bci)); |
1397 | } |
1398 | } |
1399 | add(reg_zero_ptr, width_size); |
1400 | dec(aux_reg_zero_ih); |
1401 | jnz(l_ih_loop, T_NEAR); |
1402 | } |
1403 | mov(reg_zero_ptr, aux_reg_zero_ptr); |
1404 | add(reg_zero_ptr, width_size * jpp.ih); |
1405 | dec(reg_zero_id); |
1406 | jnz(l_id_loop, T_NEAR); |
1407 | } |
1408 | |
1409 | L(l_skip); |
1410 | } |
1411 | |
1412 | template <cpu_isa_t isa> |
1413 | void jit_uni_pool_kernel<isa>::generate() { |
1414 | |
1415 | this->preamble(); |
1416 | |
1417 | Label idx_table; |
1418 | |
1419 | int ow = jpp.ow; |
1420 | int iw = jpp.iw; |
1421 | int kw = jpp.kw; |
1422 | int kh = jpp.kh; |
1423 | int c_block = jpp.c_block; |
1424 | int stride_w = jpp.stride_w; |
1425 | int l_pad = jpp.l_pad; |
1426 | const int c_off |
1427 | = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block; |
1428 | |
1429 | int vlen = cpu_isa_traits<isa>::vlen; |
1430 | |
1431 | #if defined(_WIN32) |
1432 | // Always mimic the Unix ABI (see the note about maskmovdqu in the header |
1433 | // file). |
1434 | xor_(rdi, rcx); |
1435 | xor_(rcx, rdi); |
1436 | xor_(rdi, rcx); |
1437 | #endif |
1438 | if (use_bf16_emulation()) bf16_emu_->init_vcvtneps2bf16(); |
1439 | |
1440 | mov(reg_input, ptr[reg_param + GET_OFF(src)]); |
1441 | mov(reg_output, ptr[reg_param + GET_OFF(dst)]); |
1442 | if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) |
1443 | mov(reg_index, ptr[reg_param + GET_OFF(indices)]); |
1444 | mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]); |
1445 | mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]); |
1446 | mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]); |
1447 | mov(reg_nbc, ptr[reg_param + GET_OFF(ur_bc)]); |
1448 | |
1449 | if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) { |
1450 | mov(tmp_gpr.cvt32(), 0xAAAAAAAA); |
1451 | kmovd(k_mask_cvt, tmp_gpr.cvt32()); |
1452 | |
1453 | mov(tmp_gpr, idx_table); |
1454 | vmovups(vmm_idx(), ptr[tmp_gpr]); |
1455 | } |
1456 | |
1457 | auto process_oi = [&](int ur_w, int ur_bc, int lpad, int rpad, |
1458 | bool with_c_tail_proccessing, |
1459 | bool inc_reg = true) { |
1460 | step(ur_w, ur_bc, lpad, rpad, with_c_tail_proccessing); |
1461 | |
1462 | if (isa == sse41) { |
1463 | if (with_c_tail_proccessing && jpp.c_tail <= (jpp.c_block / 2)) { |
1464 | |
1465 | // In nspc format in case of c tail processing if c tail is |
1466 | // equal or lower than 4 we don't have to process |
1467 | // last high half block, because it doesn't exist |
1468 | if (!jpp.is_c_padded) ur_bc -= 1; |
1469 | /* |
1470 | * In case of c_tail_processing if c_tail is equal or lower than 4 |
1471 | * applying postops never make sense. In case of blocked format it |
1472 | * can cause overwriting zero padding or segfault because the element |
1473 | * corresponding to the piece with padded zeros doesn't exist in binary |
1474 | * postops arg1 tensor (nchw format) in per_oc bcast strategy. |
1475 | */ |
1476 | disable_postops_when_sse_high_half_processed_ |
1477 | = jpp.tag_kind == jit_memory_tag_kind_t::blocked; |
1478 | } |
1479 | sse_high_half = true; |
1480 | step_high_half(ur_w, ur_bc, lpad, rpad, with_c_tail_proccessing); |
1481 | sse_high_half = false; |
1482 | disable_postops_when_sse_high_half_processed_ = false; |
1483 | } |
1484 | |
1485 | if (!inc_reg) return; |
1486 | |
1487 | auto dt_size = jpp.dt_size; |
1488 | auto shift = (isa == sse41) ? vlen : 0; |
1489 | add(reg_input, |
1490 | dt_size * nstl::max(0, ur_w * stride_w - lpad) * c_off - shift); |
1491 | add(reg_output, dt_size * ur_w * c_off - shift); |
1492 | if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { |
1493 | auto ishift = (isa == sse41) ? jpp.c_block / 2 : 0; |
1494 | auto ind_dt_size = types::data_type_size(jpp.ind_dt); |
1495 | add(reg_index, (ur_w * c_off - ishift) * ind_dt_size); |
1496 | } |
1497 | }; |
1498 | |
1499 | auto perform_ker = [&](int ur_bc, bool with_c_tail_processing) { |
1500 | prev_kw = 0; // re-initialize this value for avg steps |
1501 | |
1502 | if (jpp.is_backward && jpp.simple_alg) |
1503 | zero_diff_src(ur_bc, with_c_tail_processing); |
1504 | |
1505 | if (jpp.alg == pooling_avg_exclude_padding |
1506 | && (!with_c_tail_processing |
1507 | || (!utils::one_of(isa, avx, avx2, avx2_vnni_2)))) { |
1508 | // vmm_ker_area_h and vmm_c_tail_mask are stored in one register |
1509 | // so when vmm_c_tail_mask is used we need to load vmm_ker_area_h |
1510 | // exactly where this information is needed with the |
1511 | // vmm_c_tail_mask information being saved first |
1512 | uni_broadcast_reg_val( |
1513 | reg_ker_area_h.getIdx(), vmm_ker_area_h.getIdx()); |
1514 | } |
1515 | |
1516 | if (jpp.alg == pooling_avg_include_padding) { |
1517 | mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd))); |
1518 | uni_vmovq(xmm_tmp, tmp_gpr); |
1519 | uni_vpbroadcastd(vmm_tmp, xmm_tmp); |
1520 | } |
1521 | |
1522 | if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { |
1523 | if (!with_c_tail_processing |
1524 | || (!utils::one_of(isa, avx, avx2, avx2_vnni_2))) { |
1525 | // The same situation as above(vmm_ker_area_h). |
1526 | put_one_in_vmm(); |
1527 | } |
1528 | |
1529 | if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) { |
1530 | mov(reg_shuf_mask, 0x0c080400); |
1531 | } |
1532 | } |
1533 | |
1534 | const int ur_w = nstl::min(jpp.ow, jpp.ur / jpp.ur_bc); |
1535 | const int n_oi_iterations = utils::div_up(ow, ur_w); |
1536 | const int ur_stride_w = ur_w * stride_w; |
1537 | const int l_pad_iterations |
1538 | = nstl::min(n_oi_iterations, utils::div_up(l_pad, ur_stride_w)); |
1539 | |
1540 | for (int i = 0; i < l_pad_iterations; ++i) { |
1541 | const int ow_s = i * ur_w; |
1542 | const int ow_e = nstl::min(ow, ow_s + ur_w); |
1543 | const int cur_l_pad = l_pad - i * ur_stride_w; |
1544 | const int cur_r_pad = nstl::max( |
1545 | 0, calculate_end_padding(l_pad, ow_e, iw, stride_w, kw)); |
1546 | const int cur_ur_w = ow_e - ow_s; |
1547 | process_oi(cur_ur_w, ur_bc, cur_l_pad, cur_r_pad, |
1548 | with_c_tail_processing); |
1549 | } |
1550 | |
1551 | const int rem_n_oi_iters = n_oi_iterations - l_pad_iterations; |
1552 | const int cur_iw = l_pad_iterations * ur_stride_w - l_pad; |
1553 | const int cur_iw_rightmost_idx = cur_iw + kw - 1; |
1554 | const int no_pad_full_n_oi_iters = utils::saturate<int>( |
1555 | 0, rem_n_oi_iters, (iw - cur_iw_rightmost_idx) / ur_stride_w); |
1556 | |
1557 | if (no_pad_full_n_oi_iters > 0) { |
1558 | Label ow_loop; |
1559 | if (no_pad_full_n_oi_iters > 1) xor_(oi_iter, oi_iter); |
1560 | L(ow_loop); |
1561 | { |
1562 | process_oi(ur_w, ur_bc, 0, 0, with_c_tail_processing); |
1563 | if (no_pad_full_n_oi_iters > 1) { |
1564 | inc(oi_iter); |
1565 | cmp(oi_iter, no_pad_full_n_oi_iters); |
1566 | jl(ow_loop, T_NEAR); |
1567 | } |
1568 | } |
1569 | } |
1570 | |
1571 | for (int i = l_pad_iterations + no_pad_full_n_oi_iters; |
1572 | i < n_oi_iterations; ++i) { |
1573 | const int ow_s = i * ur_w; |
1574 | const int ow_e = nstl::min(ow, ow_s + ur_w); |
1575 | const int cur_r_pad = nstl::max( |
1576 | 0, calculate_end_padding(l_pad, ow_e, iw, stride_w, kw)); |
1577 | const int cur_ur_w = ow_e - ow_s; |
1578 | process_oi(cur_ur_w, ur_bc, 0, cur_r_pad, with_c_tail_processing); |
1579 | } |
1580 | }; |
1581 | Label ur_bc_tail_label, c_tail_processing_label, finish_label; |
1582 | |
1583 | if (jpp.ur_bc_tail > 0) { |
1584 | cmp(reg_nbc, jpp.ur_bc); |
1585 | jne(ur_bc_tail_label, T_NEAR); |
1586 | } else if (jpp.c_tail != 0) { |
1587 | // ur_bc contains number of channel blocks to processing |
1588 | // b_c contains number of channel blocks already processed |
1589 | // If reg_nbc + tmp_gpr == jpp.nb_c then this is |
1590 | // information that probably channel tail processing will be needed. |
1591 | mov(tmp_gpr, ptr[reg_param + GET_OFF(b_c)]); |
1592 | add(tmp_gpr, reg_nbc); |
1593 | cmp(tmp_gpr, jpp.nb_c); |
1594 | je(c_tail_processing_label, T_NEAR); |
1595 | } |
1596 | |
1597 | perform_ker(jpp.ur_bc, false); |
1598 | |
1599 | if (jpp.ur_bc_tail > 0) { |
1600 | jmp(finish_label, T_NEAR); |
1601 | |
1602 | // If ur_bc_tail exists then we know that this is |
1603 | // last set of blocks to process and we need |
1604 | // care of c tail processing if number of channels |
1605 | // is not divided by number of channels in block |
1606 | L(ur_bc_tail_label); |
1607 | if (jpp.c_tail != 0) prepare_tail_mask(); |
1608 | perform_ker(jpp.ur_bc_tail, jpp.c_tail != 0); |
1609 | |
1610 | L(finish_label); |
1611 | } else if (jpp.c_tail != 0) { |
1612 | jmp(finish_label, T_NEAR); |
1613 | |
1614 | L(c_tail_processing_label); |
1615 | prepare_tail_mask(); |
1616 | perform_ker(jpp.ur_bc, true); |
1617 | |
1618 | L(finish_label); |
1619 | } |
1620 | |
1621 | this->postamble(); |
1622 | |
1623 | if (jpp.with_eltwise && postops_injector_) |
1624 | postops_injector_->prepare_table(); |
1625 | |
1626 | if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) { |
1627 | align(64); |
1628 | L(idx_table); |
1629 | const uint16_t _idx[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, |
1630 | 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15}; |
1631 | for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) |
1632 | dw(_idx[i]); |
1633 | } |
1634 | } |
1635 | |
1636 | template struct jit_uni_pool_kernel<sse41>; |
1637 | template struct jit_uni_pool_kernel<avx>; |
1638 | template struct jit_uni_pool_kernel<avx2>; |
1639 | template struct jit_uni_pool_kernel<avx2_vnni_2>; |
1640 | template struct jit_uni_pool_kernel<avx512_core>; |
1641 | template struct jit_uni_pool_kernel<avx512_core_fp16>; |
1642 | |
1643 | } // namespace x64 |
1644 | } // namespace cpu |
1645 | } // namespace impl |
1646 | } // namespace dnnl |
1647 | |
1648 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1649 | |