1/*******************************************************************************
2* Copyright 2017 - 2022 Intel Corporation
3* Licensed under the Apache License, Version 2.0 (the "License");
4* you may not use this file except in compliance with the License.
5* You may obtain a copy of the License at
6*
7* http://www.apache.org/licenses/LICENSE-2.0
8*
9* Unless required by applicable law or agreed to in writing, software
10* distributed under the License is distributed on an "AS IS" BASIS,
11* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12* See the License for the specific language governing permissions and
13* limitations under the License.
14*******************************************************************************/
15
16#include <functional>
17#include <new>
18
19#include "oneapi/dnnl/dnnl_types.h"
20
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/nstl.hpp"
24#include "common/type_helpers.hpp"
25
26#include "cpu/x64/jit_uni_pooling.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33namespace jit_uni_pooling_utils {
34
35struct trans_wrapper_t {
36 trans_wrapper_t(data_type_t inp_dt, dim_t inp_str, data_type_t out_dt,
37 dim_t out_str, dim_t ysize, dim_t xsize)
38 : inp_dt_size_(types::data_type_size(inp_dt))
39 , out_dt_size_(types::data_type_size(out_dt))
40 , inp_str_(inp_str)
41 , out_str_(out_str)
42 , nb_x_(xsize / 8)
43 , nb_y_(ysize / 8)
44 , x_tail_(xsize % 8)
45 , y_tail_(ysize % 8) {
46 using namespace cpu::x64::tr;
47
48 auto create_ker = [=](dim_t ys, dim_t y_inp_str, dim_t y_out_str,
49 dim_t xs, dim_t x_inp_str, dim_t x_out_str) {
50 tr::prb_t prb;
51 kernel_t::desc_t desc;
52
53 prb.ndims = 2;
54 prb.ioff = 0;
55 prb.ooff = 0;
56 prb.src_scale_type = scale_type_t::NONE;
57 prb.dst_scale_type = scale_type_t::NONE;
58 prb.beta = 0;
59 prb.nodes[0].ss = prb.nodes[1].ss = 1;
60
61 prb.itype = inp_dt;
62 prb.otype = out_dt;
63
64 prb.nodes[0].n = ys;
65 prb.nodes[0].is = y_inp_str;
66 prb.nodes[0].os = y_out_str;
67
68 prb.nodes[1].n = xs;
69 prb.nodes[1].is = x_inp_str;
70 prb.nodes[1].os = x_out_str;
71
72 prb.full_ndims = prb.ndims;
73
74 kernel_t::desc_init(desc, prb, 2);
75 return kernel_t::create(desc);
76 };
77
78 if (nb_x_ * nb_y_ > 0)
79 ker_.reset(create_ker(8, inp_str_, 1, 8, 1, out_str_));
80
81 if (x_tail_)
82 ker_x_tail_.reset(create_ker(8, inp_str_, 1, x_tail_, 1, out_str_));
83
84 if (y_tail_)
85 ker_y_tail_.reset(
86 create_ker(y_tail_, inp_str_, 1, xsize, 1, out_str_));
87 }
88
89 status_t create_kernel() {
90 if (ker_) CHECK(ker_->create_kernel());
91 if (ker_x_tail_) CHECK(ker_x_tail_->create_kernel());
92 if (ker_y_tail_) CHECK(ker_y_tail_->create_kernel());
93 return status::success;
94 }
95
96 void exec(const void *inp, void *out) {
97 dim_t x_blocked = nb_x_ * 8;
98 dim_t y_blocked = nb_y_ * 8;
99
100 auto call_ker = [&](tr::kernel_t &ker, dim_t inp_y, dim_t inp_x,
101 dim_t out_y, dim_t out_x) {
102 tr::call_param_t cp;
103 cp.src_scales = nullptr;
104 cp.dst_scales = nullptr;
105
106 dim_t inp_off = (inp_y * inp_str_ + inp_x) * inp_dt_size_;
107 dim_t out_off = (out_y * out_str_ + out_x) * out_dt_size_;
108 cp.in = (uint8_t *)inp + inp_off;
109 cp.out = (uint8_t *)out + out_off;
110 (ker)(&cp);
111 };
112
113 for (dim_t by = 0; by < nb_y_; by++) {
114 for (dim_t bx = 0; bx < nb_x_; bx++)
115 call_ker(*ker_, 8 * by, 8 * bx, 8 * bx, 8 * by);
116
117 if (x_tail_)
118 call_ker(*ker_x_tail_, 8 * by, x_blocked, x_blocked, 8 * by);
119 }
120 if (y_tail_) call_ker(*ker_y_tail_, y_blocked, 0, 0, y_blocked);
121 }
122
123 ~trans_wrapper_t() = default;
124
125private:
126 std::unique_ptr<tr::kernel_t> ker_;
127 std::unique_ptr<tr::kernel_t> ker_x_tail_;
128 std::unique_ptr<tr::kernel_t> ker_y_tail_;
129
130 const size_t inp_dt_size_;
131 const size_t out_dt_size_;
132
133 const dim_t inp_str_;
134 const dim_t out_str_;
135 const dim_t nb_x_;
136 const dim_t nb_y_;
137 const dim_t x_tail_;
138 const dim_t y_tail_;
139};
140
141struct trans_context_t {
142 std::unique_ptr<trans_wrapper_t> src_trans_ = nullptr;
143 std::unique_ptr<trans_wrapper_t> src_tail_trans_ = nullptr;
144 std::unique_ptr<trans_wrapper_t> ind_trans_ = nullptr;
145 std::unique_ptr<trans_wrapper_t> ind_tail_trans_ = nullptr;
146 std::unique_ptr<trans_wrapper_t> dst_trans_ = nullptr;
147 std::unique_ptr<trans_wrapper_t> dst_tail_trans_ = nullptr;
148
149 // NOLINTNEXTLINE(readability-make-member-function-const)
150 status_t create_kernel() {
151 if (src_trans_) CHECK(src_trans_->create_kernel());
152 if (src_tail_trans_) CHECK(src_tail_trans_->create_kernel());
153 if (ind_trans_) CHECK(ind_trans_->create_kernel());
154 if (ind_tail_trans_) CHECK(ind_tail_trans_->create_kernel());
155 if (dst_trans_) CHECK(dst_trans_->create_kernel());
156 if (dst_tail_trans_) CHECK(dst_tail_trans_->create_kernel());
157 return status::success;
158 }
159};
160
161static void trans_exec(trans_wrapper_t *trans, trans_wrapper_t *trans_tail,
162 dim_t cs, const void *inp, void *out, dim_t c_block) {
163
164 if (cs == c_block)
165 trans->exec(inp, out);
166 else
167 trans_tail->exec(inp, out);
168};
169
170template <typename src_data_t, typename dst_data_t>
171struct transpose_ncsp_to_block_fmt_t {
172 transpose_ncsp_to_block_fmt_t(trans_wrapper_t *transposer,
173 trans_wrapper_t *transposer_tail, const src_data_t *src_nscp_base,
174 const memory_desc_wrapper &src_nscp_desc,
175 dst_data_t *__restrict dst_blocked_base, dim_t block_size,
176 const jit_pool_conf_t &jpp, std::size_t offset_multiplier = 1u)
177 : transposer_(transposer)
178 , transposer_tail_(transposer_tail)
179 , c_without_padding_(jpp.c_without_padding)
180 , c_block_(jpp.c_block)
181 , src_nscp_base_(src_nscp_base)
182 , src_nscp_desc_(src_nscp_desc)
183 , dst_blocked_base_(dst_blocked_base)
184 , block_size_(block_size)
185 , offset_multiplier_(offset_multiplier) {}
186
187 void operator()(std::size_t ithr, int n, int b_c) const {
188 const dim_t cs
189 = nstl::min(c_without_padding_ - b_c * c_block_, c_block_);
190 const src_data_t *src_nscp = src_nscp_base_
191 + src_nscp_desc_.blk_off(n, b_c * c_block_, 0)
192 * offset_multiplier_;
193 dst_data_t *dst_blocked
194 = dst_blocked_base_ + ithr * block_size_ * offset_multiplier_;
195 trans_exec(transposer_, transposer_tail_, cs, src_nscp, dst_blocked,
196 c_block_);
197 }
198
199private:
200 trans_wrapper_t *transposer_;
201 trans_wrapper_t *transposer_tail_;
202 const int c_without_padding_;
203 const int c_block_;
204 const src_data_t *src_nscp_base_;
205 const memory_desc_wrapper &src_nscp_desc_;
206 dst_data_t *__restrict dst_blocked_base_;
207 const dim_t block_size_;
208 std::size_t offset_multiplier_;
209};
210
211template <typename src_data_t, typename dst_data_t>
212struct transpose_block_fmt_to_ncsp_t {
213
214 transpose_block_fmt_to_ncsp_t(trans_wrapper_t *transposer,
215 trans_wrapper_t *transposer_tail,
216 const src_data_t *__restrict src_blocked_base, dim_t block_size,
217 dst_data_t *dst_ncsp_base, const memory_desc_wrapper &dst_nscp_desc,
218 const jit_pool_conf_t &jpp, std::size_t offset_multiplier = 1u)
219 : transposer_(transposer)
220 , transposer_tail_(transposer_tail)
221 , c_without_padding_(jpp.c_without_padding)
222 , c_block_(jpp.c_block)
223 , src_blocked_base_(src_blocked_base)
224 , block_size_(block_size)
225 , dst_ncsp_base_(dst_ncsp_base)
226 , dst_nscp_desc_(dst_nscp_desc)
227 , offset_multiplier_(offset_multiplier) {}
228
229 void operator()(std::size_t ithr, int n, int b_c) const {
230 const dim_t cs
231 = nstl::min(c_without_padding_ - b_c * c_block_, c_block_);
232 const src_data_t *src_blocked
233 = src_blocked_base_ + ithr * block_size_ * offset_multiplier_;
234 dst_data_t *dst_ncsp = dst_ncsp_base_
235 + dst_nscp_desc_.blk_off(n, b_c * c_block_, 0)
236 * offset_multiplier_;
237 trans_exec(transposer_, transposer_tail_, cs, src_blocked, dst_ncsp,
238 c_block_);
239 }
240
241private:
242 trans_wrapper_t *transposer_;
243 trans_wrapper_t *transposer_tail_;
244 const int c_without_padding_;
245 const int c_block_;
246 const src_data_t *__restrict src_blocked_base_;
247 const dim_t block_size_;
248 dst_data_t *dst_ncsp_base_;
249 const memory_desc_wrapper &dst_nscp_desc_;
250 std::size_t offset_multiplier_;
251};
252
253template <typename wsp_data_t, impl::data_type_t d_type>
254class transpose_facade_base_t {
255public:
256 transpose_facade_base_t(const jit_pool_conf_t &jpp,
257 const memory_desc_wrapper &src_d, const memory_desc_wrapper &dst_d,
258 const memory_desc_wrapper &indices_d, const char *indices,
259 const data_type_t wsp_dt, const exec_ctx_t &ctx)
260 : src_sp_(static_cast<dim_t>(jpp.id) * jpp.ih * jpp.iw)
261 , dst_sp_(static_cast<dim_t>(jpp.od) * jpp.oh * jpp.ow)
262 , src_slice_(src_sp_ * jpp.c_block)
263 , dst_slice_(dst_sp_ * jpp.c_block)
264 , transpose_src_(jpp.tag_kind == jit_memory_tag_kind_t::ncsp)
265 , transpose_dst_(jpp.tag_kind == jit_memory_tag_kind_t::ncsp)
266 , src_d_(src_d)
267 , dst_d_(dst_d)
268 , indices_d_(indices_d)
269 , ind_dt_size_(
270 indices ? types::data_type_size(indices_d_.data_type()) : 0)
271 , cvt_slice_src_wsp_(nullptr)
272 , cvt_slice_dst_wsp_(nullptr)
273 , cvt_slice_ind_wsp_(nullptr)
274 , execute_transpose_input_(nullptr)
275 , execute_transpose_output_(nullptr) {
276
277 auto scratchpad = ctx.get_scratchpad_grantor();
278
279 if (transpose_src_)
280 cvt_slice_src_wsp_ = scratchpad.template get<wsp_data_t>(
281 memory_tracking::names::key_pool_src_plain2blocked_cvt);
282
283 if (transpose_dst_) {
284 cvt_slice_dst_wsp_ = scratchpad.template get<wsp_data_t>(
285 memory_tracking::names::key_pool_dst_plain2blocked_cvt);
286 cvt_slice_ind_wsp_ = scratchpad.template get<char>(
287 memory_tracking::names::key_pool_ind_plain2blocked_cvt);
288 }
289 }
290
291 inline bool should_transpose_src() const noexcept { return transpose_src_; }
292 inline bool should_transpose_dst() const noexcept { return transpose_dst_; }
293
294 const void *get_src_addr(
295 std::size_t ithr, int ih, const jit_pool_conf_t &jpp) const {
296 const wsp_data_t *const wsp = cvt_slice_src_wsp_ + ithr * src_slice_;
297 return static_cast<const void *>(&wsp[ih * jpp.iw * jpp.c_block]);
298 }
299
300 const void *get_dst_addr(
301 std::size_t ithr, int oh, const jit_pool_conf_t &jpp) const {
302 const wsp_data_t *const wsp = cvt_slice_dst_wsp_ + ithr * dst_slice_;
303 return static_cast<const void *>(&wsp[oh * jpp.ow * jpp.c_block]);
304 }
305
306 const void *get_indices_addr(
307 std::size_t ithr, int oh, const jit_pool_conf_t &jpp) const {
308 const char *const wsp
309 = cvt_slice_ind_wsp_ + ithr * dst_slice_ * ind_dt_size_;
310 return static_cast<const void *>(
311 &wsp[oh * jpp.ow * jpp.c_block * ind_dt_size_]);
312 }
313
314 const void *get_src_addr_3d(std::size_t ithr, int id, int ih,
315 const jit_pool_conf_t &jpp) const {
316 const wsp_data_t *const wsp = cvt_slice_src_wsp_ + ithr * src_slice_;
317 return static_cast<const void *>(&wsp[ih * jpp.iw * jpp.c_block
318 + id * jpp.ih * jpp.iw * jpp.c_block]);
319 }
320
321 const void *get_dst_addr_3d(std::size_t ithr, int od, int oh,
322 const jit_pool_conf_t &jpp) const {
323 const wsp_data_t *const wsp = cvt_slice_dst_wsp_ + ithr * dst_slice_;
324 return static_cast<const void *>(&wsp[oh * jpp.ow * jpp.c_block
325 + od * jpp.oh * jpp.ow * jpp.c_block]);
326 }
327
328 const void *get_indices_addr_3d(std::size_t ithr, int od, int oh,
329 const jit_pool_conf_t &jpp) const {
330 const char *const wsp
331 = cvt_slice_ind_wsp_ + ithr * dst_slice_ * ind_dt_size_;
332 return static_cast<const void *>(
333 &wsp[oh * jpp.ow * jpp.c_block * ind_dt_size_
334 + od * jpp.oh * jpp.ow * jpp.c_block * ind_dt_size_]);
335 }
336
337 void execute_transpose_input(std::size_t ithr, int n, int b_c) const {
338 execute_transpose_input_(ithr, n, b_c);
339 }
340
341 void execute_transpose_output(std::size_t ithr, int n, int b_c) const {
342 execute_transpose_output_(ithr, n, b_c);
343 }
344
345protected:
346 const dim_t src_sp_;
347 const dim_t dst_sp_;
348 const dim_t src_slice_;
349 const dim_t dst_slice_;
350
351 const bool transpose_src_;
352 const bool transpose_dst_;
353
354 const memory_desc_wrapper &src_d_;
355 const memory_desc_wrapper &dst_d_;
356 const memory_desc_wrapper &indices_d_;
357 const size_t ind_dt_size_;
358
359 wsp_data_t *__restrict cvt_slice_src_wsp_;
360 wsp_data_t *__restrict cvt_slice_dst_wsp_;
361 char *__restrict cvt_slice_ind_wsp_;
362
363 std::function<void(std::size_t, int, int)> execute_transpose_input_;
364 std::function<void(std::size_t, int, int)> execute_transpose_output_;
365};
366
367template <typename data_t, typename wsp_data_t, impl::data_type_t d_type>
368class fwd_pooling_transpose_facade_t
369 : public transpose_facade_base_t<wsp_data_t, d_type> {
370public:
371 fwd_pooling_transpose_facade_t(const jit_pool_conf_t &jpp,
372 trans_context_t *trans_ctx, const memory_desc_wrapper &src_d,
373 const memory_desc_wrapper &dst_d,
374 const memory_desc_wrapper &indices_d, const data_type_t wsp_dt,
375 const data_t *src, data_t *dst, char *indices,
376 const exec_ctx_t &ctx)
377 : transpose_facade_base_t<wsp_data_t, d_type>(
378 jpp, src_d, dst_d, indices_d, indices, wsp_dt, ctx) {
379
380 if (this->should_transpose_src()) {
381 this->execute_transpose_input_
382 = transpose_ncsp_to_block_fmt_t<data_t, wsp_data_t>(
383 trans_ctx->src_trans_.get(),
384 trans_ctx->src_tail_trans_.get(), src, this->src_d_,
385 this->cvt_slice_src_wsp_, this->src_slice_, jpp);
386 }
387
388 if (this->should_transpose_dst()) {
389 using namespace std::placeholders;
390 this->execute_transpose_output_ = std::bind(
391 [=](const transpose_block_fmt_to_ncsp_t<wsp_data_t, data_t>
392 &trans_dst,
393 transpose_block_fmt_to_ncsp_t<char, char>
394 &trans_indices,
395 std::size_t ithr, int n, int b_c) {
396 trans_dst(ithr, n, b_c);
397 if (indices) trans_indices(ithr, n, b_c);
398 },
399 transpose_block_fmt_to_ncsp_t<wsp_data_t, data_t>(
400 trans_ctx->dst_trans_.get(),
401 trans_ctx->dst_tail_trans_.get(),
402 this->cvt_slice_dst_wsp_, this->dst_slice_, dst,
403 this->dst_d_, jpp, 1u),
404 transpose_block_fmt_to_ncsp_t<char, char>(
405 trans_ctx->ind_trans_.get(),
406 trans_ctx->ind_tail_trans_.get(),
407 this->cvt_slice_ind_wsp_, this->dst_slice_, indices,
408 this->indices_d_, jpp, this->ind_dt_size_),
409 _1, _2, _3);
410 }
411 }
412};
413
414template <typename data_t, typename wsp_data_t, impl::data_type_t d_type>
415class bwd_pooling_transpose_facade_t
416 : public transpose_facade_base_t<wsp_data_t, d_type> {
417public:
418 bwd_pooling_transpose_facade_t(const jit_pool_conf_t &jpp,
419 trans_context_t *trans_ctx, const memory_desc_wrapper &src_d,
420 const memory_desc_wrapper &dst_d,
421 const memory_desc_wrapper &indices_d, const data_type_t wsp_dt,
422 data_t *src, const data_t *dst, const char *indices,
423 const exec_ctx_t &ctx)
424 : transpose_facade_base_t<wsp_data_t, d_type>(
425 jpp, src_d, dst_d, indices_d, indices, wsp_dt, ctx)
426 , c_tail_(jpp.c_without_padding % jpp.c_block) {
427
428 if (this->should_transpose_src())
429 this->execute_transpose_output_
430 = transpose_block_fmt_to_ncsp_t<wsp_data_t, data_t>(
431 trans_ctx->src_trans_.get(),
432 trans_ctx->src_tail_trans_.get(),
433 this->cvt_slice_src_wsp_, this->src_slice_, src,
434 this->src_d_, jpp, 1u);
435
436 if (this->should_transpose_dst()) {
437 using namespace std::placeholders;
438
439 this->execute_transpose_input_ = std::bind(
440 [=](const transpose_ncsp_to_block_fmt_t<data_t, wsp_data_t>
441 &trans_dst,
442 transpose_ncsp_to_block_fmt_t<char, char>
443 &trans_indices,
444 std::size_t ithr, int n, int b_c) {
445 trans_dst(ithr, n, b_c);
446 if (indices) trans_indices(ithr, n, b_c);
447 },
448 transpose_ncsp_to_block_fmt_t<data_t, wsp_data_t>(
449 trans_ctx->dst_trans_.get(),
450 trans_ctx->dst_tail_trans_.get(), dst, this->dst_d_,
451 this->cvt_slice_dst_wsp_, this->dst_slice_, jpp),
452 transpose_ncsp_to_block_fmt_t<char, char>(
453 trans_ctx->ind_trans_.get(),
454 trans_ctx->ind_tail_trans_.get(), indices,
455 this->indices_d_, this->cvt_slice_ind_wsp_,
456 this->dst_slice_, jpp, this->ind_dt_size_),
457 _1, _2, _3);
458 }
459 }
460
461 inline bool should_fill_input_c_tail_with_zeros() const noexcept {
462 return this->should_transpose_dst() && c_tail_ != 0;
463 }
464
465 void fill_input_c_tail_with_zeros(
466 std::size_t ithr, const jit_pool_conf_t &jpp) const {
467
468 wsp_data_t *__restrict wsp_ptr
469 = this->cvt_slice_dst_wsp_ + ithr * this->dst_slice_;
470 for_(dim_t s = 0; s < this->dst_sp_; s++)
471 for (dim_t c = c_tail_; c < jpp.c_block; c++)
472 wsp_ptr[s * jpp.c_block + c] = 0.f;
473
474 char *__restrict ind_ptr = this->cvt_slice_ind_wsp_
475 + ithr * this->dst_slice_ * this->ind_dt_size_;
476 for_(dim_t s = 0; s < this->dst_sp_; s++)
477 for_(dim_t c = c_tail_; c < jpp.c_block; c++)
478 for (size_t i = 0; i < this->ind_dt_size_; i++)
479 ind_ptr[(s * jpp.c_block + c) * this->ind_dt_size_ + i] = 0;
480 }
481
482private:
483 const dim_t c_tail_;
484};
485
486} // namespace jit_uni_pooling_utils
487
488template <cpu_isa_t isa, impl::data_type_t d_type>
489jit_uni_pooling_fwd_t<isa, d_type>::jit_uni_pooling_fwd_t(const pd_t *apd)
490 : primitive_t(apd), kernel_(nullptr), trans_ctx_(nullptr) {}
491
492template <cpu_isa_t isa, impl::data_type_t d_type>
493status_t jit_uni_pooling_fwd_t<isa, d_type>::init(engine_t *engine) {
494
495 CHECK(safe_ptr_assign(kernel_,
496 new jit_uni_pool_kernel<isa>(
497 pd()->jpp_, pd()->invariant_dst_md())));
498
499 if (pd()->jpp_.tag_kind == jit_memory_tag_kind_t::ncsp)
500 CHECK(init_ncsp_trans_ctx());
501 return kernel_->create_kernel();
502}
503
504template <cpu_isa_t isa, data_type_t d_type>
505status_t jit_uni_pooling_fwd_t<isa, d_type>::init_ncsp_trans_ctx() {
506 using namespace dnnl::impl;
507 using namespace jit_uni_pooling_utils;
508
509 const auto &jpp = pd()->jpp_;
510 trans_ctx_ = utils::make_unique<trans_context_t>();
511 const dim_t src_sp = static_cast<dim_t>(jpp.id) * jpp.ih * jpp.iw;
512 const dim_t dst_sp = static_cast<dim_t>(jpp.od) * jpp.oh * jpp.ow;
513 const auto res = std::div(jpp.c_without_padding, jpp.c_block);
514 const dim_t &nb_c = res.quot;
515 const dim_t &c_tail = res.rem;
516 const memory_desc_wrapper indices_d = pd()->workspace_md();
517 const bool have_indices = indices_d.data_type() != data_type::undef;
518 static constexpr auto wsp_dt = wsp_dt_;
519
520 if (nb_c) {
521 trans_ctx_->src_trans_ = utils::make_unique<trans_wrapper_t>(
522 d_type, src_sp, wsp_dt, jpp.c_block, jpp.c_block, src_sp);
523 trans_ctx_->dst_trans_ = utils::make_unique<trans_wrapper_t>(
524 wsp_dt, jpp.c_block, d_type, dst_sp, dst_sp, jpp.c_block);
525 if (have_indices)
526 trans_ctx_->ind_trans_ = utils::make_unique<trans_wrapper_t>(
527 indices_d.data_type(), jpp.c_block, indices_d.data_type(),
528 dst_sp, dst_sp, jpp.c_block);
529 }
530
531 if (c_tail) {
532 trans_ctx_->src_tail_trans_ = utils::make_unique<trans_wrapper_t>(
533 d_type, src_sp, wsp_dt, jpp.c_block, c_tail, src_sp);
534 trans_ctx_->dst_tail_trans_ = utils::make_unique<trans_wrapper_t>(
535 wsp_dt, jpp.c_block, d_type, dst_sp, dst_sp, c_tail);
536 if (have_indices)
537 trans_ctx_->ind_tail_trans_ = utils::make_unique<trans_wrapper_t>(
538 indices_d.data_type(), jpp.c_block, indices_d.data_type(),
539 dst_sp, dst_sp, c_tail);
540 }
541
542 return trans_ctx_->create_kernel();
543}
544
545template <cpu_isa_t isa, impl::data_type_t d_type>
546jit_uni_pooling_fwd_t<isa, d_type>::~jit_uni_pooling_fwd_t() = default;
547
548template <cpu_isa_t isa, data_type_t d_type>
549void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
550 data_t *dst, char *indices, const exec_ctx_t &ctx) const {
551
552 const memory_desc_wrapper src_d = pd()->src_md();
553 const memory_desc_wrapper dst_d = pd()->dst_md();
554 const memory_desc_wrapper indices_d = pd()->workspace_md();
555 const auto ind_dt_size
556 = indices ? types::data_type_size(indices_d.data_type()) : 0;
557 const auto &jpp = pd()->jpp_;
558 const auto post_ops_binary_rhs_arg_vec
559 = binary_injector::prepare_binary_args(jpp.post_ops, ctx);
560
561 using wsp_data_t = typename prec_traits<wsp_dt_>::type;
562 using namespace jit_uni_pooling_utils;
563
564 const auto transpose_facade
565 = fwd_pooling_transpose_facade_t<data_t, wsp_data_t, d_type>(jpp,
566 trans_ctx_.get(), src_d, dst_d, indices_d, wsp_dt_, src,
567 dst, indices, ctx);
568
569 const auto trans_src = transpose_facade.should_transpose_src();
570 const auto trans_dst = transpose_facade.should_transpose_dst();
571
572 const auto ker = [&](std::size_t ithr, int n, int b_c, int oh, int ur_bc) {
573 assert(ur_bc == jpp.ur_bc || ur_bc == jpp.ur_bc_tail);
574 auto arg = jit_pool_call_s();
575
576 const int ij = oh * jpp.stride_h;
577 const int i_t_overflow = nstl::max(0, jpp.t_pad - ij);
578 const int i_b_overflow
579 = nstl::max(jpp.ih, ij + jpp.kh - jpp.t_pad) - jpp.ih;
580 const int ih = nstl::max(ij - jpp.t_pad, 0);
581 assert(IMPLICATION(pd()->ndims() == 3, utils::everyone_is(0, ih, oh)));
582 const int c_off
583 = ((jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c_block
584 : 1)
585 * b_c;
586
587 if (trans_src)
588 arg.src = transpose_facade.get_src_addr(ithr, ih, jpp);
589 else
590 arg.src = static_cast<const void *>(
591 &src[src_d.blk_off(n, c_off, ih)]);
592
593 arg.dst_orig = dst;
594 if (trans_dst) {
595 arg.dst = transpose_facade.get_dst_addr(ithr, oh, jpp);
596 if (!types::is_zero_md(&jpp.tmp_md)) {
597 const memory_desc_wrapper tmp_d
598 = memory_desc_wrapper(jpp.tmp_md);
599 // offset needs to be f32
600 const int dt_scale
601 = sizeof(float) / types::data_type_size(d_type);
602 const auto blk_off = tmp_d.blk_off(n, c_off, oh) * dt_scale;
603 arg.dst_po_helper = static_cast<const void *>(&dst[blk_off]);
604 }
605 } else {
606 arg.dst = static_cast<const void *>(
607 &dst[dst_d.blk_off(n, c_off, oh)]);
608 }
609
610 if (indices) {
611 if (trans_dst)
612 arg.indices = transpose_facade.get_indices_addr(ithr, oh, jpp);
613 else {
614 const size_t ind_off = indices_d.blk_off(n, c_off, oh);
615 arg.indices = static_cast<const void *>(
616 &indices[ind_off * ind_dt_size]);
617 }
618 }
619 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
620 arg.kh_padding_shift = i_t_overflow * jpp.kw;
621 arg.ker_area_h = static_cast<float>(jpp.kh
622 - nstl::max(0, oh * jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih)
623 - nstl::max(0, jpp.t_pad - oh * jpp.stride_h));
624 arg.ur_bc = ur_bc;
625 arg.b_c = b_c;
626 arg.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
627 (*kernel_)(&arg);
628 };
629
630 const int nthr = jpp.nthr;
631
632 if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
633 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
634 parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](dim_t n, dim_t oh, dim_t b2_c) {
635 const auto b_c = b2_c * jpp.ur_bc;
636 const auto ur_bc = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c);
637 ker(0, n, b_c, oh, ur_bc);
638 });
639 } else {
640 if (trans_src || trans_dst) {
641 // ncsp format
642 parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
643 [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) {
644 if (trans_src)
645 transpose_facade.execute_transpose_input(
646 ithr, n, b_c);
647 for (dim_t oh = 0; oh < jpp.oh; ++oh)
648 ker(ithr, n, b_c, oh, 1);
649 if (trans_dst)
650 transpose_facade.execute_transpose_output(
651 ithr, n, b_c);
652 });
653 } else {
654 // nChw16c, nChw8c format
655 parallel(nthr, [&](dim_t ithr, dim_t nthr) {
656 dim_t work_amount = jpp.mb * jpp.nb_c * jpp.oh;
657 if (ithr >= work_amount) return;
658
659 dim_t start {0}, end {0};
660 dim_t n {0}, b_c {0}, oh {0};
661
662 balance211(work_amount, nthr, ithr, start, end);
663 utils::nd_iterator_init(
664 start, n, jpp.mb, b_c, jpp.nb_c, oh, jpp.oh);
665
666 for (dim_t iwork = start; iwork < end; ++iwork) {
667 ker(ithr, n, b_c, oh, 1);
668 utils::nd_iterator_step(
669 n, jpp.mb, b_c, jpp.nb_c, oh, jpp.oh);
670 }
671 });
672 }
673 }
674}
675
676template <cpu_isa_t isa, data_type_t d_type>
677void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
678 data_t *dst, char *indices, const exec_ctx_t &ctx) const {
679
680 const auto &jpp = pd()->jpp_;
681 const memory_desc_wrapper src_d(pd()->src_md());
682 const memory_desc_wrapper dst_d(pd()->dst_md());
683 const memory_desc_wrapper indices_d(pd()->workspace_md());
684 const size_t ind_dt_size
685 = indices ? types::data_type_size(indices_d.data_type()) : 0;
686 const auto post_ops_binary_rhs_arg_vec
687 = binary_injector::prepare_binary_args(jpp.post_ops, ctx);
688
689 using wsp_data_t = typename prec_traits<wsp_dt_>::type;
690 using namespace jit_uni_pooling_utils;
691 static constexpr int first_ithr = 0;
692
693 const auto transpose_facade
694 = fwd_pooling_transpose_facade_t<data_t, wsp_data_t, d_type>(jpp,
695 trans_ctx_.get(), src_d, dst_d, indices_d, wsp_dt_, src,
696 dst, indices, ctx);
697
698 const auto trans_src = transpose_facade.should_transpose_src();
699 const auto trans_dst = transpose_facade.should_transpose_dst();
700
701 auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
702 int d_b_overflow, int ur_bc, int ithr) {
703 assert(ur_bc == jpp.ur_bc || ur_bc == jpp.ur_bc_tail);
704 auto arg = jit_pool_call_s();
705
706 const int ij = oh * jpp.stride_h;
707 const int i_t_overflow = nstl::max(0, jpp.t_pad - ij);
708 const int i_b_overflow
709 = nstl::max(jpp.ih, ij + jpp.kh - jpp.t_pad) - jpp.ih;
710 const int ih = nstl::max(ij - jpp.t_pad, 0);
711 const int c_off
712 = ((jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c_block
713 : 1)
714 * b_c;
715
716 if (trans_src)
717 arg.src = transpose_facade.get_src_addr_3d(ithr, id, ih, jpp);
718 else
719 arg.src = &src[src_d.blk_off(n, c_off, id, ih)];
720
721 arg.dst_orig = dst;
722 if (trans_dst) {
723 arg.dst = transpose_facade.get_dst_addr_3d(ithr, od, oh, jpp);
724 if (!types::is_zero_md(&jpp.tmp_md)) {
725 const memory_desc_wrapper tmp_d
726 = memory_desc_wrapper(jpp.tmp_md);
727 // offset needs to be f32
728 const int dt_scale
729 = sizeof(float) / types::data_type_size(d_type);
730 const auto blk_off = tmp_d.blk_off(n, c_off, od, oh) * dt_scale;
731 arg.dst_po_helper = static_cast<const void *>(&dst[blk_off]);
732 }
733 } else {
734 arg.dst = &dst[dst_d.blk_off(n, c_off, od, oh)];
735 }
736
737 if (indices) {
738 if (trans_dst) {
739 arg.indices = transpose_facade.get_indices_addr_3d(
740 ithr, od, oh, jpp);
741 } else {
742 const size_t ind_off = indices_d.blk_off(n, c_off, od, oh);
743 arg.indices = &indices[ind_off * ind_dt_size];
744 }
745 }
746
747 arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
748 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
749 arg.kh_padding_shift
750 = i_t_overflow * jpp.kw + d_t_overflow * jpp.kw * jpp.kh;
751 arg.kd_padding_shift = (i_t_overflow + i_b_overflow) * jpp.kw;
752 arg.ker_area_h = (float)(jpp.kh
753 - nstl::max(0,
754 oh * jpp.stride_h - jpp.t_pad + jpp.kh
755 - jpp.ih)
756 - nstl::max(0, jpp.t_pad - oh * jpp.stride_h))
757 * (jpp.kd
758 - nstl::max(0,
759 od * jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id)
760 - nstl::max(0, jpp.f_pad - od * jpp.stride_d));
761
762 arg.ur_bc = ur_bc;
763 arg.b_c = b_c;
764 arg.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
765 (*kernel_)(&arg);
766 };
767
768 const int nthr = jpp.nthr;
769
770 if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
771 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
772 parallel_nd(jpp.mb, jpp.od, nb2_c, [&](dim_t n, dim_t od, dim_t b2_c) {
773 const dim_t b_c = b2_c * jpp.ur_bc;
774 const dim_t ur_bc = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c);
775
776 const dim_t ik = od * jpp.stride_d;
777 const dim_t d_t_overflow = nstl::max(dim_t(0), jpp.f_pad - ik);
778 const dim_t d_b_overflow
779 = nstl::max(dim_t(jpp.id), ik + jpp.kd - jpp.f_pad)
780 - jpp.id;
781 const dim_t id = nstl::max(ik - jpp.f_pad, dim_t(0));
782 for (dim_t oh = 0; oh < jpp.oh; ++oh) {
783 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, ur_bc,
784 first_ithr);
785 }
786 });
787 } else {
788 if (trans_src || trans_dst) {
789 parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
790 [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) {
791 if (trans_src)
792 transpose_facade.execute_transpose_input(
793 ithr, n, b_c);
794
795 for (int od = 0; od < jpp.od; ++od) {
796 const int ik = od * jpp.stride_d;
797 const int d_t_overflow
798 = nstl::max(0, jpp.f_pad - ik);
799 const int d_b_overflow
800 = nstl::max(jpp.id, ik + jpp.kd - jpp.f_pad)
801 - jpp.id;
802 const int id = nstl::max(ik - jpp.f_pad, 0);
803 for (int oh = 0; oh < jpp.oh; ++oh) {
804 ker(n, b_c, od, oh, id, d_t_overflow,
805 d_b_overflow, 1, ithr);
806 }
807 }
808
809 if (trans_dst)
810 transpose_facade.execute_transpose_output(
811 ithr, n, b_c);
812 });
813 } else {
814 parallel_nd(jpp.mb, jpp.nb_c, jpp.od,
815 [&](dim_t n, dim_t b_c, dim_t od) {
816 const int ik = od * jpp.stride_d;
817 const int d_t_overflow = nstl::max(0, jpp.f_pad - ik);
818 const int d_b_overflow
819 = nstl::max(jpp.id, ik + jpp.kd - jpp.f_pad)
820 - jpp.id;
821 const int id = nstl::max(ik - jpp.f_pad, 0);
822 for (int oh = 0; oh < jpp.oh; ++oh) {
823 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
824 1, first_ithr);
825 }
826 });
827 }
828 }
829}
830
831template <cpu_isa_t isa, data_type_t d_type>
832jit_uni_pooling_bwd_t<isa, d_type>::jit_uni_pooling_bwd_t(const pd_t *apd)
833 : primitive_t(apd)
834 , kernel_(utils::make_unique<jit_uni_pool_kernel<isa>>(
835 pd()->jpp_, pd()->invariant_dst_md()))
836 , trans_ctx_(nullptr) {}
837
838template <cpu_isa_t isa, data_type_t d_type>
839jit_uni_pooling_bwd_t<isa, d_type>::~jit_uni_pooling_bwd_t() = default;
840
841template <cpu_isa_t isa, data_type_t d_type>
842status_t jit_uni_pooling_bwd_t<isa, d_type>::init_ncsp_trans_ctx() {
843 using namespace dnnl::impl;
844 using namespace jit_uni_pooling_utils;
845
846 const auto &jpp = pd()->jpp_;
847 trans_ctx_ = utils::make_unique<trans_context_t>();
848 const dim_t diff_src_sp = static_cast<dim_t>(jpp.id) * jpp.ih * jpp.iw;
849 const dim_t diff_dst_sp = static_cast<dim_t>(jpp.od) * jpp.oh * jpp.ow;
850 const auto res = std::div(jpp.c_without_padding, jpp.c_block);
851 const dim_t &nb_c = res.quot;
852 const dim_t &c_tail = res.rem;
853 const memory_desc_wrapper indices_d = pd()->workspace_md();
854 const bool have_indices = indices_d.data_type() != data_type::undef;
855 static constexpr auto wsp_dt = wsp_dt_;
856
857 if (nb_c) {
858 trans_ctx_->dst_trans_ = utils::make_unique<trans_wrapper_t>(d_type,
859 diff_dst_sp, wsp_dt, jpp.c_block, jpp.c_block, diff_dst_sp);
860 trans_ctx_->src_trans_ = utils::make_unique<trans_wrapper_t>(wsp_dt,
861 jpp.c_block, d_type, diff_src_sp, diff_src_sp, jpp.c_block);
862 if (have_indices)
863 trans_ctx_->ind_trans_ = utils::make_unique<trans_wrapper_t>(
864 indices_d.data_type(), diff_dst_sp, indices_d.data_type(),
865 jpp.c_block, jpp.c_block, diff_dst_sp);
866 }
867 if (c_tail) {
868 trans_ctx_->dst_tail_trans_ = utils::make_unique<trans_wrapper_t>(
869 d_type, diff_dst_sp, wsp_dt, jpp.c_block, c_tail, diff_dst_sp);
870 trans_ctx_->src_tail_trans_ = utils::make_unique<trans_wrapper_t>(
871 wsp_dt, jpp.c_block, d_type, diff_src_sp, diff_src_sp, c_tail);
872 if (have_indices)
873 trans_ctx_->ind_tail_trans_ = utils::make_unique<trans_wrapper_t>(
874 indices_d.data_type(), diff_dst_sp, indices_d.data_type(),
875 jpp.c_block, c_tail, diff_dst_sp);
876 }
877
878 return trans_ctx_->create_kernel();
879}
880
881template <cpu_isa_t isa, data_type_t d_type>
882status_t jit_uni_pooling_bwd_t<isa, d_type>::init(engine_t *engine) {
883 if (pd()->jpp_.tag_kind == jit_memory_tag_kind_t::ncsp)
884 CHECK(init_ncsp_trans_ctx());
885 return kernel_->create_kernel();
886}
887
888template <cpu_isa_t isa, data_type_t d_type>
889void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward(
890 const data_t *diff_dst, const char *indices, data_t *diff_src,
891 const exec_ctx_t &ctx) const {
892
893 using namespace jit_uni_pooling_utils;
894 using wsp_data_t = typename prec_traits<wsp_dt_>::type;
895
896 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
897 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
898 const memory_desc_wrapper indices_d(pd()->workspace_md());
899 const size_t ind_dt_size
900 = indices ? types::data_type_size(indices_d.data_type()) : 0;
901 const auto &jpp = pd()->jpp_;
902 const auto transpose_facade
903 = jit_uni_pooling_utils::bwd_pooling_transpose_facade_t<data_t,
904 wsp_data_t, d_type>(jpp, trans_ctx_.get(), diff_src_d,
905 diff_dst_d, indices_d, wsp_dt_, diff_src, diff_dst, indices,
906 ctx);
907
908 auto get_first_ih = [&](int oh) {
909 return nstl::min(nstl::max(oh * jpp.stride_h - jpp.t_pad, 0), jpp.ih);
910 };
911
912 auto get_last_ih = [&](int oh) {
913 return nstl::min(
914 nstl::max(oh * jpp.stride_h - jpp.t_pad + jpp.kh, 0), jpp.ih);
915 };
916 const auto ker = [&](int ithr, int n, int b_c, int oh, int ur_bc) {
917 auto arg = jit_pool_call_s();
918
919 const int ih = get_first_ih(oh);
920 assert(IMPLICATION(pd()->ndims() == 3, utils::everyone_is(0, ih, oh)));
921 assert(pd()->ndims() != 3 || utils::everyone_is(0, ih, oh));
922
923 const auto c_off = jpp.is_plain() ? b_c * jpp.c_block : b_c;
924 if (transpose_facade.should_transpose_src())
925 arg.src = transpose_facade.get_src_addr(ithr, ih, jpp);
926 else
927 arg.src = &diff_src[diff_src_d.blk_off(n, c_off, ih)];
928
929 if (transpose_facade.should_transpose_dst())
930 arg.dst = transpose_facade.get_dst_addr(ithr, oh, jpp);
931 else
932 arg.dst = &diff_dst[diff_dst_d.blk_off(n, c_off, oh)];
933
934 if (indices) {
935 if (transpose_facade.should_transpose_dst())
936 arg.indices = transpose_facade.get_indices_addr(ithr, oh, jpp);
937
938 else {
939 const size_t ind_off = indices_d.blk_off(n, c_off, oh);
940 arg.indices = &indices[ind_off * ind_dt_size];
941 }
942 }
943
944 const int zero_ih_start = (oh == 0) ? 0 : get_last_ih(oh - 1);
945 const int zero_ih_end = (oh == jpp.oh - 1) ? jpp.ih : get_last_ih(oh);
946
947 arg.zero_id = 1;
948 arg.zero_ih = zero_ih_end - zero_ih_start;
949 if (transpose_facade.should_transpose_src())
950 arg.zero_ptr
951 = transpose_facade.get_src_addr(ithr, zero_ih_start, jpp);
952 else
953 arg.zero_ptr
954 = &diff_src[diff_src_d.blk_off(n, c_off, zero_ih_start, 0)];
955
956 const int i_t_overflow = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
957 const int i_b_overflow
958 = nstl::max(jpp.ih, oh * jpp.stride_h + jpp.kh - jpp.t_pad)
959 - jpp.ih;
960 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
961 arg.kh_padding_shift = i_t_overflow * jpp.kw;
962 arg.ker_area_h = static_cast<float>(jpp.kh
963 - nstl::max(0, oh * jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih)
964 - nstl::max(0, jpp.t_pad - oh * jpp.stride_h));
965
966 arg.ur_bc = ur_bc;
967 arg.b_c = b_c;
968 (*kernel_)(&arg);
969 };
970
971 auto process_block = [&](int ithr, int n, int b_c, int ur_bc) {
972 if (transpose_facade.should_transpose_dst())
973 transpose_facade.execute_transpose_input(ithr, n, b_c);
974
975 for (int oh = 0; oh < jpp.oh; ++oh)
976 ker(ithr, n, b_c, oh, ur_bc);
977
978 if (transpose_facade.should_transpose_src())
979 transpose_facade.execute_transpose_output(ithr, n, b_c);
980 };
981
982 const int nthr = jpp.nthr;
983
984 parallel(nthr, [&](int ithr, int nthr) {
985 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
986 const std::size_t work_amount
987 = static_cast<std::size_t>(jpp.mb) * nb2_c;
988 if (static_cast<std::size_t>(ithr) >= work_amount) return;
989
990 if (transpose_facade.should_fill_input_c_tail_with_zeros())
991 transpose_facade.fill_input_c_tail_with_zeros(ithr, jpp);
992
993 std::size_t start {0}, end {0};
994 balance211(work_amount, nthr, ithr, start, end);
995 int n {0}, b2_c {0};
996 utils::nd_iterator_init(start, n, jpp.mb, b2_c, nb2_c);
997 for (size_t iwork = start; iwork < end; ++iwork) {
998 const auto b_c = b2_c * jpp.ur_bc;
999 const auto ur_bc = nstl::min(jpp.ur_bc, jpp.nb_c - b_c);
1000
1001 process_block(ithr, n, b_c, ur_bc);
1002 utils::nd_iterator_step(n, jpp.mb, b2_c, nb2_c);
1003 }
1004 });
1005}
1006
1007template <cpu_isa_t isa, data_type_t d_type>
1008void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
1009 const data_t *diff_dst, const char *indices, data_t *diff_src,
1010 const exec_ctx_t &ctx) const {
1011 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
1012 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1013 const memory_desc_wrapper indices_d(pd()->workspace_md());
1014 const size_t ind_dt_size
1015 = indices ? types::data_type_size(indices_d.data_type()) : 0;
1016
1017 const auto &jpp = pd()->jpp_;
1018
1019 using wsp_data_t = typename prec_traits<wsp_dt_>::type;
1020 using namespace jit_uni_pooling_utils;
1021 static constexpr int first_ithr = 0;
1022
1023 const auto transpose_facade
1024 = bwd_pooling_transpose_facade_t<data_t, wsp_data_t, d_type>(jpp,
1025 trans_ctx_.get(), diff_src_d, diff_dst_d, indices_d,
1026 wsp_dt_, diff_src, diff_dst, indices, ctx);
1027
1028 const auto trans_src = transpose_facade.should_transpose_src();
1029 const auto trans_dst = transpose_facade.should_transpose_dst();
1030
1031 auto get_last_ih = [&](int oh) {
1032 return nstl::min(
1033 nstl::max(oh * jpp.stride_h - jpp.t_pad + jpp.kh, 0), jpp.ih);
1034 };
1035
1036 auto get_last_id = [&](int od) {
1037 return nstl::min(
1038 nstl::max(od * jpp.stride_d - jpp.f_pad + jpp.kd, 0), jpp.id);
1039 };
1040
1041 auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
1042 int d_b_overflow, bool zero_inp, int kd, int ur_bc,
1043 int ithr) {
1044 auto arg = jit_pool_call_s();
1045
1046 const int ij = oh * jpp.stride_h;
1047 const int i_t_overflow = nstl::max(0, jpp.t_pad - ij);
1048 const int i_b_overflow
1049 = nstl::max(jpp.ih, ij + jpp.kh - jpp.t_pad) - jpp.ih;
1050 const int ih = nstl::max(ij - jpp.t_pad, 0);
1051 const int c_off
1052 = ((jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c_block
1053 : 1)
1054 * b_c;
1055
1056 if (trans_src)
1057 arg.src = transpose_facade.get_src_addr_3d(ithr, id + kd, ih, jpp);
1058 else
1059 arg.src = (const void *)&diff_src[diff_src_d.blk_off(
1060 n, c_off, id + kd, ih)];
1061
1062 if (trans_dst)
1063 arg.dst = transpose_facade.get_dst_addr_3d(ithr, od, oh, jpp);
1064 else
1065 arg.dst = (const void
1066 *)&diff_dst[diff_dst_d.blk_off(n, c_off, od, oh)];
1067
1068 if (indices) {
1069 if (trans_dst) {
1070 arg.indices = transpose_facade.get_indices_addr_3d(
1071 ithr, od, oh, jpp);
1072 } else {
1073 const size_t ind_off = indices_d.blk_off(n, c_off, od, oh);
1074 arg.indices = (const void *)&indices[ind_off * ind_dt_size];
1075 }
1076 }
1077
1078 if (zero_inp) {
1079 const int zero_id_start = (od == 0) ? 0 : get_last_id(od - 1);
1080 const int zero_id_end
1081 = (od == jpp.od - 1) ? jpp.id : get_last_id(od);
1082
1083 arg.zero_id = zero_id_end - zero_id_start;
1084
1085 const int zero_ih_start = (oh == 0) ? 0 : get_last_ih(oh - 1);
1086 const int zero_ih_end
1087 = (oh == jpp.oh - 1) ? jpp.ih : get_last_ih(oh);
1088 arg.zero_ih = zero_ih_end - zero_ih_start;
1089
1090 if (trans_src)
1091 arg.zero_ptr = transpose_facade.get_src_addr_3d(
1092 ithr, zero_id_start, zero_ih_start, jpp);
1093 else
1094 arg.zero_ptr = &diff_src[diff_src_d.blk_off(
1095 n, c_off, zero_id_start, zero_ih_start, 0)];
1096 } else {
1097 arg.zero_id = 0;
1098 arg.zero_ih = 0;
1099 }
1100
1101 arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
1102 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
1103 arg.kh_padding_shift = i_t_overflow * jpp.kw
1104 + d_t_overflow * jpp.kw * jpp.kh + kd * jpp.kw * jpp.kh;
1105 arg.kd_padding_shift = (i_t_overflow + i_b_overflow) * jpp.kw;
1106 arg.ker_area_h = (float)(jpp.kh
1107 - nstl::max(0,
1108 oh * jpp.stride_h - jpp.t_pad + jpp.kh
1109 - jpp.ih)
1110 - nstl::max(0, jpp.t_pad - oh * jpp.stride_h))
1111 * (jpp.kd
1112 - nstl::max(0,
1113 od * jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id)
1114 - nstl::max(0, jpp.f_pad - od * jpp.stride_d));
1115
1116 arg.ur_bc = ur_bc;
1117 arg.b_c = b_c;
1118 (*kernel_)(&arg);
1119 };
1120
1121 auto process_simple = [&](int n, int b_c, int od, int ur_bc, int ithr) {
1122 const int ik = od * jpp.stride_d;
1123 const int d_t_overflow = nstl::max(0, jpp.f_pad - ik);
1124 const int d_b_overflow
1125 = nstl::max(jpp.id, ik + jpp.kd - jpp.f_pad) - jpp.id;
1126 const int id = nstl::max(ik - jpp.f_pad, 0);
1127
1128 for (int oh = 0; oh < jpp.oh; ++oh) {
1129 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, true, 0, ur_bc,
1130 ithr);
1131 }
1132 };
1133
1134 const int nthr = jpp.nthr;
1135
1136 if (jpp.simple_alg) {
1137 if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
1138 const dim_t nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
1139 parallel_nd(
1140 jpp.mb, jpp.od, nb2_c, [&](dim_t n, dim_t od, dim_t b2_c) {
1141 const dim_t b_c = b2_c * jpp.ur_bc;
1142 const dim_t ur_bc
1143 = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c);
1144 process_simple(n, b_c, od, ur_bc, first_ithr);
1145 });
1146 } else {
1147 assert(jpp.ur_bc == 1);
1148 if (trans_src || trans_dst) {
1149 parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
1150 [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) {
1151 if (trans_src)
1152 transpose_facade.execute_transpose_input(
1153 ithr, n, b_c);
1154 for (int od = 0; od < jpp.od; ++od) {
1155 process_simple(n, b_c, od, 1, ithr);
1156 }
1157 if (trans_dst)
1158 transpose_facade.execute_transpose_output(
1159 ithr, n, b_c);
1160 });
1161 } else {
1162 parallel_nd(jpp.mb, jpp.nb_c, jpp.od,
1163 [&](dim_t n, dim_t b_c, dim_t od) {
1164 process_simple(n, b_c, od, 1, first_ithr);
1165 });
1166 }
1167 }
1168 } else {
1169 const data_t zero_val = 0;
1170 if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
1171 const size_t chunk_size = (size_t)jpp.ih * jpp.iw * jpp.c;
1172 parallel_nd(jpp.mb, jpp.id, [&](dim_t n, dim_t id) {
1173 const size_t offset = ((size_t)n * jpp.id + id) * chunk_size;
1174 PRAGMA_OMP_SIMD()
1175 for (size_t idx = 0; idx < chunk_size; ++idx)
1176 diff_src[offset + idx] = zero_val;
1177 });
1178 } else {
1179 if (!trans_src) {
1180 const size_t chunk_size
1181 = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block;
1182 parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
1183 [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) {
1184 const size_t offset
1185 = ((size_t)n * jpp.nb_c + b_c) * chunk_size;
1186 PRAGMA_OMP_SIMD()
1187 for (size_t idx = 0; idx < chunk_size; ++idx)
1188 diff_src[offset + idx] = zero_val;
1189 });
1190 }
1191 }
1192
1193 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
1194 if (trans_src || trans_dst) {
1195 parallel_nd_ext(nthr, jpp.mb, nb2_c,
1196 [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b2_c) {
1197 const dim_t b_c = b2_c * jpp.ur_bc;
1198
1199 if (trans_dst) {
1200 transpose_facade.execute_transpose_input(
1201 ithr, n, b_c);
1202
1203 size_t block_size = jpp.c_block * jpp.id * jpp.ih
1204 * jpp.iw * jpp.dt_size;
1205
1206 const void *src = transpose_facade.get_src_addr_3d(
1207 ithr, 0, 0, jpp);
1208 std::memset((void *)src, zero_val, block_size);
1209 }
1210
1211 for (dim_t kd = 0; kd < jpp.kd; ++kd) {
1212 const dim_t ur_bc = nstl::min(
1213 dim_t(jpp.ur_bc), jpp.nb_c - b_c);
1214 for (int od = 0; od < jpp.od; ++od) {
1215 const dim_t ik = od * jpp.stride_d;
1216 const dim_t d_t_overflow
1217 = nstl::max(dim_t(0), jpp.f_pad - ik);
1218 const dim_t d_b_overflow
1219 = nstl::max(dim_t(jpp.id),
1220 ik + jpp.kd - jpp.f_pad)
1221 - jpp.id;
1222 if (kd >= jpp.kd - d_t_overflow - d_b_overflow)
1223 continue;
1224 const dim_t id
1225 = nstl::max(ik - jpp.f_pad, dim_t(0));
1226 for (dim_t oh = 0; oh < jpp.oh; ++oh) {
1227 ker(n, b_c, od, oh, id, d_t_overflow,
1228 d_b_overflow, false, kd, ur_bc,
1229 ithr);
1230 }
1231 }
1232 }
1233
1234 if (trans_src)
1235 transpose_facade.execute_transpose_output(
1236 ithr, n, b_c);
1237 });
1238 } else {
1239 for (dim_t kd = 0; kd < jpp.kd; ++kd) {
1240 parallel_nd(jpp.mb, nb2_c, [&](dim_t n, dim_t b2_c) {
1241 const dim_t b_c = b2_c * jpp.ur_bc;
1242 const dim_t ur_bc
1243 = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c);
1244 for (int od = 0; od < jpp.od; ++od) {
1245 const dim_t ik = od * jpp.stride_d;
1246 const dim_t d_t_overflow
1247 = nstl::max(dim_t(0), jpp.f_pad - ik);
1248 const dim_t d_b_overflow
1249 = nstl::max(dim_t(jpp.id),
1250 ik + jpp.kd - jpp.f_pad)
1251 - jpp.id;
1252 if (kd >= jpp.kd - d_t_overflow - d_b_overflow)
1253 continue;
1254 const dim_t id = nstl::max(ik - jpp.f_pad, dim_t(0));
1255 for (dim_t oh = 0; oh < jpp.oh; ++oh) {
1256 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
1257 false, kd, ur_bc, first_ithr);
1258 }
1259 }
1260 });
1261 }
1262 }
1263 }
1264}
1265
1266template struct jit_uni_pooling_fwd_t<sse41, data_type::f32>;
1267template struct jit_uni_pooling_bwd_t<sse41, data_type::f32>;
1268template struct jit_uni_pooling_fwd_t<avx, data_type::f32>;
1269template struct jit_uni_pooling_bwd_t<avx, data_type::f32>;
1270template struct jit_uni_pooling_fwd_t<avx2, data_type::f32>;
1271template struct jit_uni_pooling_fwd_t<avx2_vnni_2, data_type::bf16>;
1272template struct jit_uni_pooling_fwd_t<avx2_vnni_2, data_type::f16>;
1273template struct jit_uni_pooling_bwd_t<avx2, data_type::f32>;
1274template struct jit_uni_pooling_fwd_t<avx512_core, data_type::f32>;
1275template struct jit_uni_pooling_bwd_t<avx512_core, data_type::f32>;
1276template struct jit_uni_pooling_fwd_t<avx512_core, data_type::bf16>;
1277template struct jit_uni_pooling_bwd_t<avx512_core, data_type::bf16>;
1278template struct jit_uni_pooling_fwd_t<avx512_core_fp16, data_type::f16>;
1279template struct jit_uni_pooling_bwd_t<avx512_core_fp16, data_type::f16>;
1280
1281} // namespace x64
1282} // namespace cpu
1283} // namespace impl
1284} // namespace dnnl
1285
1286// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1287