1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include "cpu/x64/jit_uni_i8i8_pooling.hpp"
17#include <math.h>
18
19#include "common/dnnl_thread.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
23#include "cpu/x64/jit_generator.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30static bcast_set_t get_supported_bcast_strategies() {
31 return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc};
32}
33
34static inline dim_t get_offset(
35 const memory_desc_wrapper &mdw, int n, int c, int d, int h, int w) {
36 switch (mdw.ndims()) {
37 case 3: return mdw.blk_off(n, c, w);
38 case 4: return mdw.blk_off(n, c, h, w);
39 case 5: return mdw.blk_off(n, c, d, h, w);
40 default: assert(!"Invalid tensor dimension in pooling");
41 }
42 return 0;
43}
44
45using namespace Xbyak;
46
47using namespace dnnl::impl::utils;
48using namespace dnnl::impl::utils;
49using namespace dnnl::impl::types;
50using namespace alg_kind;
51
52#define GET_OFF(field) offsetof(jit_uni_i8i8_pool_call_params_t, field)
53
54struct jit_uni_i8i8_pool_call_params_t {
55 const char *src_i8;
56 const char *dst_i8;
57 const char *dst_orig;
58 const void *post_ops_binary_rhs_arg_vec;
59 size_t kd_range;
60 size_t kh_range;
61 size_t kw_range;
62 float idivider;
63 const char *src_safe_access;
64 const char *dst_safe_access;
65};
66
67template <cpu_isa_t isa>
68struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator {
69 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
70
71 using Vmm = typename cpu_isa_traits<isa>::Vmm;
72 Xmm xreg(int idx) const { return Xmm(idx); }
73 Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
74 Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
75
76 // In case of avx2 with data type i8 we need to use
77 // maskmovdqu and maskmovq instructions which has its destination hardcoded in rdi.
78 // Windows ABI: abi_param1 is rcx - nothing to do else
79 // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
80 Reg64 reg_param = rcx; // Our "unified abi_param1"
81 Reg64 reg_ptr_src_i8 = r8;
82 Reg64 reg_ptr_dst_i8 = r9;
83 Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
84
85 Reg64 reg_kd_index
86 = rdi; // shared with reg_ptr_maskmovdqu_dst; only used before store
87 Reg64 reg_kh_index = r11;
88 Reg64 reg_kw_index = r10;
89 Reg64 reg_kd = r14;
90 Reg64 reg_kh = r13;
91 Reg64 reg_kw = r12;
92 Reg64 c_iter = r15; // shared with reg_mask; only used after mask init
93
94 Reg64 aux_reg_src_d
95 = rdx; // shared with reg_tmp; loaded before each accum loop, unused during store
96 Reg64 aux_reg_src_h = rax;
97 Reg64 aux_reg_src_w = rbx;
98
99 Reg64 reg_tmp = rdx; // only used during mask init and store
100 Reg64 reg_src_safe_access = rbp;
101 Reg64 reg_dst_safe_access = rsi;
102
103 Reg64 reg_mask = r15; // only used during mask init
104
105 Opmask k_cmp_mask = Opmask(7);
106
107 Opmask mask(int idx) { return Opmask(6 - idx); }
108
109 // ref to any of XYZ-regs via xreg/yreg/vreg functions
110 Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp
111 Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type
112 Vmm vreg_zeros = vreg(1);
113 Vmm vreg_tail = vreg(4);
114
115 // only in case of <isa> == avx2
116 Vmm vreg_mask = vreg(2); // full byte-mask
117 Xmm xreg_mask_lo = xreg(
118 2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
119 Xmm xreg_mask_hi = xreg(
120 3); // "max" - high 128-bits part of byte-mask (stored separately)
121
122 // vreg_mask shifted left (aligned left) to be used in tail processing.
123 // Example: idx [31..0]
124 // vreg_mask = [0,0,0,0,0,.....,0,x,x,x,x,x] ; x => byte mask (msb set)
125 // vreg_mask_2 = [x,x,x,x,x,0,0,0,0,0,.....,0]
126 Vmm vreg_mask_2 = vreg(5);
127 Xmm xreg_mask_2_lo = xreg(5); // similar to xreg_mask_lo
128 Xmm xreg_mask_2_hi = xreg(6); // similar to xreg_mask_hi
129
130 Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails
131 Mmx mmx_dst_i8 = Mmx(
132 0); // "avg" - Mmx reg for masked store results of s8/u8 operations
133 Mmx mmx_full_msk = Mmx(
134 1); // "avg" - Mmx reg for full mask (all 8 bytes) - used until not in tail
135 Mmx mmx_tmp = Mmx(2);
136 int post_op_tail_opmask_idx_ = -1;
137 jit_pool_conf_t jpp;
138 std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
139 postops_injector_;
140
141 enum : int { max_vidx_base = utils::one_of(isa, sse41, avx2) ? 7 : 2 };
142 //"avg" pool uses more registers for unrolling.
143 enum : int { avg_vidx_base = utils::one_of(isa, sse41, avx2) ? 4 : 2 };
144
145 Vmm max_base_vr(int idx) const { return vreg(max_vidx_base + idx); }
146 Vmm avg_base_vr(int idx) const { return vreg(avg_vidx_base + idx); }
147
148 size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
149 size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
150
151 /* max pooling */
152 Vmm vreg_src(int idx) const { return max_base_vr(idx); } // [0 .. ur_c-1]
153 Vmm vreg_dst(int idx) const {
154 return max_base_vr(jpp.ur_c + idx);
155 } // [ur_c .. 2*ur_c-1]
156
157 /* avg pooling */
158 // s32 used for processing of s8/u8 data
159 // thus we need to take into account ratio of sizes s32/i8 = 4
160 static constexpr data_type_t avg_proc_dt = data_type::s32;
161 enum : int {
162 s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
163 / sizeof(typename prec_traits<data_type::u8>::type),
164 max_num_ll = s32_to_i8_ratio,
165 mmx_msk_base_reg = 3
166 };
167
168 inline size_t get_offset_dst(int jj, int ll) const {
169 size_t offset = 0;
170 switch (jpp.alg) {
171 case pooling_max: {
172 offset = jj * jpp.c_block * sizeof_dst_dt();
173 break;
174 }
175 case pooling_avg_include_padding:
176 case pooling_avg_exclude_padding: {
177 offset = (ll * (jpp.c_block / max_num_ll) + jj * jpp.c_block)
178 * sizeof_dst_dt();
179 break;
180 }
181 default: assert(!"unsupported pooling algorithm");
182 }
183 return offset;
184 }
185
186 Vmm vreg_src_s32(int jj, int ll) {
187 return avg_base_vr(3 * max_num_ll * jj + ll + 0 * max_num_ll);
188 } // ll: 0..4 [0..3]
189
190 Vmm vreg_dst_s32(int jj, int ll) {
191 return avg_base_vr(3 * max_num_ll * jj + ll + 1 * max_num_ll);
192 } // ll: 0..4 [4..7]
193
194 Vmm vreg_dst_f32(int jj, int ll) {
195 return avg_base_vr(3 * max_num_ll * jj + ll + 2 * max_num_ll);
196 } // ll: 0..4 [8..11]
197
198 Mmx mmx_mask(int ll) {
199 return Mmx(mmx_msk_base_reg + ll);
200 }; // ll: 0..4 [Mmx(2)...Mmx(5)]
201
202 static bool post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr,
203 const memory_desc_wrapper &dst_d);
204
205 void init_tmp_reg();
206 void init_mask();
207
208 void load_vreg_mask_q(int ll) {};
209
210 void load_src_max_op(
211 int jj, int ll, size_t offset, bool masked, uint64_t msk);
212 void load_src_avg_op(
213 int jj, int ll, size_t offset, bool masked, uint64_t msk);
214 void load_src(int jj, int ll, int c_tail);
215
216 void store_dst_max_op(
217 int jj, int ll, size_t offset, bool masked, uint64_t msk);
218 void store_dst_avg_op(
219 int jj, int ll, size_t offset, bool masked, uint64_t msk);
220 void store_dst(int jj, int ll, int c_tail);
221
222 void compute_avg_step(int ur_c, int c_tail);
223 void compute_max_op(const int jj);
224 void compute_max_step(int ur_c, int c_tail);
225 void compute_step(int ur_c, int c_tail);
226
227 void compute_c_block();
228 void generate() override;
229
230 static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd);
231
232 jit_uni_i8i8_pooling_fwd_ker_t(
233 const jit_pool_conf_t &jpp_, const memory_desc_t *dst_md)
234 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
235 , jpp(jpp_)
236 , postops_injector_(nullptr) {
237
238 if (jpp.with_postops) {
239
240 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
241 const std::size_t c_tail_elems = jpp.c % simd_w;
242 post_op_tail_opmask_idx_ = 0;
243 if (c_tail_elems) {
244 for (int ll = max_num_ll - 1; ll >= 0; ll--) {
245 if (jpp.tail[ll] != 0) {
246 post_op_tail_opmask_idx_ = ll;
247 break;
248 }
249 }
250 };
251
252 static constexpr bool preserve_gpr = true;
253 static constexpr bool preserve_vmm = true;
254 static constexpr bool use_exact_tail_scalar_bcast = false;
255 static constexpr std::size_t tmp_vmm_injector = 0u;
256
257 const binary_injector::rhs_arg_static_params_t rhs_sp {
258 tmp_vmm_injector, r14, r15, r13, preserve_gpr, preserve_vmm,
259 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
260 memory_desc_wrapper(*dst_md), c_tail_elems,
261 mask(post_op_tail_opmask_idx_),
262 use_exact_tail_scalar_bcast};
263 const binary_injector::static_params_t bsp {
264 reg_param, get_supported_bcast_strategies(), rhs_sp};
265
266 postops_injector_ = utils::make_unique<
267 injector::jit_uni_postops_injector_t<isa>>(
268 this, jpp.post_ops, bsp);
269 }
270 }
271};
272
273template <>
274void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_vreg_mask_q(int ll) {};
275
276template <>
277void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
278
279 // extract ll-th part of mask (ll-th QWORD)
280 vpblendd(vreg_mask_q, vreg_zeros, vreg_mask,
281 0x3 << 2 * ll); // 0x3 - mask for 2 x DWORD
282
283 // Move mask from ll-th pos to 0-th pos
284 if (ll > 0) vpermq(vreg_mask_q, vreg_mask_q, ll);
285};
286
287template <>
288void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_src_max_op(
289 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
290 using namespace data_type;
291
292 if (masked) {
293 if (jpp.src_dt == s32)
294 for (int64_t i = 0; i < jpp.c_tail; i++)
295 pinsrd(vreg_src(jj),
296 ptr[aux_reg_src_w + offset + i * data_type_size(s32)],
297 i);
298 else
299 for (int i = 0; i < jpp.c_tail; i++)
300 pinsrb(vreg_src(jj), ptr[aux_reg_src_w + offset + i], i);
301 } else
302 movups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
303}
304
305template <>
306void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(
307 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
308 using namespace data_type;
309
310 if (masked) {
311 if (jpp.src_dt == s32) {
312 vpmaskmovd(vreg_src(jj), vreg_mask, ptr[aux_reg_src_w + offset]);
313 } else {
314 // Steps to access 'tail' section:
315 // 1) First load all data from the shifted src ptr
316 // 2) Now bring the required data from the end of reg to beginning.
317 // Example: idx=[31..0]
318 // vreg_src = [x,x,x,x,.....,x,-,-,-,-,-] ; x => byte data
319 // shift to transform vreg_src = [-,-,-,-,-,x,..,x,x,x,x,]
320 const uint8_t shift = cpu_isa_traits<avx2>::vlen - jpp.c_tail;
321
322 if (jpp.safe_c_tail) {
323
324 /* load src_tail at 'src_address - shift' so that it does not
325 * spill over the memory boundary */
326 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset - shift]);
327
328 vperm2i128(vreg_tmp, vreg_src(jj), vreg_src(jj), 0x81);
329 vpalignr(vreg_src(jj), vreg_tmp, vreg_src(jj), shift);
330
331 } else {
332 Label load_data_safely, done;
333 add(aux_reg_src_w, offset);
334
335 // Check if mask crosses page boundary
336 cmp(aux_reg_src_w, reg_src_safe_access);
337 ja(load_data_safely, T_NEAR);
338
339 vpblendvb(
340 vreg_src(jj), vreg_tmp, byte[aux_reg_src_w], vreg_mask);
341 jmp(done, T_NEAR);
342
343 L(load_data_safely);
344
345 /* load src_tail at 'src_address - shift' so that it does not
346 * spill over the memory boundary */
347 vmovups(vreg_src(jj), ptr[aux_reg_src_w - shift]);
348
349 vperm2i128(vreg_tmp, vreg_src(jj), vreg_src(jj), 0x81);
350 vpalignr(vreg_src(jj), vreg_tmp, vreg_src(jj), shift);
351
352 L(done);
353 sub(aux_reg_src_w, offset);
354 }
355 }
356
357 } else
358 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
359};
360
361template <>
362void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(
363 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
364 using namespace data_type;
365
366 if (masked) {
367 if (jpp.src_dt == s32)
368 vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
369 else
370 vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
371 } else
372 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
373};
374
375template <>
376void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_src_avg_op(
377 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
378 using namespace data_type;
379
380 const Vmm &vr_src = vreg_src_s32(jj, ll);
381
382 if (jpp.src_dt == s32) {
383 if (masked)
384 for (int64_t i = 0; i < jpp.c_tail; i++)
385 pinsrd(vr_src,
386 ptr[aux_reg_src_w + offset + i * data_type_size(s32)],
387 i);
388 else
389 movups(vr_src, ptr[aux_reg_src_w + offset]);
390 } else if (utils::one_of(jpp.src_dt, s8, u8)) {
391 if (masked) {
392 const int copy_range = math::ilog2q(jpp.tail[ll] + 1);
393 for (int i = 0; i < copy_range; i++)
394 pinsrb(vr_src, ptr[aux_reg_src_w + offset + i], i);
395
396 if (jpp.src_dt == s8)
397 pmovsxbd(vr_src, vr_src);
398 else
399 pmovzxbd(vr_src, vr_src);
400 } else {
401 if (jpp.src_dt == s8)
402 pmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
403 else
404 pmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
405 }
406 } else
407 assert(!"unsupported src data type");
408}
409
410template <>
411void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(
412 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
413 using namespace data_type;
414
415 auto load_i8 = [&](bool is_signed, const Vmm &vr_src) {
416 // Need to use mask of tail?
417 if (masked) {
418
419 // load ll-th part of mask into vreg_mask_q
420 load_vreg_mask_q(ll);
421
422 // Steps to access 'tail' section:
423 // 1) First load all data from the shifted src ptr
424 // 2) Now bring the required data from the end of reg to begining.
425 // Example: idx=[31..0]
426 // vreg_src = [x,x,x,x,.....,x,-,-,-,-,-] ; x => byte data
427 // shift to transform vreg_src = [-,-,-,-,-,x,..,x,x,x,x,]
428 // Re-purposing vreg_zeros here. Set it back to zero immmediately.
429 const int msk_gran
430 = cpu_isa_traits<avx2>::vlen / data_type_size(avg_proc_dt);
431
432 const uint8_t shift = cpu_isa_traits<avx2>::vlen
433 - (jpp.c_tail > (ll + 1) * msk_gran
434 ? msk_gran
435 : jpp.c_tail - (ll * msk_gran));
436 if (jpp.safe_c_tail) {
437 /* load src_tail at 'src_address - shift' so that it does not
438 * spill over the memory boundary */
439 vmovups(vr_src, ptr[aux_reg_src_w + offset - shift]);
440
441 vperm2i128(vreg_zeros, vr_src, vr_src, 0x81);
442 vpalignr(vr_src, vreg_zeros, vr_src, shift);
443 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
444 } else {
445 Label load_data_safely, done;
446 // assume that it is not safe to load the src_tail
447
448 add(aux_reg_src_w, offset);
449
450 // Check if load crosses the memory boundary
451 cmp(aux_reg_src_w, reg_src_safe_access);
452 ja(load_data_safely, T_NEAR);
453
454 vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w], vreg_mask_q);
455 jmp(done, T_NEAR);
456
457 L(load_data_safely);
458
459 /* load src_tail at 'src_address - shift' so that it does not
460 * spill over the memory boundary */
461 vmovups(vr_src, ptr[aux_reg_src_w - shift]);
462
463 vperm2i128(vreg_zeros, vr_src, vr_src, 0x81);
464 vpalignr(vr_src, vreg_zeros, vr_src, shift);
465 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
466
467 L(done);
468 sub(aux_reg_src_w, offset);
469 }
470
471 // Conversion s8/u8 -> s32
472 if (is_signed)
473 vpmovsxbd(vr_src, vr_src);
474 else
475 vpmovzxbd(vr_src, vr_src);
476 } else {
477
478 // Load from mem into vr_src with conversion
479 if (is_signed)
480 vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
481 else
482 vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
483 }
484 };
485
486 switch (jpp.src_dt) {
487 case s32:
488 if (masked)
489 vpmaskmovd(vreg_src_s32(jj, ll), vreg_mask,
490 ptr[aux_reg_src_w + offset]);
491 else
492 vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
493 break;
494 case s8: load_i8(true, vreg_src_s32(jj, ll)); break;
495 case u8: load_i8(false, vreg_src_s32(jj, ll)); break;
496 default: assert(!"unsupported src data type");
497 }
498};
499
500template <>
501void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(
502 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
503 using namespace data_type;
504
505 const Vmm &vr_src
506 = masked ? vreg_src_s32(jj, ll) | mask(ll) : vreg_src_s32(jj, ll);
507
508 switch (jpp.src_dt) {
509 case s32: vmovups(vr_src, ptr[aux_reg_src_w + offset]); break;
510 case s8: vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); break;
511 case u8: vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); break;
512 default: assert(!"unsupported src data type");
513 }
514};
515
516template <cpu_isa_t isa>
517void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
518 using namespace data_type;
519
520 int c_block = jpp.c_block;
521 int ur_c = jpp.ur_c;
522
523 switch (jpp.alg) {
524 case pooling_max: {
525 auto offset = jj * c_block * sizeof_src_dt();
526 bool masked = jj == ur_c - 1 && c_tail;
527 load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
528 break;
529 }
530 case pooling_avg_include_padding:
531 case pooling_avg_exclude_padding: {
532 auto offset = (ll * (c_block / max_num_ll) + jj * c_block)
533 * sizeof_src_dt();
534 bool masked = jj == ur_c - 1 && c_tail;
535 load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
536 break;
537 }
538 default: assert(!"unsupported algorithm");
539 }
540}
541
542template <>
543void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::store_dst_max_op(
544 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
545 using namespace data_type;
546
547 if (masked) {
548 if (jpp.src_dt == s32)
549 for (int i = 0; i < jpp.c_tail; i++)
550 pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)],
551 vreg_dst(jj), i);
552 else if (utils::one_of(jpp.src_dt, u8, s8))
553 for (int i = 0; i < jpp.c_tail; i++)
554 pextrb(ptr[reg_ptr_dst_i8 + offset + i], vreg_dst(jj), i);
555 else
556 assert(!"unsupported src data type");
557 } else
558 movups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
559}
560
561template <>
562void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(
563 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
564 using namespace data_type;
565
566 Label store_data_safely, done;
567
568 int c_block = jpp.c_block;
569
570 const uint64_t low_mask = (1ULL << (c_block / 2)) - 1;
571 const uint8_t shift = cpu_isa_traits<avx2>::vlen - jpp.c_tail;
572
573 if (masked) {
574 switch (jpp.src_dt) {
575 case s32:
576 vpmaskmovd(
577 ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
578 break;
579 case s8:
580 case u8: {
581
582 lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
583
584 if (!jpp.safe_c_tail) {
585 Xmm xreg_dst = Xmm(vreg_dst(jj).getIdx());
586
587 cmp(reg_ptr_maskmovdqu_dst, reg_dst_safe_access);
588 ja(store_data_safely, T_NEAR);
589
590 // Store low half by mask (bytes 0...15)
591 vmaskmovdqu(xreg_dst, xreg_mask_lo);
592
593 // Do we need to store high half (bytes 16...31) ?
594 if (msk & ~low_mask) {
595 vextracti128(xreg_dst, vreg_dst(jj), 1);
596 add(reg_ptr_maskmovdqu_dst, c_block / 2);
597 vmaskmovdqu(xreg_dst, xreg_mask_hi);
598 }
599 jmp(done, T_NEAR);
600 }
601
602 L(store_data_safely);
603
604 vperm2i128(vreg_tail, vreg_dst(jj), vreg_dst(jj), 0x08);
605 if (shift <= 16) {
606 vpalignr(vreg_tail, vreg_dst(jj), vreg_tail, 16 - shift);
607 } else {
608 vpalignr(vreg_tail, vreg_tail, vreg_zeros, 32 - shift);
609 }
610
611 Xmm xreg_tail = Xmm(vreg_tail.getIdx());
612 // Do we need to store low half (bytes 0...15) ?
613 if (msk & ~low_mask) {
614 sub(reg_ptr_maskmovdqu_dst, shift);
615 vmaskmovdqu(xreg_tail, xreg_mask_2_lo);
616 add(reg_ptr_maskmovdqu_dst, c_block / 2);
617 } else {
618 add(reg_ptr_maskmovdqu_dst, (c_block / 2) - shift);
619 }
620
621 // Store high half by mask (bytes 16..31)
622 vextracti128(xreg_tail, vreg_tail, 1);
623 vmaskmovdqu(xreg_tail, xreg_mask_2_hi);
624
625 L(done);
626 } break;
627 default: assert(!"unsupported src data type");
628 }
629 } else
630 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
631}
632
633template <>
634void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(
635 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
636 using namespace data_type;
637
638 if (masked) {
639 switch (jpp.src_dt) {
640 case s32:
641 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
642 break;
643 case s8:
644 case u8:
645 vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
646 break;
647 default: assert(!"unsupported src data type");
648 }
649 } else
650 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
651}
652
653template <>
654void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::store_dst_avg_op(
655 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
656 using namespace data_type;
657
658 // Don't generate useless code
659 if (masked && !msk) return;
660
661 const Vmm &vr_dst = vreg_dst_s32(jj, ll);
662
663 if (jpp.src_dt == s32) {
664 if (masked)
665 for (int i = 0; i < jpp.c_tail; i++)
666 pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)],
667 vr_dst, i);
668 else
669 movups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
670 } else if (utils::one_of(jpp.src_dt, s8, u8)) {
671 packssdw(vr_dst, vr_dst);
672 if (jpp.src_dt == s8)
673 packsswb(vr_dst, vr_dst);
674 else
675 packuswb(vr_dst, vr_dst);
676
677 const int copy_range = masked
678 ? math::ilog2q(jpp.tail[ll] + 1)
679 : cpu_isa_traits<sse41>::vlen / data_type_size(avg_proc_dt);
680 for (int i = 0; i < copy_range; i++)
681 pextrb(ptr[reg_ptr_dst_i8 + offset + i], vr_dst, i);
682 } else
683 assert(!"unsupported src data type");
684}
685
686template <>
687void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(
688 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
689 using namespace data_type;
690
691 // Don't generate useless code
692 if (masked && !msk) return;
693
694 auto s32_to_i8 = [&](bool is_signed, const Vmm &vr_dst) {
695 // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
696 // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
697 if (is_signed)
698 vpackssdw(vr_dst, vr_dst, vreg_zeros);
699 else
700 vpackusdw(vr_dst, vr_dst, vreg_zeros);
701
702 // Permute qwords to restore original order
703 // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
704 vpermq(vr_dst, vr_dst, 0x58);
705
706 // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
707 // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
708 if (is_signed)
709 vpacksswb(vr_dst, vr_dst, vreg_zeros);
710 else
711 vpackuswb(vr_dst, vr_dst, vreg_zeros);
712 };
713
714 auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm &vr_dst) {
715 // Conversion s32 -> s8/u8
716 s32_to_i8(is_signed, vr_dst);
717
718 // early-out for non-masked cases
719 if (!is_masked) {
720 vmovlps(ptr[reg_ptr_dst_i8 + offset], Xmm(vr_dst.getIdx()));
721 return;
722 }
723 // store 8 bytes
724 lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
725
726 // Need to use mmx 8-bytes operation to avoid memory violations.
727 // NOTICE: it was discovered that Intel SSE and Intel AVX instructions
728 // maskmovdqu/vmaskmovdqu
729 // with low 8-bytes mask throws exception if high 8-bytes belongs write-protected page.
730 // NOTE: use indirect move via gpr to avoid transition penalty
731 vmovq(reg_tmp, Xmm(vr_dst.getIdx()));
732 movq(mmx_dst_i8, reg_tmp);
733
734 // mmx_full_msk - mask for all 8 bytes in zero-tail case
735 // mmx_mask(ll) - ll-th mask of tail in non-zero-tail case
736
737 const int msk_gran
738 = cpu_isa_traits<avx2>::vlen / data_type_size(avg_proc_dt);
739
740 const int ll_end = (ll + 1) * msk_gran; // ((ll + 1) * 8)
741
742 if (is_masked && (ll_end > jpp.c_tail)) { //implies this tail not full.
743 Label store_data_safely, done;
744 const uint8_t shift = msk_gran - jpp.c_tail % msk_gran;
745
746 if (!jpp.safe_c_tail) {
747 cmp(reg_ptr_maskmovdqu_dst, reg_dst_safe_access);
748 ja(store_data_safely, T_NEAR);
749
750 /* store dst_tail with overlap outside the channel dimension,
751 * but assume it's within the memory boundary. */
752 maskmovq(mmx_dst_i8, mmx_mask(ll));
753 jmp(done, T_NEAR);
754 }
755
756 L(store_data_safely);
757
758 /* store dst_tail at 'dst_address - shift' so that it does not
759 * spill over the memory boundary */
760 movq(mmx_tmp, mmx_mask(ll));
761 psllq(mmx_tmp, shift * 8); // multiply with 8 (bits/byte)
762 psllq(mmx_dst_i8, shift * 8);
763 sub(reg_ptr_maskmovdqu_dst, shift);
764 maskmovq(mmx_dst_i8, mmx_tmp);
765
766 L(done);
767 } else {
768 maskmovq(mmx_dst_i8, mmx_full_msk);
769 }
770 };
771
772 switch (jpp.dst_dt) {
773 case s32:
774 if (masked) {
775 vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask,
776 vreg_dst_s32(jj, ll));
777 } else
778 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
779 break;
780 case s8: store_i8(true, masked, vreg_dst_s32(jj, ll)); break;
781 case u8: store_i8(false, masked, vreg_dst_s32(jj, ll)); break;
782 default: assert(!"unsuppotred dst data_type");
783 }
784}
785
786template <>
787void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(
788 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
789 using namespace data_type;
790
791 // Don't generate useless code
792 if (masked && !msk) return;
793
794 const Vmm &vr_dst
795 = masked ? vreg_dst_s32(jj, ll) | mask(ll) : vreg_dst_s32(jj, ll);
796
797 switch (jpp.dst_dt) {
798 case s32: vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
799 case s8: vpmovsdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
800 case u8: vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
801 default: assert(!"unsupported dst data_type");
802 }
803}
804
805template <cpu_isa_t isa>
806void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(
807 int jj, int ll, int c_tail) {
808 using namespace data_type;
809
810 int c_block = jpp.c_block;
811 int ur_c = jpp.ur_c;
812
813 switch (jpp.alg) {
814 case pooling_max: {
815 auto offset = jj * c_block * sizeof_dst_dt();
816 bool masked = jj == ur_c - 1 && c_tail;
817 store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
818 break;
819 }
820 case pooling_avg_include_padding:
821 case pooling_avg_exclude_padding: {
822 auto offset = (ll * (c_block / max_num_ll) + jj * c_block)
823 * sizeof_dst_dt();
824 bool masked = jj == ur_c - 1 && c_tail;
825 store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
826 break;
827 }
828 default: assert(!"unsupported pooling algorithm");
829 }
830}
831
832template <>
833void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::compute_max_op(const int jj) {
834 using namespace data_type;
835 switch (jpp.src_dt) {
836 case s32: pmaxsd(vreg_dst(jj), vreg_src(jj)); break;
837 case s8: pmaxsb(vreg_dst(jj), vreg_src(jj)); break;
838 case u8: pmaxub(vreg_dst(jj), vreg_src(jj)); break;
839 default: assert(!"unsupported src data type");
840 }
841}
842
843template <>
844void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj) {
845 using namespace data_type;
846 switch (jpp.src_dt) {
847 case s32: vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
848 case s8: vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
849 case u8: vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
850 default: assert(!"unsupported src data type");
851 }
852}
853
854template <>
855void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj) {
856 using namespace data_type;
857
858 // Compare
859 switch (jpp.src_dt) {
860 case s32:
861 vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
862 break;
863 case s8:
864 vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
865 break;
866 case u8:
867 vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
868 break;
869 default: assert(!"unsupported src data type");
870 }
871
872 // move max values into vreg_dst
873 if (jpp.src_dt == s32)
874 vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
875 else
876 vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
877}
878
879template <cpu_isa_t isa>
880void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(
881 int ur_c, int c_tail) {
882 Label l_kd, l_kh, l_kw;
883
884 int ih = jpp.ih;
885 int iw = jpp.iw;
886 int c = jpp.c;
887
888 for (int jj = 0; jj < ur_c; jj++)
889 uni_vmovups(vreg_dst(jj), vreg_tmp);
890
891 mov(aux_reg_src_d, reg_ptr_src_i8);
892 xor_(reg_kd_index, reg_kd_index);
893 L(l_kd);
894 {
895 mov(aux_reg_src_h, aux_reg_src_d);
896 xor_(reg_kh_index, reg_kh_index);
897 L(l_kh);
898 {
899 mov(aux_reg_src_w, aux_reg_src_h);
900 xor_(reg_kw_index, reg_kw_index);
901 L(l_kw);
902 {
903 for (int jj = 0; jj < ur_c; jj++) {
904 load_src(jj, 0, c_tail);
905 compute_max_op(jj);
906 }
907 add(aux_reg_src_w, c * sizeof_src_dt());
908 inc(reg_kw_index);
909 cmp(reg_kw_index, reg_kw);
910 jl(l_kw, T_NEAR);
911 }
912 add(aux_reg_src_h, iw * c * sizeof_src_dt());
913 inc(reg_kh_index);
914 cmp(reg_kh_index, reg_kh);
915 jl(l_kh, T_NEAR);
916 }
917 add(aux_reg_src_d, ih * iw * c * sizeof_src_dt());
918 inc(reg_kd_index);
919 cmp(reg_kd_index, reg_kd);
920 jl(l_kd, T_NEAR);
921 }
922
923 for (int jj = 0; jj < ur_c; jj++)
924 store_dst(jj, 0, c_tail);
925}
926
927template <cpu_isa_t isa>
928void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(
929 int ur_c, int c_tail) {
930 using namespace data_type;
931
932 Label l_kd, l_kh, l_kw;
933
934 int ih = jpp.ih;
935 int iw = jpp.iw;
936 int c = jpp.c;
937
938 const int num_ll = data_type_size(avg_proc_dt) / data_type_size(jpp.src_dt);
939
940 for (int jj = 0; jj < ur_c; jj++) {
941 for (int ll = 0; ll < num_ll; ll++) {
942 bool masked = jj == ur_c - 1 && c_tail;
943 size_t msk = jpp.tail[ll];
944 if (!(masked && !msk)) {
945 // Clearing of src reg is not needed as they are written before read
946 uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
947 vreg_dst_s32(jj, ll));
948 }
949 }
950 }
951
952 mov(aux_reg_src_d, reg_ptr_src_i8);
953 xor_(reg_kd_index, reg_kd_index);
954 L(l_kd);
955 {
956 mov(aux_reg_src_h, aux_reg_src_d);
957 xor_(reg_kh_index, reg_kh_index);
958 L(l_kh);
959 {
960 mov(aux_reg_src_w, aux_reg_src_h);
961 xor_(reg_kw_index, reg_kw_index);
962 L(l_kw);
963 {
964 for (int jj = 0; jj < ur_c; jj++) {
965 for (int ll = 0; ll < num_ll; ll++) {
966 bool masked = jj == ur_c - 1 && c_tail;
967 size_t msk = jpp.tail[ll];
968 if (!(masked && !msk)) {
969 load_src(jj, ll, c_tail);
970 uni_vpaddd(vreg_dst_s32(jj, ll),
971 vreg_dst_s32(jj, ll), vreg_src_s32(jj, ll));
972 }
973 }
974 }
975 add(aux_reg_src_w, c * sizeof_src_dt());
976 inc(reg_kw_index);
977 cmp(reg_kw_index, reg_kw);
978 jl(l_kw, T_NEAR);
979 }
980 add(aux_reg_src_h, iw * c * sizeof_src_dt());
981 inc(reg_kh_index);
982 cmp(reg_kh_index, reg_kh);
983 jl(l_kh, T_NEAR);
984 }
985 add(aux_reg_src_d, ih * iw * c * sizeof_src_dt());
986 inc(reg_kd_index);
987 cmp(reg_kd_index, reg_kd);
988 jl(l_kd, T_NEAR);
989 }
990
991 for (int jj = 0; jj < ur_c; jj++) {
992 for (int ll = 0; ll < num_ll; ll++) {
993 const bool masked = jj == ur_c - 1 && c_tail;
994 const size_t msk = jpp.tail[ll];
995 if (!(masked && !msk)) {
996 const auto &reg_dst_f32 = vreg_dst_f32(jj, ll);
997 const auto &reg_dst_s32 = vreg_dst_s32(jj, ll);
998 uni_vcvtdq2ps(reg_dst_f32, reg_dst_s32);
999 uni_vfmadd132ps(reg_dst_f32, vreg_zeros, vreg_tmp);
1000
1001 if (jpp.with_postops) {
1002 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1003 if (jpp.with_binary) {
1004 rhs_arg_params.vmm_idx_to_out_reg.emplace(
1005 reg_dst_f32.getIdx(), reg_ptr_dst_i8);
1006 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
1007 reg_dst_f32.getIdx(), get_offset_dst(jj, ll));
1008 const bool tail = ll == post_op_tail_opmask_idx_;
1009 if (tail && masked)
1010 rhs_arg_params.vmm_tail_idx_.emplace(
1011 reg_dst_f32.getIdx());
1012 }
1013 postops_injector_->compute_vector(
1014 reg_dst_f32.getIdx(), rhs_arg_params);
1015 }
1016
1017 uni_vcvtps2dq(reg_dst_s32, reg_dst_f32);
1018
1019 if (jpp.with_postops)
1020 if (jpp.dst_dt == u8) {
1021 uni_vpmaxsd(reg_dst_s32, reg_dst_s32, vreg_zeros);
1022 }
1023 store_dst(jj, ll, c_tail);
1024 }
1025 }
1026 }
1027}
1028
1029template <cpu_isa_t isa>
1030void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
1031 switch (jpp.alg) {
1032 case pooling_max: compute_max_step(ur_c, c_tail); break;
1033 case pooling_avg_include_padding:
1034 case pooling_avg_exclude_padding: compute_avg_step(ur_c, c_tail); break;
1035 default: assert(!"unsupported pooling algorithm");
1036 }
1037}
1038
1039template <cpu_isa_t isa>
1040void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block() {
1041 Label l_main_loop;
1042
1043 int nb_c = jpp.nb_c;
1044 int c_block = jpp.c_block;
1045 int ur_c = jpp.ur_c;
1046 int ur_c_tail = jpp.ur_c_tail;
1047 int c_steps = nb_c / ur_c;
1048 int c_tail = jpp.c_tail;
1049
1050 xor_(c_iter, c_iter);
1051 if (c_steps > 0) {
1052 L(l_main_loop);
1053 {
1054 compute_step(ur_c, 0);
1055 add(reg_ptr_src_i8, ur_c * c_block * sizeof_src_dt());
1056 add(reg_ptr_dst_i8, ur_c * c_block * sizeof_dst_dt());
1057 inc(c_iter);
1058 cmp(c_iter, c_steps);
1059 jl(l_main_loop, T_NEAR);
1060 }
1061 }
1062
1063 if (ur_c_tail != 0) { compute_step(ur_c_tail, c_tail); }
1064}
1065
1066template <>
1067void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::init_mask() {}
1068
1069template <>
1070void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
1071 using namespace data_type;
1072 using cpu_isa = cpu_isa_traits<avx2>;
1073
1074 // AVX2 mask initialization: mask stored in Ymm-regs
1075 auto init = [&](uint64_t bit_mask, bool need_ymm_mask = true,
1076 bool need_mmx_mask = false) {
1077 const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
1078
1079 const size_t DBITS = 8 * sizeof_src_dt();
1080 const uint64_t VMSK = 1ULL << (DBITS - 1);
1081 const size_t D_PER_QW = (8 * sizeof(uint64_t)) / DBITS;
1082 uint64_t vmask[QW_PER_VREG];
1083 for (size_t i = 0; i < QW_PER_VREG; i++) {
1084 uint64_t qw_vmask = 0ULL;
1085 for (size_t j = 0; j < D_PER_QW; j++) {
1086 if (bit_mask & 1) qw_vmask |= VMSK << DBITS * j;
1087 bit_mask >>= 1;
1088 }
1089 vmask[i] = qw_vmask;
1090 }
1091
1092 // Need mask in Ymm regs ?
1093 if (need_ymm_mask) {
1094
1095 // Put QWORDS with target mask into xmm regs
1096 const int xdst_i[QW_PER_VREG]
1097 = {xreg_mask_lo.getIdx(), xreg_mask_lo.getIdx(),
1098 xreg_mask_hi.getIdx(), xreg_mask_hi.getIdx()};
1099 const int xsrc_i[QW_PER_VREG] = {
1100 vreg_zeros
1101 .getIdx(), // 0-th qword insert in zeros -> {qw0, 0}
1102 xreg_mask_lo
1103 .getIdx(), // 1-st and 0-th merge -> {qw0,qw1}
1104 vreg_zeros.getIdx(), xreg_mask_hi.getIdx()};
1105 const uint8 qw_dst_idx[QW_PER_VREG]
1106 = {0, 1, 0, 1}; // qword index in 128-bit xreg
1107
1108 for (size_t i = 0; i < QW_PER_VREG; i++) {
1109 mov(reg_mask, vmask[i]);
1110 vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask,
1111 qw_dst_idx[i]);
1112
1113 // Need mask in MMX regs also?
1114 if (need_mmx_mask)
1115 movq(mmx_mask(i), reg_mask); // reuse value in reg_mask
1116 }
1117
1118 // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
1119 // and High (xreg_mask_hi) into full vreg_mask
1120 // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
1121 vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
1122
1123 // Compute mask algned to left from vreg_mask and store it in vreg_mask_2 to be use for tail processing.
1124 const uint8_t shift = 32 - jpp.c_tail;
1125 vperm2i128(vreg_mask_2, vreg_mask, vreg_mask, 0x08);
1126 if (shift <= 16) {
1127 vpalignr(vreg_mask_2, vreg_mask, vreg_mask_2, 16 - shift);
1128 } else {
1129 vpalignr(vreg_mask_2, vreg_mask_2, vreg_zeros, 32 - shift);
1130 }
1131 vextracti128(xreg_mask_2_hi, vreg_mask_2, 0x1);
1132 }
1133
1134 // Need mask in MMX regs ?
1135 if (need_mmx_mask) {
1136
1137 // Only in MMX regs ?
1138 if (!need_ymm_mask)
1139 for (size_t i = 0; i < QW_PER_VREG; i++) {
1140 mov(reg_mask, vmask[i]);
1141 movq(mmx_mask(i), reg_mask);
1142 }
1143
1144 // Form full mask for one QWORD
1145 uint64_t qw_full_vmask = 0ULL;
1146 for (size_t i = 0; i < D_PER_QW; i++)
1147 qw_full_vmask |= VMSK << DBITS * i;
1148
1149 mov(reg_mask, qw_full_vmask);
1150 movq(mmx_full_msk, reg_mask);
1151 }
1152 };
1153
1154 uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
1155 switch (jpp.alg) {
1156 case pooling_max:
1157 // For "max" we need mask only in case of non-zero tail
1158 if (tail_mask) init(tail_mask);
1159 break;
1160 case pooling_avg_include_padding:
1161 case pooling_avg_exclude_padding:
1162 // For "avg" we need mask:
1163 // - s32 - in case of the non-zero tail
1164 // - s8/u8 - irrespective of the tail in MMX regs (always store by mask)
1165 // - for non-zero tail in Ymm regs (for load)
1166 switch (jpp.src_dt) {
1167 case s32:
1168 if (tail_mask) init(tail_mask);
1169 break;
1170 case s8:
1171 case u8:
1172 init(tail_mask ? tail_mask : ~0ULL, tail_mask != 0, true);
1173 break;
1174 default: assert(!"unsupported src data type");
1175 }
1176 break;
1177 default: assert(!"unsupported pooling algorithm");
1178 }
1179}
1180
1181template <>
1182void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
1183
1184 for (int ll = 0; ll < max_num_ll; ll++) {
1185 mov(reg_mask, jpp.tail[ll]);
1186 kmovq(mask(ll), reg_mask);
1187 }
1188}
1189
1190template <cpu_isa_t isa>
1191void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
1192 using namespace data_type;
1193
1194 switch (jpp.alg) {
1195 case pooling_avg_include_padding:
1196 case pooling_avg_exclude_padding:
1197 mov(reg_tmp,
1198 ptr[reg_param
1199 + offsetof(jit_uni_i8i8_pool_call_params_t,
1200 idivider)]);
1201 uni_vmovq(xmm_tmp, reg_tmp);
1202 uni_vpbroadcastd(vreg_tmp, xmm_tmp);
1203 break;
1204 case pooling_max:
1205 switch (jpp.src_dt) {
1206 case s32:
1207 mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
1208 break;
1209 case s8:
1210 mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
1211 break;
1212 case u8:
1213 mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
1214 break;
1215 default: assert(!"unsupported src data_type");
1216 }
1217
1218 uni_vmovq(xmm_tmp, reg_tmp);
1219 if (jpp.src_dt == s32)
1220 uni_vpbroadcastd(vreg_tmp, xmm_tmp);
1221 else if (mayiuse(avx2))
1222 vpbroadcastb(vreg_tmp, xmm_tmp);
1223 else
1224 pshufb(xmm_tmp, vreg_zeros);
1225 break;
1226 default: assert(!"unsupported pooling algorithm");
1227 }
1228}
1229
1230template <cpu_isa_t isa>
1231void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
1232 preamble();
1233
1234#if !defined(_WIN32)
1235 // Always use rcx as abi_param1 -
1236 // see the note about maskmovdqu/maskmovq near reg_param.
1237 mov(rcx, rdi);
1238#endif
1239
1240#define READ_PARAM(reg, field) \
1241 mov(reg, ptr[reg_param + offsetof(jit_uni_i8i8_pool_call_params_t, field)])
1242 READ_PARAM(reg_ptr_src_i8, src_i8);
1243 READ_PARAM(reg_ptr_dst_i8, dst_i8);
1244 READ_PARAM(reg_kd, kd_range);
1245 READ_PARAM(reg_kh, kh_range);
1246 READ_PARAM(reg_kw, kw_range);
1247 READ_PARAM(reg_src_safe_access, src_safe_access);
1248 READ_PARAM(reg_dst_safe_access, dst_safe_access);
1249
1250#undef READ_PARAM
1251
1252 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
1253
1254 init_mask();
1255
1256 init_tmp_reg();
1257
1258 compute_c_block();
1259
1260 emms();
1261 postamble();
1262
1263 if (jpp.with_eltwise && postops_injector_)
1264 postops_injector_->prepare_table();
1265}
1266
1267template <cpu_isa_t isa>
1268status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(
1269 jit_pool_conf_t &jpp, const pooling_pd_t *ppd) {
1270 if (!mayiuse(isa)) return status::unimplemented;
1271
1272 const auto &pd = *ppd->desc();
1273 const memory_desc_wrapper src_d(ppd->src_md());
1274 const memory_desc_wrapper dst_d(ppd->dst_md());
1275 const int ndims = src_d.ndims();
1276 const bool is_1d = ndims == 3;
1277 const bool is_3d = ndims == 5;
1278
1279 jpp.mb = src_d.dims()[0];
1280 jpp.c = src_d.dims()[1];
1281
1282 jpp.id = is_3d ? src_d.dims()[ndims - 3] : 1;
1283 jpp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
1284 jpp.iw = src_d.dims()[ndims - 1];
1285
1286 jpp.od = is_3d ? dst_d.dims()[ndims - 3] : 1;
1287 jpp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
1288 jpp.ow = dst_d.dims()[ndims - 1];
1289
1290 jpp.stride_d = is_3d ? pd.strides[ndims - 5] : 1;
1291 jpp.stride_h = is_1d ? 1 : pd.strides[ndims - 4];
1292 jpp.stride_w = pd.strides[ndims - 3];
1293
1294 jpp.kd = is_3d ? pd.kernel[ndims - 5] : 1;
1295 jpp.kh = is_1d ? 1 : pd.kernel[ndims - 4];
1296 jpp.kw = pd.kernel[ndims - 3];
1297
1298 jpp.f_pad = is_3d ? pd.padding[0][ndims - 5] : 0;
1299 jpp.t_pad = is_1d ? 0 : pd.padding[0][ndims - 4];
1300 jpp.l_pad = pd.padding[0][ndims - 3];
1301
1302 int back_pad = calculate_end_padding(
1303 jpp.f_pad, jpp.od, jpp.id, jpp.stride_d, jpp.kd);
1304 int bottom_pad = calculate_end_padding(
1305 jpp.t_pad, jpp.oh, jpp.ih, jpp.stride_h, jpp.kh);
1306 int right_pad = calculate_end_padding(
1307 jpp.l_pad, jpp.ow, jpp.iw, jpp.stride_w, jpp.kw);
1308
1309 if (jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw
1310 || back_pad >= jpp.kd || bottom_pad >= jpp.kh
1311 || right_pad >= jpp.kw)
1312 return status::unimplemented;
1313
1314 jpp.alg = pd.alg_kind;
1315
1316 jpp.src_dt = pd.src_desc.data_type;
1317 jpp.dst_dt = pd.dst_desc.data_type;
1318
1319 // data_type items per one vreg on the <isa>
1320 // isa == sse41 : 16 bytes -> 16 for s8/u8, 4 for s32
1321 // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32
1322 // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
1323 int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
1324
1325 /* Verify that vlen-sized memory access happens within the tensor's
1326 * size, otherwise load/store will always spill outside the memory
1327 * boundary.*/
1328 bool safe_load_n_store = IMPLICATION(utils::one_of(isa, avx2, sse41),
1329 jpp.mb * jpp.c * nstl::min(jpp.id, jpp.od)
1330 * nstl::min(jpp.ih, jpp.oh)
1331 * nstl::min(jpp.iw, jpp.ow)
1332 >= simd_w);
1333 if (!safe_load_n_store) return status::unimplemented;
1334
1335 jpp.c_block = simd_w;
1336 jpp.c_tail = jpp.c % jpp.c_block;
1337 jpp.nb_c = jpp.c / jpp.c_block;
1338 jpp.ur_c = 1;
1339 jpp.ur_c_tail = jpp.c_tail != 0;
1340
1341 size_t tail_mask = (1ULL << jpp.c_tail) - 1;
1342
1343 /* If channel_size is bigger than vlen, we can safely assume there is no
1344 * underflow of memory boundary, so always perform c_tail and save
1345 * a couple of compute cycles*/
1346 jpp.safe_c_tail = jpp.c_tail > 0 && jpp.c >= simd_w;
1347
1348 switch (jpp.alg) {
1349 case pooling_max:
1350 jpp.tail[0] = tail_mask;
1351 jpp.tail[1] = 0;
1352 jpp.tail[2] = 0;
1353 jpp.tail[3] = 0;
1354 break;
1355 case pooling_avg_include_padding:
1356 case pooling_avg_exclude_padding: {
1357 // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
1358 // sse : 4, avx2 : 8, avx512 : 16
1359 const size_t msk_gran
1360 = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
1361 const size_t msk_msk = (1ULL << msk_gran) - 1;
1362 size_t m = tail_mask;
1363 for (size_t ll = 0; ll < max_num_ll; ll++) {
1364 jpp.tail[ll] = m & msk_msk;
1365 m = m >> msk_gran;
1366 }
1367 break;
1368 }
1369 default: return status::unimplemented;
1370 }
1371
1372 if (!post_ops_ok(jpp, *ppd->attr(), dst_d)) return status::unimplemented;
1373
1374 return status::success;
1375}
1376
1377template <cpu_isa_t isa>
1378bool jit_uni_i8i8_pooling_fwd_ker_t<isa>::post_ops_ok(jit_pool_conf_t &jpp,
1379 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) {
1380 const auto &post_ops = attr.post_ops_;
1381 const auto &entries = post_ops.entry_;
1382 jpp.with_postops = false;
1383 jpp.with_eltwise = false;
1384 jpp.with_binary = false;
1385
1386 if (entries.empty()) return true;
1387
1388 for (const auto &entry : entries) {
1389 if (entry.is_eltwise()) {
1390 const auto alg = entry.eltwise.alg;
1391 jpp.with_eltwise = eltwise_injector::is_supported(isa, alg);
1392 } else if (entry.is_binary()) {
1393 if (isa != avx512_core
1394 && entry.binary.src1_desc.data_type == data_type::bf16)
1395 return false;
1396 jpp.with_binary = true;
1397 } else
1398 return false;
1399 }
1400
1401 jpp.with_postops = jpp.with_eltwise || jpp.with_binary;
1402 jpp.post_ops = post_ops;
1403
1404 /*
1405 * TODO Currently eltwise/binary injectors assumes that data in vmm has f32 dt.
1406 * In max pooling data remains in i8 data type.
1407 */
1408 return IMPLICATION(jpp.with_postops, jpp.alg != pooling_max)
1409 && binary_injector::binary_args_broadcast_supported(
1410 post_ops, dst_d, get_supported_bcast_strategies());
1411}
1412
1413template <cpu_isa_t isa>
1414status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
1415 return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_, this);
1416}
1417
1418template <cpu_isa_t isa>
1419jit_uni_i8i8_pooling_fwd_t<isa>::jit_uni_i8i8_pooling_fwd_t(const pd_t *apd)
1420 : primitive_t(apd), ker_(nullptr) {}
1421
1422template <cpu_isa_t isa>
1423jit_uni_i8i8_pooling_fwd_t<isa>::~jit_uni_i8i8_pooling_fwd_t() = default;
1424
1425template <cpu_isa_t isa>
1426status_t jit_uni_i8i8_pooling_fwd_t<isa>::init(engine_t *engine) {
1427 CHECK(safe_ptr_assign(ker_,
1428 new jit_uni_i8i8_pooling_fwd_ker_t<isa>(
1429 pd()->jpp_, pd()->invariant_dst_md())));
1430 return ker_->create_kernel();
1431}
1432
1433template <cpu_isa_t isa>
1434status_t jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward(
1435 const exec_ctx_t &ctx) const {
1436 auto src_i8 = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1437 auto dst_i8 = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1438
1439 const memory_desc_wrapper src_d(pd()->src_md());
1440 const memory_desc_wrapper dst_d(pd()->dst_md());
1441
1442 const auto &jpp = pd()->jpp_;
1443 const auto post_ops_binary_rhs_arg_vec
1444 = binary_injector::prepare_binary_args(jpp.post_ops, ctx);
1445 /* Calculate when the memory-access will happen outisde of the memory
1446 * boundary, if so, compute a safe memory access. */
1447 const auto src_safe_access = reinterpret_cast<char *>(
1448 reinterpret_cast<ptrdiff_t>(src_i8 + src_d.size() - 1)
1449 - (cpu_isa_traits<isa>::vlen - 1));
1450
1451 const auto dst_safe_access = reinterpret_cast<char *>(
1452 reinterpret_cast<ptrdiff_t>(dst_i8 + dst_d.size() - 1)
1453 - (cpu_isa_traits<isa>::vlen - 1));
1454
1455 parallel_nd(jpp.mb, jpp.od, jpp.oh, jpp.ow,
1456 [&](dim_t n, dim_t od, dim_t oh, dim_t ow) {
1457 dim_t id = nstl::max(od * jpp.stride_d - jpp.f_pad, dim_t(0));
1458 dim_t ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, dim_t(0));
1459 dim_t iw = nstl::max(ow * jpp.stride_w - jpp.l_pad, dim_t(0));
1460
1461 dim_t kd_start
1462 = nstl::max(dim_t(0), jpp.f_pad - od * jpp.stride_d);
1463 dim_t kd_end = nstl::min(
1464 dim_t(jpp.kd), jpp.id + jpp.f_pad - od * jpp.stride_d);
1465 dim_t kh_start
1466 = nstl::max(dim_t(0), jpp.t_pad - oh * jpp.stride_h);
1467 dim_t kh_end = nstl::min(
1468 dim_t(jpp.kh), jpp.ih + jpp.t_pad - oh * jpp.stride_h);
1469 dim_t kw_start
1470 = nstl::max(dim_t(0), jpp.l_pad - ow * jpp.stride_w);
1471 dim_t kw_end = nstl::min(
1472 dim_t(jpp.kw), jpp.iw + jpp.l_pad - ow * jpp.stride_w);
1473
1474 auto p = jit_uni_i8i8_pool_call_params_t();
1475 p.src_i8 = &src_i8[get_offset(src_d, n, 0, id, ih, iw)
1476 * src_d.data_type_size()];
1477 p.dst_i8 = &dst_i8[get_offset(dst_d, n, 0, od, oh, ow)
1478 * dst_d.data_type_size()];
1479 p.dst_orig = dst_i8;
1480 p.kd_range = kd_end - kd_start;
1481 p.kh_range = kh_end - kh_start;
1482 p.kw_range = kw_end - kw_start;
1483 p.idivider = 1.0f
1484 / ((jpp.alg == pooling_avg_exclude_padding)
1485 ? p.kd_range * p.kh_range * p.kw_range
1486 : jpp.kd * jpp.kh * jpp.kw);
1487 p.src_safe_access = src_safe_access;
1488 p.dst_safe_access = dst_safe_access;
1489 p.post_ops_binary_rhs_arg_vec
1490 = post_ops_binary_rhs_arg_vec.data();
1491 (*ker_)(&p);
1492 });
1493 return status::success;
1494}
1495
1496// Explicit instantiation only for supported <isa> values.
1497//
1498template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
1499template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
1500
1501template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
1502template struct jit_uni_i8i8_pooling_fwd_t<avx2>;
1503
1504template struct jit_uni_i8i8_pooling_fwd_ker_t<sse41>;
1505template struct jit_uni_i8i8_pooling_fwd_t<sse41>;
1506
1507} // namespace x64
1508} // namespace cpu
1509} // namespace impl
1510} // namespace dnnl
1511