1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3* Copyright 2018 YANDEX LLC
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include <bitset>
19
20#include "common/dnnl_thread.hpp"
21
22#include "cpu/cpu_pooling_pd.hpp"
23#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
24#include "cpu/x64/jit_uni_pool_kernel.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace Xbyak;
32using namespace alg_kind;
33
34#define GET_OFF(field) offsetof(jit_pool_call_s, field)
35
36static bcast_set_t get_supported_bcast_strategies() {
37 return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
38 broadcasting_strategy_t::no_broadcast};
39}
40
41template <cpu_isa_t isa>
42jit_uni_pool_kernel<isa>::~jit_uni_pool_kernel() = default;
43
44template <cpu_isa_t isa>
45jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
46 const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md)
47 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
48 , jpp(ajpp)
49 , bf16_emu_(nullptr) {
50 if (use_bf16_emulation())
51 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
52 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
53 bf16_emu_reserv_4, bf16_emu_reserv_5);
54
55 if (jpp.with_postops) {
56 static constexpr bool preserve_gpr = true;
57 static constexpr bool preserve_vmm = true;
58 static constexpr bool use_exact_tail_scalar_bcast = false;
59 static constexpr int sse41_single_block_size
60 = cpu_isa_traits<sse41>::vlen / sizeof(float);
61 size_t postop_tail = static_cast<size_t>(jpp.c_tail);
62 const bool high_half_block_empty = isa == sse41
63 && static_cast<size_t>(jpp.c_tail) > sse41_single_block_size;
64 if (high_half_block_empty) postop_tail -= sse41_single_block_size;
65
66 const binary_injector::rhs_arg_static_params_t rhs_sp {
67 static_cast<std::size_t>(this->xmm4.getIdx()), this->r14,
68 this->r15, this->r13, preserve_gpr, preserve_vmm,
69 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
70 memory_desc_wrapper(jpp.tag_kind == jit_memory_tag_kind_t::ncsp
71 ? jpp.tmp_md
72 : *dst_md),
73 postop_tail, k_c_tail_mask, use_exact_tail_scalar_bcast};
74
75 const binary_injector::static_params_t bsp {
76 reg_param, get_supported_bcast_strategies(), rhs_sp};
77
78 postops_injector_
79 = utils::make_unique<injector::jit_uni_postops_injector_t<isa>>(
80 this, jpp.post_ops, bsp);
81 }
82}
83
84static status_t set_binary_postops_formats(
85 post_ops_t &post_ops, const memory_desc_t *dst_md) {
86 for (int idx = 0; idx < post_ops.len(); ++idx) {
87 if (!post_ops.contain(primitive_kind::binary, idx)) continue;
88
89 auto &src1_md = post_ops.entry_[idx].binary.src1_desc;
90 const memory_desc_wrapper src1_mdw(src1_md);
91 if (!src1_mdw.format_any()) {
92 if (src1_mdw.is_blocking_desc())
93 continue;
94 else
95 return status::unimplemented;
96 }
97
98 const memory_desc_wrapper dst_mdw(dst_md);
99 assert(!dst_mdw.format_any());
100
101 CHECK(memory_desc_init_by_blocking_desc(
102 src1_md, dst_mdw.blocking_desc()));
103 }
104
105 return status::success;
106}
107
108template <cpu_isa_t isa>
109status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
110 memory_tracking::registrar_t &scratchpad, primitive_attr_t &attr,
111 const pooling_pd_t *ppd) {
112
113 const auto &pd = *ppd->desc();
114 const memory_desc_wrapper src_d(
115 ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md());
116 const memory_desc_wrapper dst_d(
117 ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md());
118
119 const int ndims = src_d.ndims();
120
121 jpp.nthr = dnnl_get_max_threads();
122 jpp.is_training = pd.prop_kind == prop_kind::forward_training;
123 jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
124
125 jpp.id = (ndims == 5) ? src_d.dims()[2] : 1;
126 jpp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
127 jpp.iw = src_d.dims()[ndims - 1];
128 jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
129 jpp.ow = dst_d.dims()[ndims - 1];
130 jpp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
131
132 const bool is_avx512 = is_superset(isa, avx512_core);
133 jpp.ndims = ndims;
134 jpp.mb = src_d.dims()[0];
135 jpp.c_without_padding = src_d.dims()[1];
136 jpp.c_block = is_avx512 ? 16 : 8;
137
138 jpp.alg = pd.alg_kind;
139 jpp.tmp_md = memory_desc_t();
140
141 jpp.is_bf16 = (src_d.data_type() == data_type::bf16
142 && dst_d.data_type() == data_type::bf16);
143 jpp.is_f16 = (src_d.data_type() == data_type::f16
144 && dst_d.data_type() == data_type::f16);
145
146 using namespace format_tag;
147
148 const auto blocked_fmt_tag = is_avx512
149 ? utils::pick(ndims - 3, nCw16c, nChw16c, nCdhw16c)
150 : utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
151
152 // src_d.data_type() is equal to dst_d.data_type(). This is checked in init
153 auto ncsp_fmt_tag = format_tag::undef;
154
155 const unsigned int L3_cache_size_per_core
156 = platform::get_per_core_cache_size(3);
157 const size_t block_size
158 = ((size_t)jpp.id * jpp.ih * jpp.iw + jpp.od * jpp.oh * jpp.ow)
159 * jpp.c_block * types::data_type_size(src_d.data_type());
160
161 const bool forward_ncsp_allowed = !jpp.is_backward
162 && jpp.c_without_padding > 3
163 && ((jpp.ih > 1 && jpp.iw > 1
164 && block_size <= L3_cache_size_per_core)
165 || utils::one_of(src_d.data_type(), data_type::bf16,
166 data_type::f16));
167
168 const bool backward_ncsp_allowed = jpp.is_backward
169 && ((jpp.ih > 1 && jpp.iw > 1 && jpp.c_without_padding > 1
170 && block_size <= L3_cache_size_per_core)
171 || (utils::one_of(src_d.data_type(), data_type::bf16,
172 data_type::f16)
173 && !(jpp.alg == pooling_max
174 && block_size > L3_cache_size_per_core)));
175
176 ncsp_fmt_tag = ((forward_ncsp_allowed || backward_ncsp_allowed) && is_avx512
177 && ndims <= 5)
178 ? utils::pick(ndims - 3, ncw, nchw, ncdhw)
179 : format_tag::undef;
180
181 const auto nspc_fmt_tag = (ndims <= 5)
182 ? utils::pick(ndims - 3, nwc, nhwc, ndhwc)
183 : format_tag::undef;
184
185 const auto fmt_tag = src_d.matches_one_of_tag(
186 blocked_fmt_tag, ncsp_fmt_tag, nspc_fmt_tag);
187
188 if (!dst_d.matches_tag(fmt_tag)) return status::unimplemented;
189
190 if (!post_ops_ok(jpp, attr, dst_d)) return status::unimplemented;
191
192 if (fmt_tag == ncsp_fmt_tag) {
193 // transform input to blocked f32, call f32 jit, transform result to
194 // plain output
195 jpp.is_bf16 = false;
196 jpp.is_f16 = false;
197 jpp.dt_size = types::data_type_size(data_type::f32);
198 jpp.tag_kind = jit_memory_tag_kind_t::ncsp;
199
200 // used to initialize binary post-ops
201 if (ppd->is_fwd() && jpp.with_binary) {
202 CHECK(memory_desc_init_by_tag(jpp.tmp_md, ndims, dst_d.md_->dims,
203 data_type::f32, blocked_fmt_tag));
204 }
205 } else {
206 jpp.is_bf16 = (src_d.data_type() == data_type::bf16
207 && dst_d.data_type() == data_type::bf16);
208 jpp.is_f16 = (src_d.data_type() == data_type::f16
209 && dst_d.data_type() == data_type::f16);
210 jpp.dt_size = types::data_type_size(src_d.data_type());
211 jpp.tag_kind = (fmt_tag == nspc_fmt_tag)
212 ? jit_memory_tag_kind_t::nspc
213 : jit_memory_tag_kind_t::blocked;
214 }
215
216 if (ppd->is_fwd() && jpp.with_binary) {
217 CHECK(set_binary_postops_formats(attr.post_ops_,
218 jpp.tag_kind == jit_memory_tag_kind_t::ncsp ? &jpp.tmp_md
219 : dst_d.md_));
220 }
221
222 jpp.isa = (jpp.is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16
223 : isa;
224
225 const bool args_ok = true && mayiuse(isa) && (fmt_tag != format_tag::undef)
226 && IMPLICATION(jpp.is_bf16,
227 utils::one_of(jpp.isa, avx512_core_bf16, avx512_core,
228 avx2_vnni_2))
229 && IMPLICATION(jpp.is_f16,
230 utils::one_of(jpp.isa, avx512_core_fp16, avx2_vnni_2))
231 && utils::one_of(pd.alg_kind, pooling_max,
232 pooling_avg_include_padding, pooling_avg_exclude_padding);
233 if (!args_ok) return status::unimplemented;
234
235 const bool is_xf16_avx2_vnni_2
236 = (jpp.is_bf16 || jpp.is_f16) && isa == avx2_vnni_2;
237 // note: avx2_vnni_2 only supports nxc format
238 if (!IMPLICATION(is_xf16_avx2_vnni_2,
239 jpp.tag_kind == jit_memory_tag_kind_t::nspc))
240 return status::unimplemented;
241
242 // note: avx2_vnni_2 only supports FWD direction
243 if (!IMPLICATION(is_xf16_avx2_vnni_2, !jpp.is_backward))
244 return status::unimplemented;
245
246 jpp.c = jpp.tag_kind == jit_memory_tag_kind_t::blocked
247 ? utils::rnd_up(jpp.c_without_padding, jpp.c_block)
248 : jpp.c_without_padding;
249 if (jpp.tag_kind == jit_memory_tag_kind_t::blocked)
250 assert(src_d.padded_dims()[1] == jpp.c);
251 jpp.nb_c = utils::div_up(jpp.c, jpp.c_block);
252 jpp.c_tail = jpp.c_without_padding % jpp.c_block;
253 jpp.is_c_padded = jpp.tag_kind == jit_memory_tag_kind_t::blocked
254 && src_d.padded_dims()[1] != jpp.c_without_padding;
255
256 jpp.stride_d = (ndims == 5) ? pd.strides[0] : 1;
257 jpp.stride_h = (ndims == 3) ? 1 : pd.strides[ndims - 4];
258 jpp.stride_w = pd.strides[ndims - 3];
259 jpp.kd = (ndims == 5) ? pd.kernel[0] : 1;
260 jpp.kh = (ndims == 3) ? 1 : pd.kernel[ndims - 4];
261 jpp.kw = pd.kernel[ndims - 3];
262
263 jpp.f_pad = (ndims == 5) ? pd.padding[0][0] : 0;
264 jpp.t_pad = (ndims == 3) ? 0 : pd.padding[0][ndims - 4];
265 jpp.l_pad = pd.padding[0][ndims - 3];
266
267 const int back_pad = calculate_end_padding(
268 jpp.f_pad, jpp.od, jpp.id, jpp.stride_d, jpp.kd);
269 const int bottom_pad = calculate_end_padding(
270 jpp.t_pad, jpp.oh, jpp.ih, jpp.stride_h, jpp.kh);
271 const int right_pad = calculate_end_padding(
272 jpp.l_pad, jpp.ow, jpp.iw, jpp.stride_w, jpp.kw);
273
274 if (jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw
275 || back_pad >= jpp.kd || bottom_pad >= jpp.kh
276 || right_pad >= jpp.kw)
277 return status::unimplemented;
278
279 jpp.ind_dt = ppd->workspace_md() ? ppd->workspace_md()->data_type
280 : data_type::undef;
281
282 jpp.simple_alg = jpp.is_training
283 || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d);
284
285 jpp.ur = 0;
286 if (jpp.alg == pooling_max) {
287 jpp.ur = is_avx512 ? 16 : 4;
288
289 if (utils::one_of(isa, avx, avx2, avx2_vnni_2) && jpp.c_tail > 0)
290 // Additional register needed for tail mask
291 jpp.ur -= 1;
292
293 if (jpp.is_training)
294 jpp.ur = is_avx512 ? 9 : 3;
295 else if (jpp.is_backward)
296 jpp.ur = is_avx512 ? 6 : 3;
297 } else {
298 if (jpp.is_backward)
299 jpp.ur = is_avx512 ? 12 : 6;
300 else
301 jpp.ur = is_avx512 ? 24 : 12;
302 }
303 if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) {
304 jpp.ur = (!isa_has_bf16(jpp.isa))
305 ? jpp.ur - 4 // Free registers for AVX512 emulation
306 : jpp.ur - 1; // Free register for cvt from bf16/f16 to f32
307 }
308
309 // select jpp.ur_bc
310 if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
311 auto min_ur_w = nstl::max(1, utils::div_up(jpp.l_pad, jpp.stride_w));
312 int min_ur_w1 = utils::div_up(right_pad, jpp.stride_w);
313 if (min_ur_w < min_ur_w1) { min_ur_w = min_ur_w1; }
314 jpp.ur_bc = nstl::min(jpp.nb_c, nstl::max(1, jpp.ur / min_ur_w));
315 //take into account threading - to have enough work for parallelization
316 float best_eff = 0;
317 for (int ur_bc = jpp.ur_bc; ur_bc > 0; ur_bc--) {
318
319 const auto nb2_c = utils::div_up(jpp.nb_c, ur_bc);
320 auto work = jpp.is_backward
321 ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1)
322 : (ndims == 5 ? jpp.od : jpp.oh);
323 work *= jpp.mb * nb2_c;
324 auto eff = (float)work / utils::rnd_up(work, jpp.nthr);
325 if (eff > best_eff) {
326
327 best_eff = eff;
328 jpp.ur_bc = ur_bc;
329 }
330 if (eff > 0.9f) break; // Heuristic threshold
331 }
332
333 //take into account cache re-usage after zeroing on backward
334 if (jpp.is_backward && ndims < 5) {
335 const int L2 = platform::get_per_core_cache_size(2)
336 / sizeof(jpp.dt_size);
337 int ur_bc = nstl::max(1, L2 / (jpp.kh * jpp.iw * jpp.c_block));
338 jpp.ur_bc = nstl::min(jpp.ur_bc, ur_bc);
339 }
340
341 jpp.ur_bc_tail = jpp.nb_c % jpp.ur_bc;
342 } else {
343 jpp.ur_bc = 1;
344 jpp.ur_bc_tail = 0;
345 }
346
347 // scratchpad for c_block slice of input and/or output
348 using namespace memory_tracking::names;
349 const int nscr = nstl::min(dnnl_get_max_threads(), jpp.mb * jpp.nb_c);
350 if (jpp.tag_kind == jit_memory_tag_kind_t::ncsp) {
351 scratchpad.book(key_pool_src_plain2blocked_cvt,
352 static_cast<size_t>(jpp.c_block) * jpp.id * jpp.ih * jpp.iw
353 * nscr,
354 jpp.dt_size);
355 scratchpad.book(key_pool_dst_plain2blocked_cvt,
356 static_cast<size_t>(jpp.c_block) * jpp.od * jpp.oh * jpp.ow
357 * nscr,
358 jpp.dt_size);
359 scratchpad.book<uint32_t>(key_pool_ind_plain2blocked_cvt,
360 static_cast<size_t>(jpp.c_block) * jpp.od * jpp.oh * jpp.ow
361 * nscr);
362 }
363
364 jpp.post_ops = attr.post_ops_;
365
366 return status::success;
367}
368
369static int reg_ind(int shift, int bc, int j, int ur_bc, int ur_w) noexcept {
370 return shift * ur_bc * ur_w + bc * ur_w + j;
371};
372
373template <cpu_isa_t isa>
374inline void jit_uni_pool_kernel<isa>::prepare_tail_mask() {
375 if (is_superset(isa, avx512_core)) {
376 size_t c_tail_mask = (1ULL << jpp.c_tail) - 1ULL;
377 mov(tmp_gpr.cvt32(), c_tail_mask);
378 kmovw(k_c_tail_mask, tmp_gpr.cvt32());
379 } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) {
380 constexpr int max_words_in_ymm = 8;
381
382 // for 'avx2_vnni_2' mask works with 2 x xf16 elements,
383 // in case of 'c_tail % 2 != 0' load/store an additional word
384 // for the remaining element.
385 auto dt_elem_div = isa == avx2_vnni_2 ? 2 : 1;
386 auto mask_offset = max_words_in_ymm - (jpp.c_tail / dt_elem_div);
387 auto mask_register
388 = isa == avx2_vnni_2 ? xmm_c_tail_mask : vmm_c_tail_mask;
389 static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
390 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0,
391 0, 0, 0, 0, 0, 0, 0};
392 mov(tmp_gpr, reinterpret_cast<size_t>(&mask[mask_offset]));
393 vmovups(mask_register, ptr[tmp_gpr]);
394 }
395}
396
397template <cpu_isa_t isa>
398inline void jit_uni_pool_kernel<isa>::put_one_in_vmm() {
399 mov(tmp_gpr, 1);
400 uni_broadcast_reg_val(tmp_gpr.getIdx(), vmm_one.getIdx());
401}
402
403template <cpu_isa_t isa>
404inline void jit_uni_pool_kernel<isa>::uni_broadcast_reg_val(
405 const int reg_idx, const int vmm_idx) {
406 uni_vmovq(Xmm(vmm_idx), reg64_t(reg_idx));
407 uni_vpbroadcastd(Vmm(vmm_idx), Xmm(vmm_idx));
408}
409
410template <cpu_isa_t isa>
411inline void jit_uni_pool_kernel<isa>::push_vmm_val(const int idx) {
412 Vmm val_to_store(idx);
413 sub(rsp, val_to_store.getBit());
414 uni_vmovups(ptr[rsp], val_to_store);
415}
416
417template <cpu_isa_t isa>
418inline void jit_uni_pool_kernel<isa>::pop_vmm_val(const int idx) {
419 Vmm val_to_load(idx);
420 uni_vmovups(val_to_load, ptr[rsp]);
421 add(rsp, val_to_load.getBit());
422}
423
424template <cpu_isa_t isa>
425inline void jit_uni_pool_kernel<isa>::load(const int idx,
426 const reg64_t &reg_ptr, const int offset,
427 const bool is_c_tail_proccessing) {
428 if (jpp.is_bf16) {
429 /*TODO: maybe use vpmovzxwd + vpslld,
430 * in order to free up vmm_idx() register */
431 if (is_c_tail_proccessing && !jpp.is_c_padded) {
432 Vmm vmm_to_load = Vmm(idx) | k_c_tail_mask | T_z;
433 vpmovzxwd(vmm_to_load, ptr[reg_ptr + offset]);
434 vpslld(vmm_to_load, vmm_to_load, 16);
435 } else {
436 vmovups(Ymm(idx), ptr[reg_ptr + offset]);
437 vpermw(Vmm(idx) | k_mask_cvt | T_z, vmm_idx(), Vmm(idx));
438 }
439 } else if (jpp.is_f16) {
440 Vmm vmm_to_load = is_c_tail_proccessing && !jpp.is_c_padded
441 ? Vmm(idx) | k_c_tail_mask | T_z
442 : Vmm(idx);
443 vcvtph2psx(vmm_to_load, ptr[reg_ptr + offset]);
444 } else {
445 if (is_c_tail_proccessing && !jpp.is_c_padded) {
446 if (isa == avx || isa == avx2) {
447 vmaskmovps(Vmm(idx), vmm_c_tail_mask, ptr[reg_ptr + offset]);
448 } else {
449 vmovups(Zmm(idx) | k_c_tail_mask | T_z, ptr[reg_ptr + offset]);
450 }
451 } else {
452 uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]);
453 }
454 }
455}
456
457template <>
458inline void jit_uni_pool_kernel<avx2_vnni_2>::load(const int idx,
459 const reg64_t &reg_ptr, const int offset,
460 const bool is_c_tail_proccessing) {
461 if (is_c_tail_proccessing) {
462 vmaskmovps(Xmm(idx), xmm_c_tail_mask, ptr[reg_ptr + offset]);
463 if (jpp.c_tail % 2 != 0) {
464 const int tail_pos = jpp.c_tail - 1;
465 auto word_addr
466 = ptr[reg_ptr + offset + tail_pos * sizeof(bfloat16_t)];
467 vpinsrw(Xmm(idx), Xmm(idx), word_addr, tail_pos);
468 }
469 }
470 if (jpp.is_bf16) {
471 if (is_c_tail_proccessing)
472 vpmovzxwd(Ymm(idx), Xmm(idx));
473 else
474 vpmovzxwd(Ymm(idx), ptr[reg_ptr + offset]);
475 vpslld(Ymm(idx), Ymm(idx), 16);
476 } else if (jpp.is_f16) {
477 if (is_c_tail_proccessing)
478 vcvtph2ps(Ymm(idx), Xmm(idx));
479 else
480 vcvtph2ps(Ymm(idx), ptr[reg_ptr + offset]);
481 } else
482 assert(!"invalid data type");
483}
484
485template <>
486inline void jit_uni_pool_kernel<sse41>::load(const int idx,
487 const reg64_t &reg_ptr, const int offset,
488 const bool is_c_tail_proccessing) {
489 if (is_c_tail_proccessing && !jpp.is_c_padded) {
490 for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++)
491 pinsrd(Xmm(idx), ptr[reg_ptr + offset + i * jpp.dt_size], i);
492 } else
493 uni_vmovups(Vmm(idx), ptr[reg_ptr + offset]);
494}
495
496template <cpu_isa_t isa>
497inline void jit_uni_pool_kernel<isa>::store(const int idx,
498 const reg64_t &reg_ptr, const int offset,
499 const bool is_c_tail_proccessing) {
500 if (jpp.is_bf16 || jpp.is_f16) {
501 if (is_c_tail_proccessing) {
502 if (jpp.is_c_padded) {
503 vmovdqu16(Ymm(idx) | k_c_tail_mask | T_z, Ymm(idx));
504 vmovups(yword[reg_ptr + offset], Ymm(idx));
505 } else
506 vmovdqu16(ptr[reg_ptr + offset] | k_c_tail_mask, Ymm(idx));
507 } else
508 vmovups(yword[reg_ptr + offset], Ymm(idx));
509 } else {
510 if (is_c_tail_proccessing) {
511 if (!jpp.is_c_padded) {
512 if (isa == avx || isa == avx2)
513 vmaskmovps(
514 ptr[reg_ptr + offset], vmm_c_tail_mask, Vmm(idx));
515 else
516 vmovups(ptr[reg_ptr + offset] | k_c_tail_mask, Zmm(idx));
517 } else {
518 if (jpp.with_postops) {
519 if (isa == avx || isa == avx2) {
520 uni_vxorps(ymm_tmp_1, ymm_tmp_1, ymm_tmp_1);
521 uni_vblendvps(
522 Vmm(idx), ymm_tmp_1, Vmm(idx), vmm_c_tail_mask);
523 } else
524 uni_vmovups(Vmm(idx) | k_c_tail_mask | T_z, Vmm(idx));
525 }
526 uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx));
527 }
528 } else
529 uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx));
530 }
531}
532
533template <>
534inline void jit_uni_pool_kernel<avx2_vnni_2>::store(const int idx,
535 const reg64_t &reg_ptr, const int offset,
536 const bool is_c_tail_proccessing) {
537 if (jpp.is_bf16 || jpp.is_f16) {
538 if (is_c_tail_proccessing) {
539 vmaskmovps(ptr[reg_ptr + offset], xmm_c_tail_mask, Xmm(idx));
540 if (jpp.c_tail % 2 != 0) {
541 const int tail_pos = jpp.c_tail - 1;
542 auto word_addr = ptr[reg_ptr + offset + tail_pos * 2];
543 vpextrw(word_addr, Xmm(idx), tail_pos);
544 }
545 } else
546 vmovups(xword[reg_ptr + offset], Xmm(idx));
547 } else
548 assert(!"datatype not supported");
549}
550
551template <>
552inline void jit_uni_pool_kernel<sse41>::store(const int idx,
553 const reg64_t &reg_ptr, const int offset,
554 const bool is_c_tail_proccessing) {
555 if (is_c_tail_proccessing) {
556 if (!jpp.is_c_padded) {
557 for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++)
558 pextrd(ptr[reg_ptr + offset + i * jpp.dt_size], Xmm(idx), i);
559 } else {
560 if (jpp.with_postops) {
561 static constexpr auto xmm_half = 4;
562 const auto tail_size = (jpp.c_without_padding > jpp.c_block)
563 ? jpp.c_without_padding % (jpp.c - jpp.c_block)
564 : jpp.c_without_padding;
565 const auto tail_size_real = (tail_size >= xmm_half)
566 ? tail_size - xmm_half
567 : tail_size;
568 uni_vxorps(xmm_tmp_1, xmm_tmp_1, xmm_tmp_1);
569 if (tail_size <= xmm_half && sse_high_half) {
570 // just zero out upper half padding and don't write anything else
571 uni_vmovups(vmmword[reg_ptr + offset], xmm_tmp_1);
572 return;
573 }
574
575 if ((tail_size < xmm_half && !sse_high_half)
576 || (tail_size > xmm_half && sse_high_half)) {
577 std::bitset<8> tail_mask((1 << tail_size_real) - 1);
578 tail_mask.flip();
579 uni_vblendps(Vmm(idx), Vmm(idx), xmm_tmp_1,
580 tail_mask.to_ulong());
581 }
582 }
583 uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx));
584 }
585 } else
586 uni_vmovups(vmmword[reg_ptr + offset], Vmm(idx));
587}
588
589template <cpu_isa_t isa>
590bool jit_uni_pool_kernel<isa>::post_ops_ok(jit_pool_conf_t &jpp,
591 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) {
592 const auto &post_ops = attr.post_ops_;
593 const auto &entries = post_ops.entry_;
594 jpp.with_postops = false;
595 jpp.with_eltwise = false;
596 jpp.with_binary = false;
597
598 if (!jpp.is_backward) {
599 for (const auto &entry : entries) {
600 if (entry.is_eltwise()) {
601 const auto alg = entry.eltwise.alg;
602 jpp.with_eltwise = eltwise_injector::is_supported(isa, alg);
603 } else if (entry.is_binary()) {
604 const bool is_bf16_ok = IMPLICATION(
605 entry.binary.src1_desc.data_type == data_type::bf16,
606 utils::one_of(isa, avx512_core, avx2_vnni_2));
607 const bool is_f16_ok = IMPLICATION(
608 entry.binary.src1_desc.data_type == data_type::f16,
609 utils::one_of(isa, avx512_core_fp16, avx2_vnni_2));
610 if (!(is_bf16_ok && is_f16_ok)) return false;
611
612 jpp.with_binary = true;
613 } else
614 return false;
615 }
616
617 jpp.with_postops = jpp.with_eltwise || jpp.with_binary;
618 }
619
620 return binary_injector::binary_args_broadcast_supported(
621 post_ops, dst_d, get_supported_bcast_strategies());
622}
623
624template <cpu_isa_t isa>
625void jit_uni_pool_kernel<isa>::apply_postops(int ur_bc, int ur_w, int c_block,
626 const std::function<bool(int, bool)> &is_tail_predicate) {
627 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
628 const int end_idx = vmm_idx_upper_bound() + 1;
629 const int start_idx = end_idx - (ur_bc * ur_w);
630 const bool sse41_postops_disabled
631 = isa == sse41 && disable_postops_when_sse_high_half_processed_;
632
633 if (jpp.with_binary && !sse41_postops_disabled) {
634
635 const int c_off = (jpp.tag_kind == jit_memory_tag_kind_t::nspc)
636 ? jpp.c
637 : c_block;
638
639 if (jpp.tag_kind == jit_memory_tag_kind_t::ncsp) {
640 mov(tmp_gpr, reg_output);
641 sub(tmp_gpr, ptr[reg_param + GET_OFF(dst)]);
642 add(tmp_gpr, ptr[reg_param + GET_OFF(dst_po_helper)]);
643 }
644
645 for (int jj = 0; jj < ur_w; jj++) {
646 for (int bci = 0; bci < ur_bc; bci++) {
647 const auto vmm_idx
648 = vreg(reg_ind(0, bci, jj, ur_bc, ur_w)).getIdx();
649
650 const size_t output_offset
651 = jpp.dt_size * (jj * c_off + bci * c_block);
652
653 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx,
654 jpp.tag_kind == jit_memory_tag_kind_t::ncsp
655 ? tmp_gpr
656 : reg_output);
657 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
658 vmm_idx, output_offset);
659 if (is_tail_predicate
660 && is_tail_predicate(
661 bci, true /*process_with_postops*/)) {
662 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
663 }
664 }
665 }
666 }
667 postops_injector_->compute_vector_range(start_idx, end_idx, rhs_arg_params);
668}
669
670template <cpu_isa_t isa>
671inline void jit_uni_pool_kernel<isa>::maybe_recalculate_divisor(
672 int jj, int ur_w, int pad_l, int pad_r, bool with_c_tail_proccessing) {
673 if (jpp.alg == pooling_avg_exclude_padding) {
674 int kw = jpp.kw;
675 int stride_w = jpp.stride_w;
676
677 int non_zero_kw = kw;
678 non_zero_kw -= nstl::max(0, pad_l - jj * stride_w);
679 non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj) * stride_w);
680
681 if (non_zero_kw != prev_kw) {
682 mov(tmp_gpr, float2int((float)non_zero_kw));
683 uni_vmovq(xmm_tmp, tmp_gpr);
684 uni_vbroadcastss(vmm_tmp, xmm_tmp);
685 if (with_c_tail_proccessing
686 && (utils::one_of(isa, avx, avx2, avx2_vnni_2))) {
687 push_vmm_val(vmm_c_tail_mask.getIdx());
688 uni_broadcast_reg_val(
689 reg_ker_area_h.getIdx(), vmm_ker_area_h.getIdx());
690 }
691 uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h);
692 if (with_c_tail_proccessing
693 && (utils::one_of(isa, avx, avx2, avx2_vnni_2))) {
694 pop_vmm_val(vmm_c_tail_mask.getIdx());
695 }
696 prev_kw = non_zero_kw;
697 }
698 }
699}
700
701template <cpu_isa_t isa>
702inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int ur_bc, int pad_l,
703 int pad_r, bool with_c_tail_proccessing) {
704
705 auto iw = jpp.iw;
706 auto kw = jpp.kw;
707 auto stride_w = jpp.stride_w;
708 auto c_block = jpp.c_block;
709 auto dt_size = jpp.dt_size;
710 const int c_off
711 = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block;
712 Label kd_label, kh_label;
713
714 const auto is_tail_processing = [&](int bc,
715 bool process_with_postops = false) {
716 if (isa == sse41 && (!jpp.is_c_padded || process_with_postops)) {
717 return with_c_tail_proccessing && bc == (ur_bc - 1)
718 && ((jpp.c_tail > (jpp.c_block / 2) && sse_high_half)
719 || (jpp.c_tail < (jpp.c_block / 2)
720 && !sse_high_half));
721 } else
722 return with_c_tail_proccessing && bc == (ur_bc - 1);
723 };
724
725 for (int jj = 0; jj < ur_w; jj++) {
726 if (jpp.is_backward)
727 maybe_recalculate_divisor(
728 jj, ur_w, pad_l, pad_r, with_c_tail_proccessing);
729 for (int bci = 0; bci < ur_bc; bci++) {
730 const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
731 auto accvr = vreg(accr_i);
732 if (jpp.is_backward) {
733 auto output_offset = dt_size * (jj * c_off + bci * c_block);
734 load(accvr.getIdx(), reg_output, output_offset,
735 is_tail_processing(bci));
736 uni_vdivps(accvr, accvr, vmm_tmp);
737 } else {
738 uni_vpxor(accvr, accvr, accvr);
739 }
740 }
741 }
742
743 if (jpp.simple_alg && jpp.ndims == 5) {
744 push(reg_input);
745 push(reg_output);
746 mov(aux_reg_input_d, reg_input);
747 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
748 L(kd_label);
749 mov(aux_reg_input, aux_reg_input_d);
750 } else {
751 mov(aux_reg_input, reg_input);
752 }
753
754 xor_(kj, kj);
755 L(kh_label);
756 {
757 for (int ki = 0; ki < kw; ki++) {
758 int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
759 int jj_end = ur_w
760 - utils::div_up(
761 nstl::max(0, ki + pad_r - (kw - 1)), stride_w);
762
763 for_(int jj = jj_start; jj < jj_end; jj++)
764 for (int bci = 0; bci < ur_bc; bci++) {
765 const auto accvr = vreg(reg_ind(0, bci, jj, ur_bc, ur_w));
766 const auto inpr_i = reg_ind(1, bci, jj, ur_bc, ur_w);
767 auto inpvr = vreg(inpr_i);
768 int aux_input_offset
769 = (ki + jj * stride_w - pad_l) * c_off + bci * c_block;
770 if (aux_input_offset >= iw * c_off) continue;
771 int input_offset = dt_size * aux_input_offset;
772 if (jpp.is_backward) {
773 auto inpyr = yreg(inpr_i);
774 load(reg_idx(inpr_i), aux_reg_input, input_offset,
775 is_tail_processing(bci));
776 uni_vaddps(inpvr, inpvr, accvr);
777 if (jpp.is_bf16) {
778 if (!isa_has_bf16(jpp.isa))
779 bf16_emu_->vcvtneps2bf16(inpyr, zreg(inpr_i));
780 else
781 vcvtneps2bf16(inpyr, inpvr);
782 } else if (jpp.is_f16) {
783 vcvtps2ph(inpyr, inpvr, _op_mxcsr);
784 }
785 store(reg_idx(inpr_i), aux_reg_input, input_offset,
786 is_tail_processing(bci));
787 } else {
788 if (jpp.is_bf16 || jpp.is_f16 || is_tail_processing(bci)
789 || (isa == sse41
790 && c_off % (jpp.c_block / 2) != 0)) {
791 load(vmm_tmp_1.getIdx(), aux_reg_input, input_offset,
792 is_tail_processing(bci));
793
794 uni_vaddps(accvr, accvr, vmm_tmp_1);
795 } else {
796 uni_vaddps(accvr, accvr,
797 ptr[aux_reg_input + input_offset]);
798 }
799 }
800 }
801 }
802 add(aux_reg_input, jpp.dt_size * iw * c_off);
803 inc(kj);
804 cmp(kj, reg_kh);
805 jl(kh_label, T_NEAR);
806 }
807
808 if (jpp.simple_alg && jpp.ndims == 5) {
809 add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_off);
810 dec(ki);
811 cmp(ki, 0);
812 jg(kd_label, T_NEAR);
813 pop(reg_output);
814 pop(reg_input);
815 }
816
817 if (!jpp.is_backward) {
818 for (int jj = 0; jj < ur_w; jj++) {
819 maybe_recalculate_divisor(
820 jj, ur_w, pad_l, pad_r, with_c_tail_proccessing);
821 for (int bci = 0; bci < ur_bc; bci++) {
822 const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
823 const auto accvr = vreg(accr_i);
824 uni_vdivps(accvr, accvr, vmm_tmp);
825 }
826 }
827
828 if (jpp.with_postops)
829 apply_postops(ur_bc, ur_w, c_block, is_tail_processing);
830
831 for (int jj = 0; jj < ur_w; jj++) {
832 for (int bci = 0; bci < ur_bc; bci++) {
833 const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
834 const auto accvr = vreg(accr_i);
835 const auto output_offset
836 = dt_size * (jj * c_off + bci * c_block);
837 const auto accyr = yreg(accr_i);
838 if (jpp.is_bf16) {
839 if (isa == avx2_vnni_2) {
840 auto accxr = xreg(accr_i);
841 vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding);
842 } else {
843 const auto acczr = zreg(accr_i);
844 if (!isa_has_bf16(jpp.isa))
845 bf16_emu_->vcvtneps2bf16(accyr, acczr);
846 else
847 vcvtneps2bf16(accyr, accvr);
848 }
849 } else if (jpp.is_f16) {
850 if (isa == avx2_vnni_2) {
851 auto accxr = xreg(accr_i);
852 vcvtps2ph(accxr, accyr, _op_mxcsr);
853 } else
854 vcvtps2ph(accyr, accvr, _op_mxcsr);
855 }
856 store(reg_idx(accr_i), reg_output, output_offset,
857 is_tail_processing(bci));
858 }
859 }
860 }
861}
862
863template <cpu_isa_t isa>
864inline void jit_uni_pool_kernel<isa>::max_step_fwd(int ur_w, int ur_bc,
865 int pad_l, int pad_r, bool with_c_tail_proccessing) {
866 int iw = jpp.iw;
867 int kw = jpp.kw;
868 int stride_w = jpp.stride_w;
869 int c_block = jpp.c_block;
870 const int c_off
871 = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block;
872 Label kd_label, kh_label;
873
874 auto is_tail_processing = [&](int bc, bool process_with_postops = false) {
875 if (isa == sse41 && (!jpp.is_c_padded || process_with_postops)) {
876 return with_c_tail_proccessing && bc == (ur_bc - 1)
877 && ((jpp.c_tail > (jpp.c_block / 2) && sse_high_half)
878 || (jpp.c_tail < (jpp.c_block / 2)
879 && !sse_high_half));
880 } else
881 return with_c_tail_proccessing && bc == (ur_bc - 1);
882 };
883
884 mov(tmp_gpr, float2int(nstl::numeric_limits<float>::lowest()));
885 uni_vmovq(xmm_tmp, tmp_gpr);
886 uni_vbroadcastss(vmm_tmp, xmm_tmp);
887
888 for_(int jj = 0; jj < ur_w; jj++)
889 for (int bci = 0; bci < ur_bc; bci++) {
890 const auto accvr = vreg(reg_ind(0, bci, jj, ur_bc, ur_w));
891 uni_vmovups(accvr, vmm_tmp);
892 if (jpp.is_training) {
893 const auto indvr = vreg(reg_ind(2, bci, jj, ur_bc, ur_w));
894 uni_vpxor(indvr, indvr, indvr);
895 }
896 }
897 if (jpp.is_training) {
898 uni_vmovq(xmm_tmp, reg_k_shift);
899 uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
900 }
901 if (jpp.ndims == 5) {
902 push(reg_input);
903 push(reg_output);
904 mov(aux_reg_input_d, reg_input);
905 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
906 L(kd_label);
907 mov(aux_reg_input, aux_reg_input_d);
908 } else {
909 mov(aux_reg_input, reg_input);
910 }
911 xor_(kj, kj);
912 L(kh_label);
913 {
914 for (int ki = 0; ki < kw; ki++) {
915 int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
916 int jj_end = ur_w
917 - utils::div_up(
918 nstl::max(0, ki + pad_r - (kw - 1)), stride_w);
919 for_(int jj = jj_start; jj < jj_end; jj++)
920 for (int bci = 0; bci < ur_bc; bci++) {
921 const auto accvr = vreg(reg_ind(0, bci, jj, ur_bc, ur_w));
922 const auto inpr_i = reg_ind(1, bci, jj, ur_bc, ur_w);
923 const auto inpvr = vreg(inpr_i);
924 const auto indvr = vreg(reg_ind(2, bci, jj, ur_bc, ur_w));
925 const auto cvtvr = vreg(reg_ind(3, bci, jj, ur_bc, ur_w));
926 int aux_input_offset
927 = (ki + jj * stride_w - pad_l) * c_off + bci * c_block;
928 if (aux_input_offset >= iw * c_off) continue;
929 int input_offset = jpp.dt_size * aux_input_offset;
930 load(reg_idx(inpr_i), aux_reg_input, input_offset,
931 is_tail_processing(bci));
932 if (isa == sse41) {
933 movups(vmm_mask, accvr);
934 cmpps(vmm_mask, inpvr, _cmp_lt_os);
935 blendvps(accvr, inpvr);
936 if (jpp.is_training) blendvps(indvr, vmm_k_offset);
937 } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) {
938 vcmpps(cvtvr, accvr, inpvr, _cmp_lt_os);
939 vblendvps(accvr, accvr, inpvr, cvtvr);
940 if (jpp.is_training)
941 vblendvps(indvr, indvr, vmm_k_offset, cvtvr);
942 } else {
943 vcmpps(k_store_mask, accvr, inpvr, _cmp_lt_os);
944 vblendmps(accvr | k_store_mask, accvr, inpvr);
945 if (jpp.is_training)
946 vblendmps(indvr | k_store_mask, indvr, vmm_k_offset);
947 }
948 }
949 if (jpp.is_training) {
950 if (with_c_tail_proccessing
951 && (utils::one_of(isa, avx, avx2, avx2_vnni_2))) {
952 push_vmm_val(vmm_c_tail_mask.getIdx());
953 put_one_in_vmm();
954 }
955
956 if (isa == avx && !mayiuse(avx2))
957 avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
958 else
959 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
960
961 if (with_c_tail_proccessing
962 && (utils::one_of(isa, avx, avx2, avx2_vnni_2)))
963 pop_vmm_val(vmm_c_tail_mask.getIdx());
964 }
965 }
966 add(aux_reg_input, jpp.dt_size * iw * c_off);
967 inc(kj);
968 cmp(kj, reg_kh);
969 jl(kh_label, T_NEAR);
970 }
971
972 if (jpp.ndims == 5) {
973 add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_off);
974 if (jpp.is_training) {
975 mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]);
976 uni_vmovq(xmm_tmp, tmp_gpr);
977 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
978 if (isa == avx && !mayiuse(avx2)) {
979 Xmm t(vmm_mask.getIdx());
980 avx_vpadd1(vmm_k_offset, xmm_tmp, t);
981 } else {
982 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
983 }
984 }
985
986 dec(ki);
987 cmp(ki, 0);
988 jg(kd_label, T_NEAR);
989 pop(reg_output);
990 pop(reg_input);
991 }
992
993 if (with_c_tail_proccessing && jpp.is_c_padded && isa == sse41)
994 mov(tmp_gpr, 0); // needed zero to fill padded tail
995
996 if (jpp.with_postops)
997 apply_postops(ur_bc, ur_w, c_block, is_tail_processing);
998
999 for_(int jj = 0; jj < ur_w; jj++)
1000 for (int bci = 0; bci < ur_bc; bci++) {
1001 const auto accr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
1002 const auto accvr = vreg(accr_i);
1003 const auto output_offset = jpp.dt_size * (jj * c_off + bci * c_block);
1004 auto accyr = yreg(accr_i);
1005 if (jpp.is_bf16) {
1006 if (isa == avx2_vnni_2) {
1007 auto accxr = xreg(accr_i);
1008 vcvtneps2bf16(accxr, accyr, Xbyak::VexEncoding);
1009 } else {
1010 auto acczr = zreg(accr_i);
1011 if (!isa_has_bf16(jpp.isa))
1012 bf16_emu_->vcvtneps2bf16(accyr, acczr);
1013 else
1014 vcvtneps2bf16(accyr, accvr);
1015 }
1016 } else if (jpp.is_f16) {
1017 if (isa == avx2_vnni_2) {
1018 auto accxr = xreg(accr_i);
1019 vcvtps2ph(accxr, accyr, _op_mxcsr);
1020 } else
1021 vcvtps2ph(accyr, accvr, _op_mxcsr);
1022 }
1023 store(reg_idx(accr_i), reg_output, output_offset,
1024 is_tail_processing(bci));
1025
1026 if (jpp.is_training) {
1027 const size_t step_index = (jj * c_off + bci * c_block)
1028 * types::data_type_size(jpp.ind_dt);
1029
1030 const auto indr_i = reg_ind(2, bci, jj, ur_bc, ur_w);
1031 auto vr = vreg(indr_i);
1032 if (jpp.ind_dt == data_type::u8) {
1033 auto xr = xreg(indr_i);
1034 if (isa == sse41) {
1035 for (int i = 0; i < (jpp.c_block / 2); ++i) {
1036 if (is_tail_processing(bci)
1037 && i + (sse_high_half ? (jpp.c_block / 2) : 0)
1038 >= jpp.c_tail) {
1039 if (jpp.is_c_padded)
1040 mov(ptr[reg_index + step_index + i],
1041 tmp_gpr.cvt8()); // fill padded tail with zeros
1042 else
1043 break; // tail end
1044 } else {
1045 // bytes which should be stored are located in
1046 // least significant bits(8 to be precise) of 32 bits parts
1047 // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm
1048 pextrb(ptr[reg_index + step_index + i], xr, 4 * i);
1049 }
1050 }
1051 } else if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) {
1052 auto yr = yreg(indr_i);
1053 if (is_tail_processing(bci) && !jpp.is_c_padded) {
1054 const int max_nr_of_vals
1055 = jpp.c_tail > (jpp.c_block / 2)
1056 ? (jpp.c_block / 2)
1057 : jpp.c_tail;
1058 for (int i = 0; i < max_nr_of_vals; ++i) {
1059 // bytes which should be stored are located in
1060 // least significant bits(8 to be precise) of 32 bits parts
1061 // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm
1062 vpextrb(ptr[reg_index + step_index + i], xr, 4 * i);
1063 }
1064
1065 if (jpp.c_tail > (jpp.c_block / 2)) {
1066 Xmm higher_128bits(vmm_mask.getIdx());
1067 vextractf128(higher_128bits, yr, 1);
1068 for (int i = 0; i < jpp.c_tail - (jpp.c_block / 2);
1069 ++i) {
1070 // bytes which should be stored are located in
1071 // least significant bits(8 to be precise) of 32 bits parts
1072 // of xmm thus we need to store 0, 4, 8 and 12 byte of xmm
1073 vpextrb(ptr[reg_index + step_index
1074 + (jpp.c_block / 2) + i],
1075 higher_128bits, 4 * i);
1076 }
1077 }
1078 } else {
1079 if (is_tail_processing(bci)) {
1080 assert(jpp.is_c_padded);
1081 vandps(yr, yr, vmm_c_tail_mask);
1082 }
1083 if (jj == 0) {
1084 vmovd(xmm_tmp, reg_shuf_mask);
1085 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
1086 }
1087 if (mayiuse(avx2)) {
1088 vpshufb(yr, yr, vmm_tmp);
1089 vmovd(ptr[reg_index + step_index], xr);
1090 vperm2i128(yr, yr, yr, 0x1u);
1091 vmovd(ptr[reg_index + step_index
1092 + (jpp.c_block / 2)],
1093 xr);
1094 } else {
1095 Xmm t(vmm_mask.getIdx());
1096 vextractf128(t, yr, 0);
1097 vpshufb(t, t, xmm_tmp);
1098 vmovd(ptr[reg_index + step_index], t);
1099 vextractf128(t, yr, 1);
1100 vpshufb(t, t,
1101 xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0]
1102 vmovd(ptr[reg_index + step_index
1103 + (jpp.c_block / 2)],
1104 t);
1105 }
1106 }
1107 } else {
1108 if (is_tail_processing(bci)) {
1109 if (jpp.is_c_padded) {
1110 knotw(k_c_tail_mask, k_c_tail_mask);
1111 vpxord(vr | k_c_tail_mask, vr, vr);
1112 knotw(k_c_tail_mask, k_c_tail_mask);
1113 vpmovusdb(ptr[reg_index + step_index], vr);
1114 } else
1115 vpmovusdb(ptr[reg_index + step_index],
1116 vr | k_c_tail_mask);
1117 } else {
1118 vpmovusdb(ptr[reg_index + step_index], vr);
1119 }
1120 }
1121 } else {
1122 store(vr.getIdx(), reg_index, step_index,
1123 is_tail_processing(bci));
1124 }
1125 }
1126 }
1127}
1128
1129template <cpu_isa_t isa>
1130inline void jit_uni_pool_kernel<isa>::max_step_bwd(int ur_w, int ur_bc,
1131 int pad_l, int pad_r, bool with_c_tail_proccessing) {
1132
1133 int iw = jpp.iw;
1134 int kw = jpp.kw;
1135 int stride_w = jpp.stride_w;
1136 int c_block = jpp.c_block;
1137 const int c_off
1138 = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block;
1139 Label kd_label, kh_label;
1140
1141 const auto is_tail_processing = [&](int bc) {
1142 if (isa == sse41) {
1143 return with_c_tail_proccessing && bc == (ur_bc - 1)
1144 && ((jpp.c_tail > (jpp.c_block / 2) && sse_high_half)
1145 || (jpp.c_tail < (jpp.c_block / 2)
1146 && !sse_high_half)
1147 || (jpp.c_tail == (jpp.c_block / 2) && sse_high_half
1148 && jpp.is_c_padded));
1149 } else
1150 return with_c_tail_proccessing && bc == (ur_bc - 1);
1151 };
1152
1153 for_(int jj = 0; jj < ur_w; jj++)
1154 for (int bci = 0; bci < ur_bc; bci++) {
1155 const auto outr_i = reg_ind(0, bci, jj, ur_bc, ur_w);
1156 auto out_offset = jpp.dt_size * (jj * c_off + bci * c_block);
1157 load(reg_idx(outr_i), reg_output, out_offset, is_tail_processing(bci));
1158 const size_t step_index = (jj * c_off + bci * c_block)
1159 * types::data_type_size(jpp.ind_dt);
1160
1161 const auto indr_i = reg_ind(1, bci, jj, ur_bc, ur_w);
1162 auto indvr = vreg(indr_i);
1163 if (jpp.ind_dt == data_type::u8) {
1164 auto indxr = xreg(indr_i);
1165 if (isa == sse41) {
1166 if (is_tail_processing(bci) && !jpp.is_c_padded) {
1167 for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2); i++)
1168 pinsrb(indxr, ptr[reg_index + step_index + i], i);
1169 } else {
1170 movd(indxr, ptr[reg_index + step_index]);
1171 }
1172 pmovzxbd(indvr, indxr);
1173 } else if (isa == avx || isa == avx2) {
1174 if (is_tail_processing(bci) && !jpp.is_c_padded) {
1175 for (int i = 0; i < jpp.c_tail; i++)
1176 vpinsrb(indxr, indxr, ptr[reg_index + step_index + i],
1177 i);
1178 } else {
1179 vmovq(indxr, ptr[reg_index + step_index]);
1180 }
1181 if (!mayiuse(avx2)) {
1182 avx_pmovzxbd(indvr, indxr, xmm_tmp);
1183 } else {
1184 vpmovzxbd(indvr, indxr);
1185 }
1186 } else {
1187 if (is_tail_processing(bci) && !jpp.is_c_padded) {
1188 vpmovzxbd(indvr | k_c_tail_mask | T_z,
1189 ptr[reg_index + step_index]);
1190 } else {
1191 vpmovzxbd(indvr, ptr[reg_index + step_index]);
1192 }
1193 }
1194 } else {
1195 load(indvr.getIdx(), reg_index, step_index,
1196 is_tail_processing(bci));
1197 }
1198 }
1199 uni_vmovq(xmm_tmp, reg_k_shift);
1200 uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
1201
1202 if (jpp.simple_alg && jpp.ndims == 5) {
1203 push(reg_input);
1204 push(reg_output);
1205 if (isa == sse41) {
1206 // Save rdi since it is used in maskmovdqu
1207 assert(dst_ptr == rdi);
1208 push(dst_ptr);
1209 }
1210 mov(aux_reg_input_d, reg_input);
1211 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
1212 mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]);
1213 L(kd_label);
1214 mov(aux_reg_input, aux_reg_input_d);
1215 } else {
1216 mov(aux_reg_input, reg_input);
1217 }
1218
1219 xor_(kj, kj);
1220 L(kh_label);
1221 {
1222 for (int ki = 0; ki < kw; ki++) {
1223 int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
1224 int jj_end = ur_w
1225 - utils::div_up(
1226 nstl::max(0, ki + pad_r - (kw - 1)), stride_w);
1227 for_(int jj = jj_start; jj < jj_end; jj++)
1228 for (int bci = 0; bci < ur_bc; bci++) {
1229 const auto outvr = vreg(reg_ind(0, bci, jj, ur_bc, ur_w));
1230 const auto indvr = vreg(reg_ind(1, bci, jj, ur_bc, ur_w));
1231 const auto inpr_i = reg_ind(2, bci, jj, ur_bc, ur_w);
1232 const auto inpvr = vreg(inpr_i);
1233 const auto cvtvr = vreg(reg_ind(3, bci, jj, ur_bc, ur_w));
1234 int aux_inp_offset
1235 = (ki + jj * stride_w - pad_l) * c_off + bci * c_block;
1236 if (aux_inp_offset >= iw * c_off) continue;
1237 int inp_offset = jpp.dt_size * aux_inp_offset;
1238 load(reg_idx(inpr_i), aux_reg_input, inp_offset,
1239 is_tail_processing(bci));
1240 if (isa == sse41) {
1241 mov(dst_ptr, aux_reg_input);
1242 add(dst_ptr, inp_offset);
1243
1244 movups(cvtvr, indvr);
1245 pcmpeqd(cvtvr, vmm_k_offset);
1246 addps(inpvr, outvr);
1247 if (is_tail_processing(bci)) {
1248 Label end_cond_move[4];
1249 for (int i = 0; i < jpp.c_tail % (jpp.c_block / 2);
1250 i++) {
1251 pextrd(tmp_gpr.cvt32(), cvtvr, i);
1252 cmp(tmp_gpr, 0);
1253 je(end_cond_move[i], T_NEAR);
1254 pextrd(ptr[dst_ptr + i * jpp.dt_size], inpvr, i);
1255 L(end_cond_move[i]);
1256 }
1257 } else
1258 maskmovdqu(inpvr, cvtvr);
1259 } else if (isa == avx || isa == avx2) {
1260 if (mayiuse(avx2)) {
1261 vpcmpeqd(cvtvr, indvr, vmm_k_offset);
1262 } else {
1263 avx_pcmpeqd(cvtvr, indvr, vmm_k_offset, xmm_tmp);
1264 }
1265 vaddps(inpvr, inpvr, outvr);
1266 if (is_tail_processing(bci)) {
1267 vandps(cvtvr, cvtvr, vmm_c_tail_mask);
1268 }
1269 vmaskmovps(
1270 vmmword[aux_reg_input + inp_offset], cvtvr, inpvr);
1271 } else {
1272 auto indzr = zreg(inpr_i);
1273 auto indyr = yreg(inpr_i);
1274 vpcmpeqd(k_store_mask, indvr, vmm_k_offset);
1275 vblendmps(vmm_tmp | k_store_mask | T_z, outvr, outvr);
1276 vaddps(inpvr, inpvr, vmm_tmp);
1277 if (jpp.is_bf16) {
1278 if (!isa_has_bf16(jpp.isa))
1279 bf16_emu_->vcvtneps2bf16(indyr, indzr);
1280 else
1281 vcvtneps2bf16(indyr, inpvr);
1282 } else if (jpp.is_f16) {
1283 vcvtps2ph(indyr, inpvr, _op_mxcsr);
1284 }
1285 store(inpvr.getIdx(), aux_reg_input, inp_offset,
1286 is_tail_processing(bci));
1287 }
1288 }
1289
1290 if (with_c_tail_proccessing && (isa == avx || isa == avx2)) {
1291 push_vmm_val(vmm_c_tail_mask.getIdx());
1292 put_one_in_vmm();
1293 }
1294
1295 if (isa == avx && !mayiuse(avx2)) {
1296 avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
1297 } else {
1298 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
1299 }
1300
1301 if (with_c_tail_proccessing && (isa == avx || isa == avx2))
1302 pop_vmm_val(vmm_c_tail_mask.getIdx());
1303 }
1304 add(aux_reg_input, jpp.dt_size * iw * c_off);
1305 inc(kj);
1306 cmp(kj, reg_kh);
1307 jl(kh_label, T_NEAR);
1308 }
1309 if (jpp.simple_alg && jpp.ndims == 5) {
1310 add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_off);
1311
1312 mov(tmp_gpr, reg_kd_pad_shift);
1313 uni_vmovq(xmm_tmp, tmp_gpr);
1314 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
1315 if (isa == avx && !mayiuse(avx2)) {
1316 Xmm t(vmm_mask.getIdx());
1317 avx_vpadd1(vmm_k_offset, vmm_tmp, t);
1318 } else {
1319 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
1320 }
1321
1322 dec(ki);
1323 cmp(ki, 0);
1324 jg(kd_label, T_NEAR);
1325 if (isa == sse41) {
1326 // Save rdi since it is used in maskmovdqu
1327 assert(dst_ptr == rdi);
1328 pop(dst_ptr);
1329 }
1330 pop(reg_output);
1331 pop(reg_input);
1332 }
1333}
1334
1335template <cpu_isa_t isa>
1336void jit_uni_pool_kernel<isa>::zero_diff_src(
1337 int ur_bc, bool with_c_tail_proccessing) {
1338 const int c_off = (jpp.tag_kind == jit_memory_tag_kind_t::nspc)
1339 ? jpp.c
1340 : jpp.c_block;
1341
1342 Label l_skip, l_ih_loop, l_id_loop;
1343
1344 auto is_tail_processing = [&](int bc) {
1345 return with_c_tail_proccessing && bc == (ur_bc - 1);
1346 };
1347
1348 mov(reg_zero_id, ptr[reg_param + GET_OFF(zero_id)]);
1349 cmp(reg_zero_id, 0);
1350 jz(l_skip, T_NEAR);
1351
1352 mov(reg_zero_ih, ptr[reg_param + GET_OFF(zero_ih)]);
1353 cmp(reg_zero_ih, 0);
1354 jz(l_skip, T_NEAR);
1355
1356 mov(reg_zero_ptr, ptr[reg_param + GET_OFF(zero_ptr)]);
1357
1358 Vmm vzero = vmm_tmp;
1359 uni_vpxor(vzero, vzero, vzero);
1360
1361 const int width_size = jpp.iw * c_off * jpp.dt_size;
1362
1363 auto aux_reg_zero_ptr = tmp_gpr;
1364
1365 L(l_id_loop);
1366 {
1367 mov(aux_reg_zero_ptr, reg_zero_ptr);
1368 mov(aux_reg_zero_ih, reg_zero_ih);
1369 L(l_ih_loop);
1370 {
1371 const auto vlen = cpu_isa_traits<isa>::vlen;
1372 const int step = c_off * jpp.dt_size;
1373
1374 // TODO: maybe a big code generated here
1375 for_(int i = 0; i < width_size; i += step)
1376 for (int bci = 0; bci < ur_bc; bci++) {
1377 const int offs = i + bci * jpp.c_block * jpp.dt_size;
1378 if (isa == sse41) {
1379 bool is_needed_c_tail_processing = false;
1380 if (is_tail_processing(bci)
1381 && jpp.c_tail < (jpp.c_block / 2))
1382 is_needed_c_tail_processing = true;
1383 store(vzero.getIdx(), reg_zero_ptr, offs,
1384 is_needed_c_tail_processing);
1385 if (!is_tail_processing(bci)
1386 || (is_tail_processing(bci)
1387 && (jpp.is_c_padded
1388 || jpp.c_tail
1389 > (jpp.c_block / 2)))) {
1390 store(vzero.getIdx(), reg_zero_ptr, offs + vlen,
1391 is_tail_processing(bci));
1392 }
1393
1394 } else {
1395 store(vzero.getIdx(), reg_zero_ptr, offs,
1396 is_tail_processing(bci));
1397 }
1398 }
1399 add(reg_zero_ptr, width_size);
1400 dec(aux_reg_zero_ih);
1401 jnz(l_ih_loop, T_NEAR);
1402 }
1403 mov(reg_zero_ptr, aux_reg_zero_ptr);
1404 add(reg_zero_ptr, width_size * jpp.ih);
1405 dec(reg_zero_id);
1406 jnz(l_id_loop, T_NEAR);
1407 }
1408
1409 L(l_skip);
1410}
1411
1412template <cpu_isa_t isa>
1413void jit_uni_pool_kernel<isa>::generate() {
1414
1415 this->preamble();
1416
1417 Label idx_table;
1418
1419 int ow = jpp.ow;
1420 int iw = jpp.iw;
1421 int kw = jpp.kw;
1422 int kh = jpp.kh;
1423 int c_block = jpp.c_block;
1424 int stride_w = jpp.stride_w;
1425 int l_pad = jpp.l_pad;
1426 const int c_off
1427 = (jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c : c_block;
1428
1429 int vlen = cpu_isa_traits<isa>::vlen;
1430
1431#if defined(_WIN32)
1432 // Always mimic the Unix ABI (see the note about maskmovdqu in the header
1433 // file).
1434 xor_(rdi, rcx);
1435 xor_(rcx, rdi);
1436 xor_(rdi, rcx);
1437#endif
1438 if (use_bf16_emulation()) bf16_emu_->init_vcvtneps2bf16();
1439
1440 mov(reg_input, ptr[reg_param + GET_OFF(src)]);
1441 mov(reg_output, ptr[reg_param + GET_OFF(dst)]);
1442 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
1443 mov(reg_index, ptr[reg_param + GET_OFF(indices)]);
1444 mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]);
1445 mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]);
1446 mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]);
1447 mov(reg_nbc, ptr[reg_param + GET_OFF(ur_bc)]);
1448
1449 if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) {
1450 mov(tmp_gpr.cvt32(), 0xAAAAAAAA);
1451 kmovd(k_mask_cvt, tmp_gpr.cvt32());
1452
1453 mov(tmp_gpr, idx_table);
1454 vmovups(vmm_idx(), ptr[tmp_gpr]);
1455 }
1456
1457 auto process_oi = [&](int ur_w, int ur_bc, int lpad, int rpad,
1458 bool with_c_tail_proccessing,
1459 bool inc_reg = true) {
1460 step(ur_w, ur_bc, lpad, rpad, with_c_tail_proccessing);
1461
1462 if (isa == sse41) {
1463 if (with_c_tail_proccessing && jpp.c_tail <= (jpp.c_block / 2)) {
1464
1465 // In nspc format in case of c tail processing if c tail is
1466 // equal or lower than 4 we don't have to process
1467 // last high half block, because it doesn't exist
1468 if (!jpp.is_c_padded) ur_bc -= 1;
1469 /*
1470 * In case of c_tail_processing if c_tail is equal or lower than 4
1471 * applying postops never make sense. In case of blocked format it
1472 * can cause overwriting zero padding or segfault because the element
1473 * corresponding to the piece with padded zeros doesn't exist in binary
1474 * postops arg1 tensor (nchw format) in per_oc bcast strategy.
1475 */
1476 disable_postops_when_sse_high_half_processed_
1477 = jpp.tag_kind == jit_memory_tag_kind_t::blocked;
1478 }
1479 sse_high_half = true;
1480 step_high_half(ur_w, ur_bc, lpad, rpad, with_c_tail_proccessing);
1481 sse_high_half = false;
1482 disable_postops_when_sse_high_half_processed_ = false;
1483 }
1484
1485 if (!inc_reg) return;
1486
1487 auto dt_size = jpp.dt_size;
1488 auto shift = (isa == sse41) ? vlen : 0;
1489 add(reg_input,
1490 dt_size * nstl::max(0, ur_w * stride_w - lpad) * c_off - shift);
1491 add(reg_output, dt_size * ur_w * c_off - shift);
1492 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
1493 auto ishift = (isa == sse41) ? jpp.c_block / 2 : 0;
1494 auto ind_dt_size = types::data_type_size(jpp.ind_dt);
1495 add(reg_index, (ur_w * c_off - ishift) * ind_dt_size);
1496 }
1497 };
1498
1499 auto perform_ker = [&](int ur_bc, bool with_c_tail_processing) {
1500 prev_kw = 0; // re-initialize this value for avg steps
1501
1502 if (jpp.is_backward && jpp.simple_alg)
1503 zero_diff_src(ur_bc, with_c_tail_processing);
1504
1505 if (jpp.alg == pooling_avg_exclude_padding
1506 && (!with_c_tail_processing
1507 || (!utils::one_of(isa, avx, avx2, avx2_vnni_2)))) {
1508 // vmm_ker_area_h and vmm_c_tail_mask are stored in one register
1509 // so when vmm_c_tail_mask is used we need to load vmm_ker_area_h
1510 // exactly where this information is needed with the
1511 // vmm_c_tail_mask information being saved first
1512 uni_broadcast_reg_val(
1513 reg_ker_area_h.getIdx(), vmm_ker_area_h.getIdx());
1514 }
1515
1516 if (jpp.alg == pooling_avg_include_padding) {
1517 mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd)));
1518 uni_vmovq(xmm_tmp, tmp_gpr);
1519 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
1520 }
1521
1522 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
1523 if (!with_c_tail_processing
1524 || (!utils::one_of(isa, avx, avx2, avx2_vnni_2))) {
1525 // The same situation as above(vmm_ker_area_h).
1526 put_one_in_vmm();
1527 }
1528
1529 if (utils::one_of(isa, avx, avx2, avx2_vnni_2)) {
1530 mov(reg_shuf_mask, 0x0c080400);
1531 }
1532 }
1533
1534 const int ur_w = nstl::min(jpp.ow, jpp.ur / jpp.ur_bc);
1535 const int n_oi_iterations = utils::div_up(ow, ur_w);
1536 const int ur_stride_w = ur_w * stride_w;
1537 const int l_pad_iterations
1538 = nstl::min(n_oi_iterations, utils::div_up(l_pad, ur_stride_w));
1539
1540 for (int i = 0; i < l_pad_iterations; ++i) {
1541 const int ow_s = i * ur_w;
1542 const int ow_e = nstl::min(ow, ow_s + ur_w);
1543 const int cur_l_pad = l_pad - i * ur_stride_w;
1544 const int cur_r_pad = nstl::max(
1545 0, calculate_end_padding(l_pad, ow_e, iw, stride_w, kw));
1546 const int cur_ur_w = ow_e - ow_s;
1547 process_oi(cur_ur_w, ur_bc, cur_l_pad, cur_r_pad,
1548 with_c_tail_processing);
1549 }
1550
1551 const int rem_n_oi_iters = n_oi_iterations - l_pad_iterations;
1552 const int cur_iw = l_pad_iterations * ur_stride_w - l_pad;
1553 const int cur_iw_rightmost_idx = cur_iw + kw - 1;
1554 const int no_pad_full_n_oi_iters = utils::saturate<int>(
1555 0, rem_n_oi_iters, (iw - cur_iw_rightmost_idx) / ur_stride_w);
1556
1557 if (no_pad_full_n_oi_iters > 0) {
1558 Label ow_loop;
1559 if (no_pad_full_n_oi_iters > 1) xor_(oi_iter, oi_iter);
1560 L(ow_loop);
1561 {
1562 process_oi(ur_w, ur_bc, 0, 0, with_c_tail_processing);
1563 if (no_pad_full_n_oi_iters > 1) {
1564 inc(oi_iter);
1565 cmp(oi_iter, no_pad_full_n_oi_iters);
1566 jl(ow_loop, T_NEAR);
1567 }
1568 }
1569 }
1570
1571 for (int i = l_pad_iterations + no_pad_full_n_oi_iters;
1572 i < n_oi_iterations; ++i) {
1573 const int ow_s = i * ur_w;
1574 const int ow_e = nstl::min(ow, ow_s + ur_w);
1575 const int cur_r_pad = nstl::max(
1576 0, calculate_end_padding(l_pad, ow_e, iw, stride_w, kw));
1577 const int cur_ur_w = ow_e - ow_s;
1578 process_oi(cur_ur_w, ur_bc, 0, cur_r_pad, with_c_tail_processing);
1579 }
1580 };
1581 Label ur_bc_tail_label, c_tail_processing_label, finish_label;
1582
1583 if (jpp.ur_bc_tail > 0) {
1584 cmp(reg_nbc, jpp.ur_bc);
1585 jne(ur_bc_tail_label, T_NEAR);
1586 } else if (jpp.c_tail != 0) {
1587 // ur_bc contains number of channel blocks to processing
1588 // b_c contains number of channel blocks already processed
1589 // If reg_nbc + tmp_gpr == jpp.nb_c then this is
1590 // information that probably channel tail processing will be needed.
1591 mov(tmp_gpr, ptr[reg_param + GET_OFF(b_c)]);
1592 add(tmp_gpr, reg_nbc);
1593 cmp(tmp_gpr, jpp.nb_c);
1594 je(c_tail_processing_label, T_NEAR);
1595 }
1596
1597 perform_ker(jpp.ur_bc, false);
1598
1599 if (jpp.ur_bc_tail > 0) {
1600 jmp(finish_label, T_NEAR);
1601
1602 // If ur_bc_tail exists then we know that this is
1603 // last set of blocks to process and we need
1604 // care of c tail processing if number of channels
1605 // is not divided by number of channels in block
1606 L(ur_bc_tail_label);
1607 if (jpp.c_tail != 0) prepare_tail_mask();
1608 perform_ker(jpp.ur_bc_tail, jpp.c_tail != 0);
1609
1610 L(finish_label);
1611 } else if (jpp.c_tail != 0) {
1612 jmp(finish_label, T_NEAR);
1613
1614 L(c_tail_processing_label);
1615 prepare_tail_mask();
1616 perform_ker(jpp.ur_bc, true);
1617
1618 L(finish_label);
1619 }
1620
1621 this->postamble();
1622
1623 if (jpp.with_eltwise && postops_injector_)
1624 postops_injector_->prepare_table();
1625
1626 if ((jpp.is_bf16 || jpp.is_f16) && isa != avx2_vnni_2) {
1627 align(64);
1628 L(idx_table);
1629 const uint16_t _idx[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7,
1630 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15};
1631 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
1632 dw(_idx[i]);
1633 }
1634}
1635
1636template struct jit_uni_pool_kernel<sse41>;
1637template struct jit_uni_pool_kernel<avx>;
1638template struct jit_uni_pool_kernel<avx2>;
1639template struct jit_uni_pool_kernel<avx2_vnni_2>;
1640template struct jit_uni_pool_kernel<avx512_core>;
1641template struct jit_uni_pool_kernel<avx512_core_fp16>;
1642
1643} // namespace x64
1644} // namespace cpu
1645} // namespace impl
1646} // namespace dnnl
1647
1648// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1649