1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include "cpu/x64/jit_uni_i8i8_pooling.hpp" |
17 | #include <math.h> |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
23 | #include "cpu/x64/jit_generator.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | static bcast_set_t get_supported_bcast_strategies() { |
31 | return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc}; |
32 | } |
33 | |
34 | static inline dim_t get_offset( |
35 | const memory_desc_wrapper &mdw, int n, int c, int d, int h, int w) { |
36 | switch (mdw.ndims()) { |
37 | case 3: return mdw.blk_off(n, c, w); |
38 | case 4: return mdw.blk_off(n, c, h, w); |
39 | case 5: return mdw.blk_off(n, c, d, h, w); |
40 | default: assert(!"Invalid tensor dimension in pooling" ); |
41 | } |
42 | return 0; |
43 | } |
44 | |
45 | using namespace Xbyak; |
46 | |
47 | using namespace dnnl::impl::utils; |
48 | using namespace dnnl::impl::utils; |
49 | using namespace dnnl::impl::types; |
50 | using namespace alg_kind; |
51 | |
52 | #define GET_OFF(field) offsetof(jit_uni_i8i8_pool_call_params_t, field) |
53 | |
54 | struct jit_uni_i8i8_pool_call_params_t { |
55 | const char *src_i8; |
56 | const char *dst_i8; |
57 | const char *dst_orig; |
58 | const void *post_ops_binary_rhs_arg_vec; |
59 | size_t kd_range; |
60 | size_t kh_range; |
61 | size_t kw_range; |
62 | float idivider; |
63 | const char *src_safe_access; |
64 | const char *dst_safe_access; |
65 | }; |
66 | |
67 | template <cpu_isa_t isa> |
68 | struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator { |
69 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t) |
70 | |
71 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
72 | Xmm xreg(int idx) const { return Xmm(idx); } |
73 | Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); } |
74 | Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); } |
75 | |
76 | // In case of avx2 with data type i8 we need to use |
77 | // maskmovdqu and maskmovq instructions which has its destination hardcoded in rdi. |
78 | // Windows ABI: abi_param1 is rcx - nothing to do else |
79 | // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1 |
80 | Reg64 reg_param = rcx; // Our "unified abi_param1" |
81 | Reg64 reg_ptr_src_i8 = r8; |
82 | Reg64 reg_ptr_dst_i8 = r9; |
83 | Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi |
84 | |
85 | Reg64 reg_kd_index |
86 | = rdi; // shared with reg_ptr_maskmovdqu_dst; only used before store |
87 | Reg64 reg_kh_index = r11; |
88 | Reg64 reg_kw_index = r10; |
89 | Reg64 reg_kd = r14; |
90 | Reg64 reg_kh = r13; |
91 | Reg64 reg_kw = r12; |
92 | Reg64 c_iter = r15; // shared with reg_mask; only used after mask init |
93 | |
94 | Reg64 aux_reg_src_d |
95 | = rdx; // shared with reg_tmp; loaded before each accum loop, unused during store |
96 | Reg64 aux_reg_src_h = rax; |
97 | Reg64 aux_reg_src_w = rbx; |
98 | |
99 | Reg64 reg_tmp = rdx; // only used during mask init and store |
100 | Reg64 reg_src_safe_access = rbp; |
101 | Reg64 reg_dst_safe_access = rsi; |
102 | |
103 | Reg64 reg_mask = r15; // only used during mask init |
104 | |
105 | Opmask k_cmp_mask = Opmask(7); |
106 | |
107 | Opmask mask(int idx) { return Opmask(6 - idx); } |
108 | |
109 | // ref to any of XYZ-regs via xreg/yreg/vreg functions |
110 | Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp |
111 | Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type |
112 | Vmm vreg_zeros = vreg(1); |
113 | Vmm vreg_tail = vreg(4); |
114 | |
115 | // only in case of <isa> == avx2 |
116 | Vmm vreg_mask = vreg(2); // full byte-mask |
117 | Xmm xreg_mask_lo = xreg( |
118 | 2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask) |
119 | Xmm xreg_mask_hi = xreg( |
120 | 3); // "max" - high 128-bits part of byte-mask (stored separately) |
121 | |
122 | // vreg_mask shifted left (aligned left) to be used in tail processing. |
123 | // Example: idx [31..0] |
124 | // vreg_mask = [0,0,0,0,0,.....,0,x,x,x,x,x] ; x => byte mask (msb set) |
125 | // vreg_mask_2 = [x,x,x,x,x,0,0,0,0,0,.....,0] |
126 | Vmm vreg_mask_2 = vreg(5); |
127 | Xmm xreg_mask_2_lo = xreg(5); // similar to xreg_mask_lo |
128 | Xmm xreg_mask_2_hi = xreg(6); // similar to xreg_mask_hi |
129 | |
130 | Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails |
131 | Mmx mmx_dst_i8 = Mmx( |
132 | 0); // "avg" - Mmx reg for masked store results of s8/u8 operations |
133 | Mmx mmx_full_msk = Mmx( |
134 | 1); // "avg" - Mmx reg for full mask (all 8 bytes) - used until not in tail |
135 | Mmx mmx_tmp = Mmx(2); |
136 | int post_op_tail_opmask_idx_ = -1; |
137 | jit_pool_conf_t jpp; |
138 | std::unique_ptr<injector::jit_uni_postops_injector_t<isa>> |
139 | postops_injector_; |
140 | |
141 | enum : int { max_vidx_base = utils::one_of(isa, sse41, avx2) ? 7 : 2 }; |
142 | //"avg" pool uses more registers for unrolling. |
143 | enum : int { avg_vidx_base = utils::one_of(isa, sse41, avx2) ? 4 : 2 }; |
144 | |
145 | Vmm max_base_vr(int idx) const { return vreg(max_vidx_base + idx); } |
146 | Vmm avg_base_vr(int idx) const { return vreg(avg_vidx_base + idx); } |
147 | |
148 | size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); } |
149 | size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); } |
150 | |
151 | /* max pooling */ |
152 | Vmm vreg_src(int idx) const { return max_base_vr(idx); } // [0 .. ur_c-1] |
153 | Vmm vreg_dst(int idx) const { |
154 | return max_base_vr(jpp.ur_c + idx); |
155 | } // [ur_c .. 2*ur_c-1] |
156 | |
157 | /* avg pooling */ |
158 | // s32 used for processing of s8/u8 data |
159 | // thus we need to take into account ratio of sizes s32/i8 = 4 |
160 | static constexpr data_type_t avg_proc_dt = data_type::s32; |
161 | enum : int { |
162 | s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type) |
163 | / sizeof(typename prec_traits<data_type::u8>::type), |
164 | max_num_ll = s32_to_i8_ratio, |
165 | mmx_msk_base_reg = 3 |
166 | }; |
167 | |
168 | inline size_t get_offset_dst(int jj, int ll) const { |
169 | size_t offset = 0; |
170 | switch (jpp.alg) { |
171 | case pooling_max: { |
172 | offset = jj * jpp.c_block * sizeof_dst_dt(); |
173 | break; |
174 | } |
175 | case pooling_avg_include_padding: |
176 | case pooling_avg_exclude_padding: { |
177 | offset = (ll * (jpp.c_block / max_num_ll) + jj * jpp.c_block) |
178 | * sizeof_dst_dt(); |
179 | break; |
180 | } |
181 | default: assert(!"unsupported pooling algorithm" ); |
182 | } |
183 | return offset; |
184 | } |
185 | |
186 | Vmm vreg_src_s32(int jj, int ll) { |
187 | return avg_base_vr(3 * max_num_ll * jj + ll + 0 * max_num_ll); |
188 | } // ll: 0..4 [0..3] |
189 | |
190 | Vmm vreg_dst_s32(int jj, int ll) { |
191 | return avg_base_vr(3 * max_num_ll * jj + ll + 1 * max_num_ll); |
192 | } // ll: 0..4 [4..7] |
193 | |
194 | Vmm vreg_dst_f32(int jj, int ll) { |
195 | return avg_base_vr(3 * max_num_ll * jj + ll + 2 * max_num_ll); |
196 | } // ll: 0..4 [8..11] |
197 | |
198 | Mmx mmx_mask(int ll) { |
199 | return Mmx(mmx_msk_base_reg + ll); |
200 | }; // ll: 0..4 [Mmx(2)...Mmx(5)] |
201 | |
202 | static bool post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr, |
203 | const memory_desc_wrapper &dst_d); |
204 | |
205 | void init_tmp_reg(); |
206 | void init_mask(); |
207 | |
208 | void load_vreg_mask_q(int ll) {}; |
209 | |
210 | void load_src_max_op( |
211 | int jj, int ll, size_t offset, bool masked, uint64_t msk); |
212 | void load_src_avg_op( |
213 | int jj, int ll, size_t offset, bool masked, uint64_t msk); |
214 | void load_src(int jj, int ll, int c_tail); |
215 | |
216 | void store_dst_max_op( |
217 | int jj, int ll, size_t offset, bool masked, uint64_t msk); |
218 | void store_dst_avg_op( |
219 | int jj, int ll, size_t offset, bool masked, uint64_t msk); |
220 | void store_dst(int jj, int ll, int c_tail); |
221 | |
222 | void compute_avg_step(int ur_c, int c_tail); |
223 | void compute_max_op(const int jj); |
224 | void compute_max_step(int ur_c, int c_tail); |
225 | void compute_step(int ur_c, int c_tail); |
226 | |
227 | void compute_c_block(); |
228 | void generate() override; |
229 | |
230 | static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd); |
231 | |
232 | jit_uni_i8i8_pooling_fwd_ker_t( |
233 | const jit_pool_conf_t &jpp_, const memory_desc_t *dst_md) |
234 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa) |
235 | , jpp(jpp_) |
236 | , postops_injector_(nullptr) { |
237 | |
238 | if (jpp.with_postops) { |
239 | |
240 | const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float); |
241 | const std::size_t c_tail_elems = jpp.c % simd_w; |
242 | post_op_tail_opmask_idx_ = 0; |
243 | if (c_tail_elems) { |
244 | for (int ll = max_num_ll - 1; ll >= 0; ll--) { |
245 | if (jpp.tail[ll] != 0) { |
246 | post_op_tail_opmask_idx_ = ll; |
247 | break; |
248 | } |
249 | } |
250 | }; |
251 | |
252 | static constexpr bool preserve_gpr = true; |
253 | static constexpr bool preserve_vmm = true; |
254 | static constexpr bool use_exact_tail_scalar_bcast = false; |
255 | static constexpr std::size_t tmp_vmm_injector = 0u; |
256 | |
257 | const binary_injector::rhs_arg_static_params_t rhs_sp { |
258 | tmp_vmm_injector, r14, r15, r13, preserve_gpr, preserve_vmm, |
259 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
260 | memory_desc_wrapper(*dst_md), c_tail_elems, |
261 | mask(post_op_tail_opmask_idx_), |
262 | use_exact_tail_scalar_bcast}; |
263 | const binary_injector::static_params_t bsp { |
264 | reg_param, get_supported_bcast_strategies(), rhs_sp}; |
265 | |
266 | postops_injector_ = utils::make_unique< |
267 | injector::jit_uni_postops_injector_t<isa>>( |
268 | this, jpp.post_ops, bsp); |
269 | } |
270 | } |
271 | }; |
272 | |
273 | template <> |
274 | void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_vreg_mask_q(int ll) {}; |
275 | |
276 | template <> |
277 | void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) { |
278 | |
279 | // extract ll-th part of mask (ll-th QWORD) |
280 | vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, |
281 | 0x3 << 2 * ll); // 0x3 - mask for 2 x DWORD |
282 | |
283 | // Move mask from ll-th pos to 0-th pos |
284 | if (ll > 0) vpermq(vreg_mask_q, vreg_mask_q, ll); |
285 | }; |
286 | |
287 | template <> |
288 | void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_src_max_op( |
289 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
290 | using namespace data_type; |
291 | |
292 | if (masked) { |
293 | if (jpp.src_dt == s32) |
294 | for (int64_t i = 0; i < jpp.c_tail; i++) |
295 | pinsrd(vreg_src(jj), |
296 | ptr[aux_reg_src_w + offset + i * data_type_size(s32)], |
297 | i); |
298 | else |
299 | for (int i = 0; i < jpp.c_tail; i++) |
300 | pinsrb(vreg_src(jj), ptr[aux_reg_src_w + offset + i], i); |
301 | } else |
302 | movups(vreg_src(jj), ptr[aux_reg_src_w + offset]); |
303 | } |
304 | |
305 | template <> |
306 | void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op( |
307 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
308 | using namespace data_type; |
309 | |
310 | if (masked) { |
311 | if (jpp.src_dt == s32) { |
312 | vpmaskmovd(vreg_src(jj), vreg_mask, ptr[aux_reg_src_w + offset]); |
313 | } else { |
314 | // Steps to access 'tail' section: |
315 | // 1) First load all data from the shifted src ptr |
316 | // 2) Now bring the required data from the end of reg to beginning. |
317 | // Example: idx=[31..0] |
318 | // vreg_src = [x,x,x,x,.....,x,-,-,-,-,-] ; x => byte data |
319 | // shift to transform vreg_src = [-,-,-,-,-,x,..,x,x,x,x,] |
320 | const uint8_t shift = cpu_isa_traits<avx2>::vlen - jpp.c_tail; |
321 | |
322 | if (jpp.safe_c_tail) { |
323 | |
324 | /* load src_tail at 'src_address - shift' so that it does not |
325 | * spill over the memory boundary */ |
326 | vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset - shift]); |
327 | |
328 | vperm2i128(vreg_tmp, vreg_src(jj), vreg_src(jj), 0x81); |
329 | vpalignr(vreg_src(jj), vreg_tmp, vreg_src(jj), shift); |
330 | |
331 | } else { |
332 | Label load_data_safely, done; |
333 | add(aux_reg_src_w, offset); |
334 | |
335 | // Check if mask crosses page boundary |
336 | cmp(aux_reg_src_w, reg_src_safe_access); |
337 | ja(load_data_safely, T_NEAR); |
338 | |
339 | vpblendvb( |
340 | vreg_src(jj), vreg_tmp, byte[aux_reg_src_w], vreg_mask); |
341 | jmp(done, T_NEAR); |
342 | |
343 | L(load_data_safely); |
344 | |
345 | /* load src_tail at 'src_address - shift' so that it does not |
346 | * spill over the memory boundary */ |
347 | vmovups(vreg_src(jj), ptr[aux_reg_src_w - shift]); |
348 | |
349 | vperm2i128(vreg_tmp, vreg_src(jj), vreg_src(jj), 0x81); |
350 | vpalignr(vreg_src(jj), vreg_tmp, vreg_src(jj), shift); |
351 | |
352 | L(done); |
353 | sub(aux_reg_src_w, offset); |
354 | } |
355 | } |
356 | |
357 | } else |
358 | vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); |
359 | }; |
360 | |
361 | template <> |
362 | void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op( |
363 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
364 | using namespace data_type; |
365 | |
366 | if (masked) { |
367 | if (jpp.src_dt == s32) |
368 | vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); |
369 | else |
370 | vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); |
371 | } else |
372 | vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); |
373 | }; |
374 | |
375 | template <> |
376 | void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_src_avg_op( |
377 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
378 | using namespace data_type; |
379 | |
380 | const Vmm &vr_src = vreg_src_s32(jj, ll); |
381 | |
382 | if (jpp.src_dt == s32) { |
383 | if (masked) |
384 | for (int64_t i = 0; i < jpp.c_tail; i++) |
385 | pinsrd(vr_src, |
386 | ptr[aux_reg_src_w + offset + i * data_type_size(s32)], |
387 | i); |
388 | else |
389 | movups(vr_src, ptr[aux_reg_src_w + offset]); |
390 | } else if (utils::one_of(jpp.src_dt, s8, u8)) { |
391 | if (masked) { |
392 | const int copy_range = math::ilog2q(jpp.tail[ll] + 1); |
393 | for (int i = 0; i < copy_range; i++) |
394 | pinsrb(vr_src, ptr[aux_reg_src_w + offset + i], i); |
395 | |
396 | if (jpp.src_dt == s8) |
397 | pmovsxbd(vr_src, vr_src); |
398 | else |
399 | pmovzxbd(vr_src, vr_src); |
400 | } else { |
401 | if (jpp.src_dt == s8) |
402 | pmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); |
403 | else |
404 | pmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); |
405 | } |
406 | } else |
407 | assert(!"unsupported src data type" ); |
408 | } |
409 | |
410 | template <> |
411 | void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op( |
412 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
413 | using namespace data_type; |
414 | |
415 | auto load_i8 = [&](bool is_signed, const Vmm &vr_src) { |
416 | // Need to use mask of tail? |
417 | if (masked) { |
418 | |
419 | // load ll-th part of mask into vreg_mask_q |
420 | load_vreg_mask_q(ll); |
421 | |
422 | // Steps to access 'tail' section: |
423 | // 1) First load all data from the shifted src ptr |
424 | // 2) Now bring the required data from the end of reg to begining. |
425 | // Example: idx=[31..0] |
426 | // vreg_src = [x,x,x,x,.....,x,-,-,-,-,-] ; x => byte data |
427 | // shift to transform vreg_src = [-,-,-,-,-,x,..,x,x,x,x,] |
428 | // Re-purposing vreg_zeros here. Set it back to zero immmediately. |
429 | const int msk_gran |
430 | = cpu_isa_traits<avx2>::vlen / data_type_size(avg_proc_dt); |
431 | |
432 | const uint8_t shift = cpu_isa_traits<avx2>::vlen |
433 | - (jpp.c_tail > (ll + 1) * msk_gran |
434 | ? msk_gran |
435 | : jpp.c_tail - (ll * msk_gran)); |
436 | if (jpp.safe_c_tail) { |
437 | /* load src_tail at 'src_address - shift' so that it does not |
438 | * spill over the memory boundary */ |
439 | vmovups(vr_src, ptr[aux_reg_src_w + offset - shift]); |
440 | |
441 | vperm2i128(vreg_zeros, vr_src, vr_src, 0x81); |
442 | vpalignr(vr_src, vreg_zeros, vr_src, shift); |
443 | uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros); |
444 | } else { |
445 | Label load_data_safely, done; |
446 | // assume that it is not safe to load the src_tail |
447 | |
448 | add(aux_reg_src_w, offset); |
449 | |
450 | // Check if load crosses the memory boundary |
451 | cmp(aux_reg_src_w, reg_src_safe_access); |
452 | ja(load_data_safely, T_NEAR); |
453 | |
454 | vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w], vreg_mask_q); |
455 | jmp(done, T_NEAR); |
456 | |
457 | L(load_data_safely); |
458 | |
459 | /* load src_tail at 'src_address - shift' so that it does not |
460 | * spill over the memory boundary */ |
461 | vmovups(vr_src, ptr[aux_reg_src_w - shift]); |
462 | |
463 | vperm2i128(vreg_zeros, vr_src, vr_src, 0x81); |
464 | vpalignr(vr_src, vreg_zeros, vr_src, shift); |
465 | uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros); |
466 | |
467 | L(done); |
468 | sub(aux_reg_src_w, offset); |
469 | } |
470 | |
471 | // Conversion s8/u8 -> s32 |
472 | if (is_signed) |
473 | vpmovsxbd(vr_src, vr_src); |
474 | else |
475 | vpmovzxbd(vr_src, vr_src); |
476 | } else { |
477 | |
478 | // Load from mem into vr_src with conversion |
479 | if (is_signed) |
480 | vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); |
481 | else |
482 | vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); |
483 | } |
484 | }; |
485 | |
486 | switch (jpp.src_dt) { |
487 | case s32: |
488 | if (masked) |
489 | vpmaskmovd(vreg_src_s32(jj, ll), vreg_mask, |
490 | ptr[aux_reg_src_w + offset]); |
491 | else |
492 | vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]); |
493 | break; |
494 | case s8: load_i8(true, vreg_src_s32(jj, ll)); break; |
495 | case u8: load_i8(false, vreg_src_s32(jj, ll)); break; |
496 | default: assert(!"unsupported src data type" ); |
497 | } |
498 | }; |
499 | |
500 | template <> |
501 | void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op( |
502 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
503 | using namespace data_type; |
504 | |
505 | const Vmm &vr_src |
506 | = masked ? vreg_src_s32(jj, ll) | mask(ll) : vreg_src_s32(jj, ll); |
507 | |
508 | switch (jpp.src_dt) { |
509 | case s32: vmovups(vr_src, ptr[aux_reg_src_w + offset]); break; |
510 | case s8: vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); break; |
511 | case u8: vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); break; |
512 | default: assert(!"unsupported src data type" ); |
513 | } |
514 | }; |
515 | |
516 | template <cpu_isa_t isa> |
517 | void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) { |
518 | using namespace data_type; |
519 | |
520 | int c_block = jpp.c_block; |
521 | int ur_c = jpp.ur_c; |
522 | |
523 | switch (jpp.alg) { |
524 | case pooling_max: { |
525 | auto offset = jj * c_block * sizeof_src_dt(); |
526 | bool masked = jj == ur_c - 1 && c_tail; |
527 | load_src_max_op(jj, ll, offset, masked, jpp.tail[0]); |
528 | break; |
529 | } |
530 | case pooling_avg_include_padding: |
531 | case pooling_avg_exclude_padding: { |
532 | auto offset = (ll * (c_block / max_num_ll) + jj * c_block) |
533 | * sizeof_src_dt(); |
534 | bool masked = jj == ur_c - 1 && c_tail; |
535 | load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]); |
536 | break; |
537 | } |
538 | default: assert(!"unsupported algorithm" ); |
539 | } |
540 | } |
541 | |
542 | template <> |
543 | void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::store_dst_max_op( |
544 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
545 | using namespace data_type; |
546 | |
547 | if (masked) { |
548 | if (jpp.src_dt == s32) |
549 | for (int i = 0; i < jpp.c_tail; i++) |
550 | pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)], |
551 | vreg_dst(jj), i); |
552 | else if (utils::one_of(jpp.src_dt, u8, s8)) |
553 | for (int i = 0; i < jpp.c_tail; i++) |
554 | pextrb(ptr[reg_ptr_dst_i8 + offset + i], vreg_dst(jj), i); |
555 | else |
556 | assert(!"unsupported src data type" ); |
557 | } else |
558 | movups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); |
559 | } |
560 | |
561 | template <> |
562 | void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op( |
563 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
564 | using namespace data_type; |
565 | |
566 | Label store_data_safely, done; |
567 | |
568 | int c_block = jpp.c_block; |
569 | |
570 | const uint64_t low_mask = (1ULL << (c_block / 2)) - 1; |
571 | const uint8_t shift = cpu_isa_traits<avx2>::vlen - jpp.c_tail; |
572 | |
573 | if (masked) { |
574 | switch (jpp.src_dt) { |
575 | case s32: |
576 | vpmaskmovd( |
577 | ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj)); |
578 | break; |
579 | case s8: |
580 | case u8: { |
581 | |
582 | lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); |
583 | |
584 | if (!jpp.safe_c_tail) { |
585 | Xmm xreg_dst = Xmm(vreg_dst(jj).getIdx()); |
586 | |
587 | cmp(reg_ptr_maskmovdqu_dst, reg_dst_safe_access); |
588 | ja(store_data_safely, T_NEAR); |
589 | |
590 | // Store low half by mask (bytes 0...15) |
591 | vmaskmovdqu(xreg_dst, xreg_mask_lo); |
592 | |
593 | // Do we need to store high half (bytes 16...31) ? |
594 | if (msk & ~low_mask) { |
595 | vextracti128(xreg_dst, vreg_dst(jj), 1); |
596 | add(reg_ptr_maskmovdqu_dst, c_block / 2); |
597 | vmaskmovdqu(xreg_dst, xreg_mask_hi); |
598 | } |
599 | jmp(done, T_NEAR); |
600 | } |
601 | |
602 | L(store_data_safely); |
603 | |
604 | vperm2i128(vreg_tail, vreg_dst(jj), vreg_dst(jj), 0x08); |
605 | if (shift <= 16) { |
606 | vpalignr(vreg_tail, vreg_dst(jj), vreg_tail, 16 - shift); |
607 | } else { |
608 | vpalignr(vreg_tail, vreg_tail, vreg_zeros, 32 - shift); |
609 | } |
610 | |
611 | Xmm xreg_tail = Xmm(vreg_tail.getIdx()); |
612 | // Do we need to store low half (bytes 0...15) ? |
613 | if (msk & ~low_mask) { |
614 | sub(reg_ptr_maskmovdqu_dst, shift); |
615 | vmaskmovdqu(xreg_tail, xreg_mask_2_lo); |
616 | add(reg_ptr_maskmovdqu_dst, c_block / 2); |
617 | } else { |
618 | add(reg_ptr_maskmovdqu_dst, (c_block / 2) - shift); |
619 | } |
620 | |
621 | // Store high half by mask (bytes 16..31) |
622 | vextracti128(xreg_tail, vreg_tail, 1); |
623 | vmaskmovdqu(xreg_tail, xreg_mask_2_hi); |
624 | |
625 | L(done); |
626 | } break; |
627 | default: assert(!"unsupported src data type" ); |
628 | } |
629 | } else |
630 | vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); |
631 | } |
632 | |
633 | template <> |
634 | void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op( |
635 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
636 | using namespace data_type; |
637 | |
638 | if (masked) { |
639 | switch (jpp.src_dt) { |
640 | case s32: |
641 | vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); |
642 | break; |
643 | case s8: |
644 | case u8: |
645 | vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); |
646 | break; |
647 | default: assert(!"unsupported src data type" ); |
648 | } |
649 | } else |
650 | vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); |
651 | } |
652 | |
653 | template <> |
654 | void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::store_dst_avg_op( |
655 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
656 | using namespace data_type; |
657 | |
658 | // Don't generate useless code |
659 | if (masked && !msk) return; |
660 | |
661 | const Vmm &vr_dst = vreg_dst_s32(jj, ll); |
662 | |
663 | if (jpp.src_dt == s32) { |
664 | if (masked) |
665 | for (int i = 0; i < jpp.c_tail; i++) |
666 | pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)], |
667 | vr_dst, i); |
668 | else |
669 | movups(ptr[reg_ptr_dst_i8 + offset], vr_dst); |
670 | } else if (utils::one_of(jpp.src_dt, s8, u8)) { |
671 | packssdw(vr_dst, vr_dst); |
672 | if (jpp.src_dt == s8) |
673 | packsswb(vr_dst, vr_dst); |
674 | else |
675 | packuswb(vr_dst, vr_dst); |
676 | |
677 | const int copy_range = masked |
678 | ? math::ilog2q(jpp.tail[ll] + 1) |
679 | : cpu_isa_traits<sse41>::vlen / data_type_size(avg_proc_dt); |
680 | for (int i = 0; i < copy_range; i++) |
681 | pextrb(ptr[reg_ptr_dst_i8 + offset + i], vr_dst, i); |
682 | } else |
683 | assert(!"unsupported src data type" ); |
684 | } |
685 | |
686 | template <> |
687 | void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op( |
688 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
689 | using namespace data_type; |
690 | |
691 | // Don't generate useless code |
692 | if (masked && !msk) return; |
693 | |
694 | auto s32_to_i8 = [&](bool is_signed, const Vmm &vr_dst) { |
695 | // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16} |
696 | // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0} |
697 | if (is_signed) |
698 | vpackssdw(vr_dst, vr_dst, vreg_zeros); |
699 | else |
700 | vpackusdw(vr_dst, vr_dst, vreg_zeros); |
701 | |
702 | // Permute qwords to restore original order |
703 | // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0} |
704 | vpermq(vr_dst, vr_dst, 0x58); |
705 | |
706 | // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8} |
707 | // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx} |
708 | if (is_signed) |
709 | vpacksswb(vr_dst, vr_dst, vreg_zeros); |
710 | else |
711 | vpackuswb(vr_dst, vr_dst, vreg_zeros); |
712 | }; |
713 | |
714 | auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm &vr_dst) { |
715 | // Conversion s32 -> s8/u8 |
716 | s32_to_i8(is_signed, vr_dst); |
717 | |
718 | // early-out for non-masked cases |
719 | if (!is_masked) { |
720 | vmovlps(ptr[reg_ptr_dst_i8 + offset], Xmm(vr_dst.getIdx())); |
721 | return; |
722 | } |
723 | // store 8 bytes |
724 | lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); |
725 | |
726 | // Need to use mmx 8-bytes operation to avoid memory violations. |
727 | // NOTICE: it was discovered that Intel SSE and Intel AVX instructions |
728 | // maskmovdqu/vmaskmovdqu |
729 | // with low 8-bytes mask throws exception if high 8-bytes belongs write-protected page. |
730 | // NOTE: use indirect move via gpr to avoid transition penalty |
731 | vmovq(reg_tmp, Xmm(vr_dst.getIdx())); |
732 | movq(mmx_dst_i8, reg_tmp); |
733 | |
734 | // mmx_full_msk - mask for all 8 bytes in zero-tail case |
735 | // mmx_mask(ll) - ll-th mask of tail in non-zero-tail case |
736 | |
737 | const int msk_gran |
738 | = cpu_isa_traits<avx2>::vlen / data_type_size(avg_proc_dt); |
739 | |
740 | const int ll_end = (ll + 1) * msk_gran; // ((ll + 1) * 8) |
741 | |
742 | if (is_masked && (ll_end > jpp.c_tail)) { //implies this tail not full. |
743 | Label store_data_safely, done; |
744 | const uint8_t shift = msk_gran - jpp.c_tail % msk_gran; |
745 | |
746 | if (!jpp.safe_c_tail) { |
747 | cmp(reg_ptr_maskmovdqu_dst, reg_dst_safe_access); |
748 | ja(store_data_safely, T_NEAR); |
749 | |
750 | /* store dst_tail with overlap outside the channel dimension, |
751 | * but assume it's within the memory boundary. */ |
752 | maskmovq(mmx_dst_i8, mmx_mask(ll)); |
753 | jmp(done, T_NEAR); |
754 | } |
755 | |
756 | L(store_data_safely); |
757 | |
758 | /* store dst_tail at 'dst_address - shift' so that it does not |
759 | * spill over the memory boundary */ |
760 | movq(mmx_tmp, mmx_mask(ll)); |
761 | psllq(mmx_tmp, shift * 8); // multiply with 8 (bits/byte) |
762 | psllq(mmx_dst_i8, shift * 8); |
763 | sub(reg_ptr_maskmovdqu_dst, shift); |
764 | maskmovq(mmx_dst_i8, mmx_tmp); |
765 | |
766 | L(done); |
767 | } else { |
768 | maskmovq(mmx_dst_i8, mmx_full_msk); |
769 | } |
770 | }; |
771 | |
772 | switch (jpp.dst_dt) { |
773 | case s32: |
774 | if (masked) { |
775 | vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, |
776 | vreg_dst_s32(jj, ll)); |
777 | } else |
778 | vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll)); |
779 | break; |
780 | case s8: store_i8(true, masked, vreg_dst_s32(jj, ll)); break; |
781 | case u8: store_i8(false, masked, vreg_dst_s32(jj, ll)); break; |
782 | default: assert(!"unsuppotred dst data_type" ); |
783 | } |
784 | } |
785 | |
786 | template <> |
787 | void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op( |
788 | int jj, int ll, size_t offset, bool masked, uint64_t msk) { |
789 | using namespace data_type; |
790 | |
791 | // Don't generate useless code |
792 | if (masked && !msk) return; |
793 | |
794 | const Vmm &vr_dst |
795 | = masked ? vreg_dst_s32(jj, ll) | mask(ll) : vreg_dst_s32(jj, ll); |
796 | |
797 | switch (jpp.dst_dt) { |
798 | case s32: vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; |
799 | case s8: vpmovsdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; |
800 | case u8: vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break; |
801 | default: assert(!"unsupported dst data_type" ); |
802 | } |
803 | } |
804 | |
805 | template <cpu_isa_t isa> |
806 | void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst( |
807 | int jj, int ll, int c_tail) { |
808 | using namespace data_type; |
809 | |
810 | int c_block = jpp.c_block; |
811 | int ur_c = jpp.ur_c; |
812 | |
813 | switch (jpp.alg) { |
814 | case pooling_max: { |
815 | auto offset = jj * c_block * sizeof_dst_dt(); |
816 | bool masked = jj == ur_c - 1 && c_tail; |
817 | store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]); |
818 | break; |
819 | } |
820 | case pooling_avg_include_padding: |
821 | case pooling_avg_exclude_padding: { |
822 | auto offset = (ll * (c_block / max_num_ll) + jj * c_block) |
823 | * sizeof_dst_dt(); |
824 | bool masked = jj == ur_c - 1 && c_tail; |
825 | store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]); |
826 | break; |
827 | } |
828 | default: assert(!"unsupported pooling algorithm" ); |
829 | } |
830 | } |
831 | |
832 | template <> |
833 | void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::compute_max_op(const int jj) { |
834 | using namespace data_type; |
835 | switch (jpp.src_dt) { |
836 | case s32: pmaxsd(vreg_dst(jj), vreg_src(jj)); break; |
837 | case s8: pmaxsb(vreg_dst(jj), vreg_src(jj)); break; |
838 | case u8: pmaxub(vreg_dst(jj), vreg_src(jj)); break; |
839 | default: assert(!"unsupported src data type" ); |
840 | } |
841 | } |
842 | |
843 | template <> |
844 | void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj) { |
845 | using namespace data_type; |
846 | switch (jpp.src_dt) { |
847 | case s32: vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break; |
848 | case s8: vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break; |
849 | case u8: vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break; |
850 | default: assert(!"unsupported src data type" ); |
851 | } |
852 | } |
853 | |
854 | template <> |
855 | void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj) { |
856 | using namespace data_type; |
857 | |
858 | // Compare |
859 | switch (jpp.src_dt) { |
860 | case s32: |
861 | vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); |
862 | break; |
863 | case s8: |
864 | vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); |
865 | break; |
866 | case u8: |
867 | vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); |
868 | break; |
869 | default: assert(!"unsupported src data type" ); |
870 | } |
871 | |
872 | // move max values into vreg_dst |
873 | if (jpp.src_dt == s32) |
874 | vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); |
875 | else |
876 | vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); |
877 | } |
878 | |
879 | template <cpu_isa_t isa> |
880 | void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step( |
881 | int ur_c, int c_tail) { |
882 | Label l_kd, l_kh, l_kw; |
883 | |
884 | int ih = jpp.ih; |
885 | int iw = jpp.iw; |
886 | int c = jpp.c; |
887 | |
888 | for (int jj = 0; jj < ur_c; jj++) |
889 | uni_vmovups(vreg_dst(jj), vreg_tmp); |
890 | |
891 | mov(aux_reg_src_d, reg_ptr_src_i8); |
892 | xor_(reg_kd_index, reg_kd_index); |
893 | L(l_kd); |
894 | { |
895 | mov(aux_reg_src_h, aux_reg_src_d); |
896 | xor_(reg_kh_index, reg_kh_index); |
897 | L(l_kh); |
898 | { |
899 | mov(aux_reg_src_w, aux_reg_src_h); |
900 | xor_(reg_kw_index, reg_kw_index); |
901 | L(l_kw); |
902 | { |
903 | for (int jj = 0; jj < ur_c; jj++) { |
904 | load_src(jj, 0, c_tail); |
905 | compute_max_op(jj); |
906 | } |
907 | add(aux_reg_src_w, c * sizeof_src_dt()); |
908 | inc(reg_kw_index); |
909 | cmp(reg_kw_index, reg_kw); |
910 | jl(l_kw, T_NEAR); |
911 | } |
912 | add(aux_reg_src_h, iw * c * sizeof_src_dt()); |
913 | inc(reg_kh_index); |
914 | cmp(reg_kh_index, reg_kh); |
915 | jl(l_kh, T_NEAR); |
916 | } |
917 | add(aux_reg_src_d, ih * iw * c * sizeof_src_dt()); |
918 | inc(reg_kd_index); |
919 | cmp(reg_kd_index, reg_kd); |
920 | jl(l_kd, T_NEAR); |
921 | } |
922 | |
923 | for (int jj = 0; jj < ur_c; jj++) |
924 | store_dst(jj, 0, c_tail); |
925 | } |
926 | |
927 | template <cpu_isa_t isa> |
928 | void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step( |
929 | int ur_c, int c_tail) { |
930 | using namespace data_type; |
931 | |
932 | Label l_kd, l_kh, l_kw; |
933 | |
934 | int ih = jpp.ih; |
935 | int iw = jpp.iw; |
936 | int c = jpp.c; |
937 | |
938 | const int num_ll = data_type_size(avg_proc_dt) / data_type_size(jpp.src_dt); |
939 | |
940 | for (int jj = 0; jj < ur_c; jj++) { |
941 | for (int ll = 0; ll < num_ll; ll++) { |
942 | bool masked = jj == ur_c - 1 && c_tail; |
943 | size_t msk = jpp.tail[ll]; |
944 | if (!(masked && !msk)) { |
945 | // Clearing of src reg is not needed as they are written before read |
946 | uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), |
947 | vreg_dst_s32(jj, ll)); |
948 | } |
949 | } |
950 | } |
951 | |
952 | mov(aux_reg_src_d, reg_ptr_src_i8); |
953 | xor_(reg_kd_index, reg_kd_index); |
954 | L(l_kd); |
955 | { |
956 | mov(aux_reg_src_h, aux_reg_src_d); |
957 | xor_(reg_kh_index, reg_kh_index); |
958 | L(l_kh); |
959 | { |
960 | mov(aux_reg_src_w, aux_reg_src_h); |
961 | xor_(reg_kw_index, reg_kw_index); |
962 | L(l_kw); |
963 | { |
964 | for (int jj = 0; jj < ur_c; jj++) { |
965 | for (int ll = 0; ll < num_ll; ll++) { |
966 | bool masked = jj == ur_c - 1 && c_tail; |
967 | size_t msk = jpp.tail[ll]; |
968 | if (!(masked && !msk)) { |
969 | load_src(jj, ll, c_tail); |
970 | uni_vpaddd(vreg_dst_s32(jj, ll), |
971 | vreg_dst_s32(jj, ll), vreg_src_s32(jj, ll)); |
972 | } |
973 | } |
974 | } |
975 | add(aux_reg_src_w, c * sizeof_src_dt()); |
976 | inc(reg_kw_index); |
977 | cmp(reg_kw_index, reg_kw); |
978 | jl(l_kw, T_NEAR); |
979 | } |
980 | add(aux_reg_src_h, iw * c * sizeof_src_dt()); |
981 | inc(reg_kh_index); |
982 | cmp(reg_kh_index, reg_kh); |
983 | jl(l_kh, T_NEAR); |
984 | } |
985 | add(aux_reg_src_d, ih * iw * c * sizeof_src_dt()); |
986 | inc(reg_kd_index); |
987 | cmp(reg_kd_index, reg_kd); |
988 | jl(l_kd, T_NEAR); |
989 | } |
990 | |
991 | for (int jj = 0; jj < ur_c; jj++) { |
992 | for (int ll = 0; ll < num_ll; ll++) { |
993 | const bool masked = jj == ur_c - 1 && c_tail; |
994 | const size_t msk = jpp.tail[ll]; |
995 | if (!(masked && !msk)) { |
996 | const auto ®_dst_f32 = vreg_dst_f32(jj, ll); |
997 | const auto ®_dst_s32 = vreg_dst_s32(jj, ll); |
998 | uni_vcvtdq2ps(reg_dst_f32, reg_dst_s32); |
999 | uni_vfmadd132ps(reg_dst_f32, vreg_zeros, vreg_tmp); |
1000 | |
1001 | if (jpp.with_postops) { |
1002 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
1003 | if (jpp.with_binary) { |
1004 | rhs_arg_params.vmm_idx_to_out_reg.emplace( |
1005 | reg_dst_f32.getIdx(), reg_ptr_dst_i8); |
1006 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
1007 | reg_dst_f32.getIdx(), get_offset_dst(jj, ll)); |
1008 | const bool tail = ll == post_op_tail_opmask_idx_; |
1009 | if (tail && masked) |
1010 | rhs_arg_params.vmm_tail_idx_.emplace( |
1011 | reg_dst_f32.getIdx()); |
1012 | } |
1013 | postops_injector_->compute_vector( |
1014 | reg_dst_f32.getIdx(), rhs_arg_params); |
1015 | } |
1016 | |
1017 | uni_vcvtps2dq(reg_dst_s32, reg_dst_f32); |
1018 | |
1019 | if (jpp.with_postops) |
1020 | if (jpp.dst_dt == u8) { |
1021 | uni_vpmaxsd(reg_dst_s32, reg_dst_s32, vreg_zeros); |
1022 | } |
1023 | store_dst(jj, ll, c_tail); |
1024 | } |
1025 | } |
1026 | } |
1027 | } |
1028 | |
1029 | template <cpu_isa_t isa> |
1030 | void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) { |
1031 | switch (jpp.alg) { |
1032 | case pooling_max: compute_max_step(ur_c, c_tail); break; |
1033 | case pooling_avg_include_padding: |
1034 | case pooling_avg_exclude_padding: compute_avg_step(ur_c, c_tail); break; |
1035 | default: assert(!"unsupported pooling algorithm" ); |
1036 | } |
1037 | } |
1038 | |
1039 | template <cpu_isa_t isa> |
1040 | void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block() { |
1041 | Label l_main_loop; |
1042 | |
1043 | int nb_c = jpp.nb_c; |
1044 | int c_block = jpp.c_block; |
1045 | int ur_c = jpp.ur_c; |
1046 | int ur_c_tail = jpp.ur_c_tail; |
1047 | int c_steps = nb_c / ur_c; |
1048 | int c_tail = jpp.c_tail; |
1049 | |
1050 | xor_(c_iter, c_iter); |
1051 | if (c_steps > 0) { |
1052 | L(l_main_loop); |
1053 | { |
1054 | compute_step(ur_c, 0); |
1055 | add(reg_ptr_src_i8, ur_c * c_block * sizeof_src_dt()); |
1056 | add(reg_ptr_dst_i8, ur_c * c_block * sizeof_dst_dt()); |
1057 | inc(c_iter); |
1058 | cmp(c_iter, c_steps); |
1059 | jl(l_main_loop, T_NEAR); |
1060 | } |
1061 | } |
1062 | |
1063 | if (ur_c_tail != 0) { compute_step(ur_c_tail, c_tail); } |
1064 | } |
1065 | |
1066 | template <> |
1067 | void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::init_mask() {} |
1068 | |
1069 | template <> |
1070 | void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() { |
1071 | using namespace data_type; |
1072 | using cpu_isa = cpu_isa_traits<avx2>; |
1073 | |
1074 | // AVX2 mask initialization: mask stored in Ymm-regs |
1075 | auto init = [&](uint64_t bit_mask, bool need_ymm_mask = true, |
1076 | bool need_mmx_mask = false) { |
1077 | const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t); |
1078 | |
1079 | const size_t DBITS = 8 * sizeof_src_dt(); |
1080 | const uint64_t VMSK = 1ULL << (DBITS - 1); |
1081 | const size_t D_PER_QW = (8 * sizeof(uint64_t)) / DBITS; |
1082 | uint64_t vmask[QW_PER_VREG]; |
1083 | for (size_t i = 0; i < QW_PER_VREG; i++) { |
1084 | uint64_t qw_vmask = 0ULL; |
1085 | for (size_t j = 0; j < D_PER_QW; j++) { |
1086 | if (bit_mask & 1) qw_vmask |= VMSK << DBITS * j; |
1087 | bit_mask >>= 1; |
1088 | } |
1089 | vmask[i] = qw_vmask; |
1090 | } |
1091 | |
1092 | // Need mask in Ymm regs ? |
1093 | if (need_ymm_mask) { |
1094 | |
1095 | // Put QWORDS with target mask into xmm regs |
1096 | const int xdst_i[QW_PER_VREG] |
1097 | = {xreg_mask_lo.getIdx(), xreg_mask_lo.getIdx(), |
1098 | xreg_mask_hi.getIdx(), xreg_mask_hi.getIdx()}; |
1099 | const int xsrc_i[QW_PER_VREG] = { |
1100 | vreg_zeros |
1101 | .getIdx(), // 0-th qword insert in zeros -> {qw0, 0} |
1102 | xreg_mask_lo |
1103 | .getIdx(), // 1-st and 0-th merge -> {qw0,qw1} |
1104 | vreg_zeros.getIdx(), xreg_mask_hi.getIdx()}; |
1105 | const uint8 qw_dst_idx[QW_PER_VREG] |
1106 | = {0, 1, 0, 1}; // qword index in 128-bit xreg |
1107 | |
1108 | for (size_t i = 0; i < QW_PER_VREG; i++) { |
1109 | mov(reg_mask, vmask[i]); |
1110 | vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, |
1111 | qw_dst_idx[i]); |
1112 | |
1113 | // Need mask in MMX regs also? |
1114 | if (need_mmx_mask) |
1115 | movq(mmx_mask(i), reg_mask); // reuse value in reg_mask |
1116 | } |
1117 | |
1118 | // Merge Low (xreg_mask_lo alias for vreg_mask.xreg) |
1119 | // and High (xreg_mask_hi) into full vreg_mask |
1120 | // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg} |
1121 | vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1); |
1122 | |
1123 | // Compute mask algned to left from vreg_mask and store it in vreg_mask_2 to be use for tail processing. |
1124 | const uint8_t shift = 32 - jpp.c_tail; |
1125 | vperm2i128(vreg_mask_2, vreg_mask, vreg_mask, 0x08); |
1126 | if (shift <= 16) { |
1127 | vpalignr(vreg_mask_2, vreg_mask, vreg_mask_2, 16 - shift); |
1128 | } else { |
1129 | vpalignr(vreg_mask_2, vreg_mask_2, vreg_zeros, 32 - shift); |
1130 | } |
1131 | vextracti128(xreg_mask_2_hi, vreg_mask_2, 0x1); |
1132 | } |
1133 | |
1134 | // Need mask in MMX regs ? |
1135 | if (need_mmx_mask) { |
1136 | |
1137 | // Only in MMX regs ? |
1138 | if (!need_ymm_mask) |
1139 | for (size_t i = 0; i < QW_PER_VREG; i++) { |
1140 | mov(reg_mask, vmask[i]); |
1141 | movq(mmx_mask(i), reg_mask); |
1142 | } |
1143 | |
1144 | // Form full mask for one QWORD |
1145 | uint64_t qw_full_vmask = 0ULL; |
1146 | for (size_t i = 0; i < D_PER_QW; i++) |
1147 | qw_full_vmask |= VMSK << DBITS * i; |
1148 | |
1149 | mov(reg_mask, qw_full_vmask); |
1150 | movq(mmx_full_msk, reg_mask); |
1151 | } |
1152 | }; |
1153 | |
1154 | uint64_t tail_mask = (1ULL << jpp.c_tail) - 1; |
1155 | switch (jpp.alg) { |
1156 | case pooling_max: |
1157 | // For "max" we need mask only in case of non-zero tail |
1158 | if (tail_mask) init(tail_mask); |
1159 | break; |
1160 | case pooling_avg_include_padding: |
1161 | case pooling_avg_exclude_padding: |
1162 | // For "avg" we need mask: |
1163 | // - s32 - in case of the non-zero tail |
1164 | // - s8/u8 - irrespective of the tail in MMX regs (always store by mask) |
1165 | // - for non-zero tail in Ymm regs (for load) |
1166 | switch (jpp.src_dt) { |
1167 | case s32: |
1168 | if (tail_mask) init(tail_mask); |
1169 | break; |
1170 | case s8: |
1171 | case u8: |
1172 | init(tail_mask ? tail_mask : ~0ULL, tail_mask != 0, true); |
1173 | break; |
1174 | default: assert(!"unsupported src data type" ); |
1175 | } |
1176 | break; |
1177 | default: assert(!"unsupported pooling algorithm" ); |
1178 | } |
1179 | } |
1180 | |
1181 | template <> |
1182 | void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() { |
1183 | |
1184 | for (int ll = 0; ll < max_num_ll; ll++) { |
1185 | mov(reg_mask, jpp.tail[ll]); |
1186 | kmovq(mask(ll), reg_mask); |
1187 | } |
1188 | } |
1189 | |
1190 | template <cpu_isa_t isa> |
1191 | void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() { |
1192 | using namespace data_type; |
1193 | |
1194 | switch (jpp.alg) { |
1195 | case pooling_avg_include_padding: |
1196 | case pooling_avg_exclude_padding: |
1197 | mov(reg_tmp, |
1198 | ptr[reg_param |
1199 | + offsetof(jit_uni_i8i8_pool_call_params_t, |
1200 | idivider)]); |
1201 | uni_vmovq(xmm_tmp, reg_tmp); |
1202 | uni_vpbroadcastd(vreg_tmp, xmm_tmp); |
1203 | break; |
1204 | case pooling_max: |
1205 | switch (jpp.src_dt) { |
1206 | case s32: |
1207 | mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest()); |
1208 | break; |
1209 | case s8: |
1210 | mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest()); |
1211 | break; |
1212 | case u8: |
1213 | mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest()); |
1214 | break; |
1215 | default: assert(!"unsupported src data_type" ); |
1216 | } |
1217 | |
1218 | uni_vmovq(xmm_tmp, reg_tmp); |
1219 | if (jpp.src_dt == s32) |
1220 | uni_vpbroadcastd(vreg_tmp, xmm_tmp); |
1221 | else if (mayiuse(avx2)) |
1222 | vpbroadcastb(vreg_tmp, xmm_tmp); |
1223 | else |
1224 | pshufb(xmm_tmp, vreg_zeros); |
1225 | break; |
1226 | default: assert(!"unsupported pooling algorithm" ); |
1227 | } |
1228 | } |
1229 | |
1230 | template <cpu_isa_t isa> |
1231 | void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() { |
1232 | preamble(); |
1233 | |
1234 | #if !defined(_WIN32) |
1235 | // Always use rcx as abi_param1 - |
1236 | // see the note about maskmovdqu/maskmovq near reg_param. |
1237 | mov(rcx, rdi); |
1238 | #endif |
1239 | |
1240 | #define READ_PARAM(reg, field) \ |
1241 | mov(reg, ptr[reg_param + offsetof(jit_uni_i8i8_pool_call_params_t, field)]) |
1242 | READ_PARAM(reg_ptr_src_i8, src_i8); |
1243 | READ_PARAM(reg_ptr_dst_i8, dst_i8); |
1244 | READ_PARAM(reg_kd, kd_range); |
1245 | READ_PARAM(reg_kh, kh_range); |
1246 | READ_PARAM(reg_kw, kw_range); |
1247 | READ_PARAM(reg_src_safe_access, src_safe_access); |
1248 | READ_PARAM(reg_dst_safe_access, dst_safe_access); |
1249 | |
1250 | #undef READ_PARAM |
1251 | |
1252 | uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros); |
1253 | |
1254 | init_mask(); |
1255 | |
1256 | init_tmp_reg(); |
1257 | |
1258 | compute_c_block(); |
1259 | |
1260 | emms(); |
1261 | postamble(); |
1262 | |
1263 | if (jpp.with_eltwise && postops_injector_) |
1264 | postops_injector_->prepare_table(); |
1265 | } |
1266 | |
1267 | template <cpu_isa_t isa> |
1268 | status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf( |
1269 | jit_pool_conf_t &jpp, const pooling_pd_t *ppd) { |
1270 | if (!mayiuse(isa)) return status::unimplemented; |
1271 | |
1272 | const auto &pd = *ppd->desc(); |
1273 | const memory_desc_wrapper src_d(ppd->src_md()); |
1274 | const memory_desc_wrapper dst_d(ppd->dst_md()); |
1275 | const int ndims = src_d.ndims(); |
1276 | const bool is_1d = ndims == 3; |
1277 | const bool is_3d = ndims == 5; |
1278 | |
1279 | jpp.mb = src_d.dims()[0]; |
1280 | jpp.c = src_d.dims()[1]; |
1281 | |
1282 | jpp.id = is_3d ? src_d.dims()[ndims - 3] : 1; |
1283 | jpp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; |
1284 | jpp.iw = src_d.dims()[ndims - 1]; |
1285 | |
1286 | jpp.od = is_3d ? dst_d.dims()[ndims - 3] : 1; |
1287 | jpp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; |
1288 | jpp.ow = dst_d.dims()[ndims - 1]; |
1289 | |
1290 | jpp.stride_d = is_3d ? pd.strides[ndims - 5] : 1; |
1291 | jpp.stride_h = is_1d ? 1 : pd.strides[ndims - 4]; |
1292 | jpp.stride_w = pd.strides[ndims - 3]; |
1293 | |
1294 | jpp.kd = is_3d ? pd.kernel[ndims - 5] : 1; |
1295 | jpp.kh = is_1d ? 1 : pd.kernel[ndims - 4]; |
1296 | jpp.kw = pd.kernel[ndims - 3]; |
1297 | |
1298 | jpp.f_pad = is_3d ? pd.padding[0][ndims - 5] : 0; |
1299 | jpp.t_pad = is_1d ? 0 : pd.padding[0][ndims - 4]; |
1300 | jpp.l_pad = pd.padding[0][ndims - 3]; |
1301 | |
1302 | int back_pad = calculate_end_padding( |
1303 | jpp.f_pad, jpp.od, jpp.id, jpp.stride_d, jpp.kd); |
1304 | int bottom_pad = calculate_end_padding( |
1305 | jpp.t_pad, jpp.oh, jpp.ih, jpp.stride_h, jpp.kh); |
1306 | int right_pad = calculate_end_padding( |
1307 | jpp.l_pad, jpp.ow, jpp.iw, jpp.stride_w, jpp.kw); |
1308 | |
1309 | if (jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw |
1310 | || back_pad >= jpp.kd || bottom_pad >= jpp.kh |
1311 | || right_pad >= jpp.kw) |
1312 | return status::unimplemented; |
1313 | |
1314 | jpp.alg = pd.alg_kind; |
1315 | |
1316 | jpp.src_dt = pd.src_desc.data_type; |
1317 | jpp.dst_dt = pd.dst_desc.data_type; |
1318 | |
1319 | // data_type items per one vreg on the <isa> |
1320 | // isa == sse41 : 16 bytes -> 16 for s8/u8, 4 for s32 |
1321 | // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32 |
1322 | // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32 |
1323 | int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt); |
1324 | |
1325 | /* Verify that vlen-sized memory access happens within the tensor's |
1326 | * size, otherwise load/store will always spill outside the memory |
1327 | * boundary.*/ |
1328 | bool safe_load_n_store = IMPLICATION(utils::one_of(isa, avx2, sse41), |
1329 | jpp.mb * jpp.c * nstl::min(jpp.id, jpp.od) |
1330 | * nstl::min(jpp.ih, jpp.oh) |
1331 | * nstl::min(jpp.iw, jpp.ow) |
1332 | >= simd_w); |
1333 | if (!safe_load_n_store) return status::unimplemented; |
1334 | |
1335 | jpp.c_block = simd_w; |
1336 | jpp.c_tail = jpp.c % jpp.c_block; |
1337 | jpp.nb_c = jpp.c / jpp.c_block; |
1338 | jpp.ur_c = 1; |
1339 | jpp.ur_c_tail = jpp.c_tail != 0; |
1340 | |
1341 | size_t tail_mask = (1ULL << jpp.c_tail) - 1; |
1342 | |
1343 | /* If channel_size is bigger than vlen, we can safely assume there is no |
1344 | * underflow of memory boundary, so always perform c_tail and save |
1345 | * a couple of compute cycles*/ |
1346 | jpp.safe_c_tail = jpp.c_tail > 0 && jpp.c >= simd_w; |
1347 | |
1348 | switch (jpp.alg) { |
1349 | case pooling_max: |
1350 | jpp.tail[0] = tail_mask; |
1351 | jpp.tail[1] = 0; |
1352 | jpp.tail[2] = 0; |
1353 | jpp.tail[3] = 0; |
1354 | break; |
1355 | case pooling_avg_include_padding: |
1356 | case pooling_avg_exclude_padding: { |
1357 | // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32) |
1358 | // sse : 4, avx2 : 8, avx512 : 16 |
1359 | const size_t msk_gran |
1360 | = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt); |
1361 | const size_t msk_msk = (1ULL << msk_gran) - 1; |
1362 | size_t m = tail_mask; |
1363 | for (size_t ll = 0; ll < max_num_ll; ll++) { |
1364 | jpp.tail[ll] = m & msk_msk; |
1365 | m = m >> msk_gran; |
1366 | } |
1367 | break; |
1368 | } |
1369 | default: return status::unimplemented; |
1370 | } |
1371 | |
1372 | if (!post_ops_ok(jpp, *ppd->attr(), dst_d)) return status::unimplemented; |
1373 | |
1374 | return status::success; |
1375 | } |
1376 | |
1377 | template <cpu_isa_t isa> |
1378 | bool jit_uni_i8i8_pooling_fwd_ker_t<isa>::post_ops_ok(jit_pool_conf_t &jpp, |
1379 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) { |
1380 | const auto &post_ops = attr.post_ops_; |
1381 | const auto &entries = post_ops.entry_; |
1382 | jpp.with_postops = false; |
1383 | jpp.with_eltwise = false; |
1384 | jpp.with_binary = false; |
1385 | |
1386 | if (entries.empty()) return true; |
1387 | |
1388 | for (const auto &entry : entries) { |
1389 | if (entry.is_eltwise()) { |
1390 | const auto alg = entry.eltwise.alg; |
1391 | jpp.with_eltwise = eltwise_injector::is_supported(isa, alg); |
1392 | } else if (entry.is_binary()) { |
1393 | if (isa != avx512_core |
1394 | && entry.binary.src1_desc.data_type == data_type::bf16) |
1395 | return false; |
1396 | jpp.with_binary = true; |
1397 | } else |
1398 | return false; |
1399 | } |
1400 | |
1401 | jpp.with_postops = jpp.with_eltwise || jpp.with_binary; |
1402 | jpp.post_ops = post_ops; |
1403 | |
1404 | /* |
1405 | * TODO Currently eltwise/binary injectors assumes that data in vmm has f32 dt. |
1406 | * In max pooling data remains in i8 data type. |
1407 | */ |
1408 | return IMPLICATION(jpp.with_postops, jpp.alg != pooling_max) |
1409 | && binary_injector::binary_args_broadcast_supported( |
1410 | post_ops, dst_d, get_supported_bcast_strategies()); |
1411 | } |
1412 | |
1413 | template <cpu_isa_t isa> |
1414 | status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() { |
1415 | return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_, this); |
1416 | } |
1417 | |
1418 | template <cpu_isa_t isa> |
1419 | jit_uni_i8i8_pooling_fwd_t<isa>::jit_uni_i8i8_pooling_fwd_t(const pd_t *apd) |
1420 | : primitive_t(apd), ker_(nullptr) {} |
1421 | |
1422 | template <cpu_isa_t isa> |
1423 | jit_uni_i8i8_pooling_fwd_t<isa>::~jit_uni_i8i8_pooling_fwd_t() = default; |
1424 | |
1425 | template <cpu_isa_t isa> |
1426 | status_t jit_uni_i8i8_pooling_fwd_t<isa>::init(engine_t *engine) { |
1427 | CHECK(safe_ptr_assign(ker_, |
1428 | new jit_uni_i8i8_pooling_fwd_ker_t<isa>( |
1429 | pd()->jpp_, pd()->invariant_dst_md()))); |
1430 | return ker_->create_kernel(); |
1431 | } |
1432 | |
1433 | template <cpu_isa_t isa> |
1434 | status_t jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward( |
1435 | const exec_ctx_t &ctx) const { |
1436 | auto src_i8 = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
1437 | auto dst_i8 = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
1438 | |
1439 | const memory_desc_wrapper src_d(pd()->src_md()); |
1440 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
1441 | |
1442 | const auto &jpp = pd()->jpp_; |
1443 | const auto post_ops_binary_rhs_arg_vec |
1444 | = binary_injector::prepare_binary_args(jpp.post_ops, ctx); |
1445 | /* Calculate when the memory-access will happen outisde of the memory |
1446 | * boundary, if so, compute a safe memory access. */ |
1447 | const auto src_safe_access = reinterpret_cast<char *>( |
1448 | reinterpret_cast<ptrdiff_t>(src_i8 + src_d.size() - 1) |
1449 | - (cpu_isa_traits<isa>::vlen - 1)); |
1450 | |
1451 | const auto dst_safe_access = reinterpret_cast<char *>( |
1452 | reinterpret_cast<ptrdiff_t>(dst_i8 + dst_d.size() - 1) |
1453 | - (cpu_isa_traits<isa>::vlen - 1)); |
1454 | |
1455 | parallel_nd(jpp.mb, jpp.od, jpp.oh, jpp.ow, |
1456 | [&](dim_t n, dim_t od, dim_t oh, dim_t ow) { |
1457 | dim_t id = nstl::max(od * jpp.stride_d - jpp.f_pad, dim_t(0)); |
1458 | dim_t ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, dim_t(0)); |
1459 | dim_t iw = nstl::max(ow * jpp.stride_w - jpp.l_pad, dim_t(0)); |
1460 | |
1461 | dim_t kd_start |
1462 | = nstl::max(dim_t(0), jpp.f_pad - od * jpp.stride_d); |
1463 | dim_t kd_end = nstl::min( |
1464 | dim_t(jpp.kd), jpp.id + jpp.f_pad - od * jpp.stride_d); |
1465 | dim_t kh_start |
1466 | = nstl::max(dim_t(0), jpp.t_pad - oh * jpp.stride_h); |
1467 | dim_t kh_end = nstl::min( |
1468 | dim_t(jpp.kh), jpp.ih + jpp.t_pad - oh * jpp.stride_h); |
1469 | dim_t kw_start |
1470 | = nstl::max(dim_t(0), jpp.l_pad - ow * jpp.stride_w); |
1471 | dim_t kw_end = nstl::min( |
1472 | dim_t(jpp.kw), jpp.iw + jpp.l_pad - ow * jpp.stride_w); |
1473 | |
1474 | auto p = jit_uni_i8i8_pool_call_params_t(); |
1475 | p.src_i8 = &src_i8[get_offset(src_d, n, 0, id, ih, iw) |
1476 | * src_d.data_type_size()]; |
1477 | p.dst_i8 = &dst_i8[get_offset(dst_d, n, 0, od, oh, ow) |
1478 | * dst_d.data_type_size()]; |
1479 | p.dst_orig = dst_i8; |
1480 | p.kd_range = kd_end - kd_start; |
1481 | p.kh_range = kh_end - kh_start; |
1482 | p.kw_range = kw_end - kw_start; |
1483 | p.idivider = 1.0f |
1484 | / ((jpp.alg == pooling_avg_exclude_padding) |
1485 | ? p.kd_range * p.kh_range * p.kw_range |
1486 | : jpp.kd * jpp.kh * jpp.kw); |
1487 | p.src_safe_access = src_safe_access; |
1488 | p.dst_safe_access = dst_safe_access; |
1489 | p.post_ops_binary_rhs_arg_vec |
1490 | = post_ops_binary_rhs_arg_vec.data(); |
1491 | (*ker_)(&p); |
1492 | }); |
1493 | return status::success; |
1494 | } |
1495 | |
1496 | // Explicit instantiation only for supported <isa> values. |
1497 | // |
1498 | template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>; |
1499 | template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>; |
1500 | |
1501 | template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>; |
1502 | template struct jit_uni_i8i8_pooling_fwd_t<avx2>; |
1503 | |
1504 | template struct jit_uni_i8i8_pooling_fwd_ker_t<sse41>; |
1505 | template struct jit_uni_i8i8_pooling_fwd_t<sse41>; |
1506 | |
1507 | } // namespace x64 |
1508 | } // namespace cpu |
1509 | } // namespace impl |
1510 | } // namespace dnnl |
1511 | |