1 | /******************************************************************************* |
2 | * Copyright 2017 - 2022 Intel Corporation |
3 | * Licensed under the Apache License, Version 2.0 (the "License"); |
4 | * you may not use this file except in compliance with the License. |
5 | * You may obtain a copy of the License at |
6 | * |
7 | * http://www.apache.org/licenses/LICENSE-2.0 |
8 | * |
9 | * Unless required by applicable law or agreed to in writing, software |
10 | * distributed under the License is distributed on an "AS IS" BASIS, |
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | * See the License for the specific language governing permissions and |
13 | * limitations under the License. |
14 | *******************************************************************************/ |
15 | |
16 | #include <functional> |
17 | #include <new> |
18 | |
19 | #include "oneapi/dnnl/dnnl_types.h" |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/nstl.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | |
26 | #include "cpu/x64/jit_uni_pooling.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | namespace jit_uni_pooling_utils { |
34 | |
35 | struct trans_wrapper_t { |
36 | trans_wrapper_t(data_type_t inp_dt, dim_t inp_str, data_type_t out_dt, |
37 | dim_t out_str, dim_t ysize, dim_t xsize) |
38 | : inp_dt_size_(types::data_type_size(inp_dt)) |
39 | , out_dt_size_(types::data_type_size(out_dt)) |
40 | , inp_str_(inp_str) |
41 | , out_str_(out_str) |
42 | , nb_x_(xsize / 8) |
43 | , nb_y_(ysize / 8) |
44 | , x_tail_(xsize % 8) |
45 | , y_tail_(ysize % 8) { |
46 | using namespace cpu::x64::tr; |
47 | |
48 | auto create_ker = [=](dim_t ys, dim_t y_inp_str, dim_t y_out_str, |
49 | dim_t xs, dim_t x_inp_str, dim_t x_out_str) { |
50 | tr::prb_t prb; |
51 | kernel_t::desc_t desc; |
52 | |
53 | prb.ndims = 2; |
54 | prb.ioff = 0; |
55 | prb.ooff = 0; |
56 | prb.src_scale_type = scale_type_t::NONE; |
57 | prb.dst_scale_type = scale_type_t::NONE; |
58 | prb.beta = 0; |
59 | prb.nodes[0].ss = prb.nodes[1].ss = 1; |
60 | |
61 | prb.itype = inp_dt; |
62 | prb.otype = out_dt; |
63 | |
64 | prb.nodes[0].n = ys; |
65 | prb.nodes[0].is = y_inp_str; |
66 | prb.nodes[0].os = y_out_str; |
67 | |
68 | prb.nodes[1].n = xs; |
69 | prb.nodes[1].is = x_inp_str; |
70 | prb.nodes[1].os = x_out_str; |
71 | |
72 | prb.full_ndims = prb.ndims; |
73 | |
74 | kernel_t::desc_init(desc, prb, 2); |
75 | return kernel_t::create(desc); |
76 | }; |
77 | |
78 | if (nb_x_ * nb_y_ > 0) |
79 | ker_.reset(create_ker(8, inp_str_, 1, 8, 1, out_str_)); |
80 | |
81 | if (x_tail_) |
82 | ker_x_tail_.reset(create_ker(8, inp_str_, 1, x_tail_, 1, out_str_)); |
83 | |
84 | if (y_tail_) |
85 | ker_y_tail_.reset( |
86 | create_ker(y_tail_, inp_str_, 1, xsize, 1, out_str_)); |
87 | } |
88 | |
89 | status_t create_kernel() { |
90 | if (ker_) CHECK(ker_->create_kernel()); |
91 | if (ker_x_tail_) CHECK(ker_x_tail_->create_kernel()); |
92 | if (ker_y_tail_) CHECK(ker_y_tail_->create_kernel()); |
93 | return status::success; |
94 | } |
95 | |
96 | void exec(const void *inp, void *out) { |
97 | dim_t x_blocked = nb_x_ * 8; |
98 | dim_t y_blocked = nb_y_ * 8; |
99 | |
100 | auto call_ker = [&](tr::kernel_t &ker, dim_t inp_y, dim_t inp_x, |
101 | dim_t out_y, dim_t out_x) { |
102 | tr::call_param_t cp; |
103 | cp.src_scales = nullptr; |
104 | cp.dst_scales = nullptr; |
105 | |
106 | dim_t inp_off = (inp_y * inp_str_ + inp_x) * inp_dt_size_; |
107 | dim_t out_off = (out_y * out_str_ + out_x) * out_dt_size_; |
108 | cp.in = (uint8_t *)inp + inp_off; |
109 | cp.out = (uint8_t *)out + out_off; |
110 | (ker)(&cp); |
111 | }; |
112 | |
113 | for (dim_t by = 0; by < nb_y_; by++) { |
114 | for (dim_t bx = 0; bx < nb_x_; bx++) |
115 | call_ker(*ker_, 8 * by, 8 * bx, 8 * bx, 8 * by); |
116 | |
117 | if (x_tail_) |
118 | call_ker(*ker_x_tail_, 8 * by, x_blocked, x_blocked, 8 * by); |
119 | } |
120 | if (y_tail_) call_ker(*ker_y_tail_, y_blocked, 0, 0, y_blocked); |
121 | } |
122 | |
123 | ~trans_wrapper_t() = default; |
124 | |
125 | private: |
126 | std::unique_ptr<tr::kernel_t> ker_; |
127 | std::unique_ptr<tr::kernel_t> ker_x_tail_; |
128 | std::unique_ptr<tr::kernel_t> ker_y_tail_; |
129 | |
130 | const size_t inp_dt_size_; |
131 | const size_t out_dt_size_; |
132 | |
133 | const dim_t inp_str_; |
134 | const dim_t out_str_; |
135 | const dim_t nb_x_; |
136 | const dim_t nb_y_; |
137 | const dim_t x_tail_; |
138 | const dim_t y_tail_; |
139 | }; |
140 | |
141 | struct trans_context_t { |
142 | std::unique_ptr<trans_wrapper_t> src_trans_ = nullptr; |
143 | std::unique_ptr<trans_wrapper_t> src_tail_trans_ = nullptr; |
144 | std::unique_ptr<trans_wrapper_t> ind_trans_ = nullptr; |
145 | std::unique_ptr<trans_wrapper_t> ind_tail_trans_ = nullptr; |
146 | std::unique_ptr<trans_wrapper_t> dst_trans_ = nullptr; |
147 | std::unique_ptr<trans_wrapper_t> dst_tail_trans_ = nullptr; |
148 | |
149 | // NOLINTNEXTLINE(readability-make-member-function-const) |
150 | status_t create_kernel() { |
151 | if (src_trans_) CHECK(src_trans_->create_kernel()); |
152 | if (src_tail_trans_) CHECK(src_tail_trans_->create_kernel()); |
153 | if (ind_trans_) CHECK(ind_trans_->create_kernel()); |
154 | if (ind_tail_trans_) CHECK(ind_tail_trans_->create_kernel()); |
155 | if (dst_trans_) CHECK(dst_trans_->create_kernel()); |
156 | if (dst_tail_trans_) CHECK(dst_tail_trans_->create_kernel()); |
157 | return status::success; |
158 | } |
159 | }; |
160 | |
161 | static void trans_exec(trans_wrapper_t *trans, trans_wrapper_t *trans_tail, |
162 | dim_t cs, const void *inp, void *out, dim_t c_block) { |
163 | |
164 | if (cs == c_block) |
165 | trans->exec(inp, out); |
166 | else |
167 | trans_tail->exec(inp, out); |
168 | }; |
169 | |
170 | template <typename src_data_t, typename dst_data_t> |
171 | struct transpose_ncsp_to_block_fmt_t { |
172 | transpose_ncsp_to_block_fmt_t(trans_wrapper_t *transposer, |
173 | trans_wrapper_t *transposer_tail, const src_data_t *src_nscp_base, |
174 | const memory_desc_wrapper &src_nscp_desc, |
175 | dst_data_t *__restrict dst_blocked_base, dim_t block_size, |
176 | const jit_pool_conf_t &jpp, std::size_t offset_multiplier = 1u) |
177 | : transposer_(transposer) |
178 | , transposer_tail_(transposer_tail) |
179 | , c_without_padding_(jpp.c_without_padding) |
180 | , c_block_(jpp.c_block) |
181 | , src_nscp_base_(src_nscp_base) |
182 | , src_nscp_desc_(src_nscp_desc) |
183 | , dst_blocked_base_(dst_blocked_base) |
184 | , block_size_(block_size) |
185 | , offset_multiplier_(offset_multiplier) {} |
186 | |
187 | void operator()(std::size_t ithr, int n, int b_c) const { |
188 | const dim_t cs |
189 | = nstl::min(c_without_padding_ - b_c * c_block_, c_block_); |
190 | const src_data_t *src_nscp = src_nscp_base_ |
191 | + src_nscp_desc_.blk_off(n, b_c * c_block_, 0) |
192 | * offset_multiplier_; |
193 | dst_data_t *dst_blocked |
194 | = dst_blocked_base_ + ithr * block_size_ * offset_multiplier_; |
195 | trans_exec(transposer_, transposer_tail_, cs, src_nscp, dst_blocked, |
196 | c_block_); |
197 | } |
198 | |
199 | private: |
200 | trans_wrapper_t *transposer_; |
201 | trans_wrapper_t *transposer_tail_; |
202 | const int c_without_padding_; |
203 | const int c_block_; |
204 | const src_data_t *src_nscp_base_; |
205 | const memory_desc_wrapper &src_nscp_desc_; |
206 | dst_data_t *__restrict dst_blocked_base_; |
207 | const dim_t block_size_; |
208 | std::size_t offset_multiplier_; |
209 | }; |
210 | |
211 | template <typename src_data_t, typename dst_data_t> |
212 | struct transpose_block_fmt_to_ncsp_t { |
213 | |
214 | transpose_block_fmt_to_ncsp_t(trans_wrapper_t *transposer, |
215 | trans_wrapper_t *transposer_tail, |
216 | const src_data_t *__restrict src_blocked_base, dim_t block_size, |
217 | dst_data_t *dst_ncsp_base, const memory_desc_wrapper &dst_nscp_desc, |
218 | const jit_pool_conf_t &jpp, std::size_t offset_multiplier = 1u) |
219 | : transposer_(transposer) |
220 | , transposer_tail_(transposer_tail) |
221 | , c_without_padding_(jpp.c_without_padding) |
222 | , c_block_(jpp.c_block) |
223 | , src_blocked_base_(src_blocked_base) |
224 | , block_size_(block_size) |
225 | , dst_ncsp_base_(dst_ncsp_base) |
226 | , dst_nscp_desc_(dst_nscp_desc) |
227 | , offset_multiplier_(offset_multiplier) {} |
228 | |
229 | void operator()(std::size_t ithr, int n, int b_c) const { |
230 | const dim_t cs |
231 | = nstl::min(c_without_padding_ - b_c * c_block_, c_block_); |
232 | const src_data_t *src_blocked |
233 | = src_blocked_base_ + ithr * block_size_ * offset_multiplier_; |
234 | dst_data_t *dst_ncsp = dst_ncsp_base_ |
235 | + dst_nscp_desc_.blk_off(n, b_c * c_block_, 0) |
236 | * offset_multiplier_; |
237 | trans_exec(transposer_, transposer_tail_, cs, src_blocked, dst_ncsp, |
238 | c_block_); |
239 | } |
240 | |
241 | private: |
242 | trans_wrapper_t *transposer_; |
243 | trans_wrapper_t *transposer_tail_; |
244 | const int c_without_padding_; |
245 | const int c_block_; |
246 | const src_data_t *__restrict src_blocked_base_; |
247 | const dim_t block_size_; |
248 | dst_data_t *dst_ncsp_base_; |
249 | const memory_desc_wrapper &dst_nscp_desc_; |
250 | std::size_t offset_multiplier_; |
251 | }; |
252 | |
253 | template <typename wsp_data_t, impl::data_type_t d_type> |
254 | class transpose_facade_base_t { |
255 | public: |
256 | transpose_facade_base_t(const jit_pool_conf_t &jpp, |
257 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &dst_d, |
258 | const memory_desc_wrapper &indices_d, const char *indices, |
259 | const data_type_t wsp_dt, const exec_ctx_t &ctx) |
260 | : src_sp_(static_cast<dim_t>(jpp.id) * jpp.ih * jpp.iw) |
261 | , dst_sp_(static_cast<dim_t>(jpp.od) * jpp.oh * jpp.ow) |
262 | , src_slice_(src_sp_ * jpp.c_block) |
263 | , dst_slice_(dst_sp_ * jpp.c_block) |
264 | , transpose_src_(jpp.tag_kind == jit_memory_tag_kind_t::ncsp) |
265 | , transpose_dst_(jpp.tag_kind == jit_memory_tag_kind_t::ncsp) |
266 | , src_d_(src_d) |
267 | , dst_d_(dst_d) |
268 | , indices_d_(indices_d) |
269 | , ind_dt_size_( |
270 | indices ? types::data_type_size(indices_d_.data_type()) : 0) |
271 | , cvt_slice_src_wsp_(nullptr) |
272 | , cvt_slice_dst_wsp_(nullptr) |
273 | , cvt_slice_ind_wsp_(nullptr) |
274 | , execute_transpose_input_(nullptr) |
275 | , execute_transpose_output_(nullptr) { |
276 | |
277 | auto scratchpad = ctx.get_scratchpad_grantor(); |
278 | |
279 | if (transpose_src_) |
280 | cvt_slice_src_wsp_ = scratchpad.template get<wsp_data_t>( |
281 | memory_tracking::names::key_pool_src_plain2blocked_cvt); |
282 | |
283 | if (transpose_dst_) { |
284 | cvt_slice_dst_wsp_ = scratchpad.template get<wsp_data_t>( |
285 | memory_tracking::names::key_pool_dst_plain2blocked_cvt); |
286 | cvt_slice_ind_wsp_ = scratchpad.template get<char>( |
287 | memory_tracking::names::key_pool_ind_plain2blocked_cvt); |
288 | } |
289 | } |
290 | |
291 | inline bool should_transpose_src() const noexcept { return transpose_src_; } |
292 | inline bool should_transpose_dst() const noexcept { return transpose_dst_; } |
293 | |
294 | const void *get_src_addr( |
295 | std::size_t ithr, int ih, const jit_pool_conf_t &jpp) const { |
296 | const wsp_data_t *const wsp = cvt_slice_src_wsp_ + ithr * src_slice_; |
297 | return static_cast<const void *>(&wsp[ih * jpp.iw * jpp.c_block]); |
298 | } |
299 | |
300 | const void *get_dst_addr( |
301 | std::size_t ithr, int oh, const jit_pool_conf_t &jpp) const { |
302 | const wsp_data_t *const wsp = cvt_slice_dst_wsp_ + ithr * dst_slice_; |
303 | return static_cast<const void *>(&wsp[oh * jpp.ow * jpp.c_block]); |
304 | } |
305 | |
306 | const void *get_indices_addr( |
307 | std::size_t ithr, int oh, const jit_pool_conf_t &jpp) const { |
308 | const char *const wsp |
309 | = cvt_slice_ind_wsp_ + ithr * dst_slice_ * ind_dt_size_; |
310 | return static_cast<const void *>( |
311 | &wsp[oh * jpp.ow * jpp.c_block * ind_dt_size_]); |
312 | } |
313 | |
314 | const void *get_src_addr_3d(std::size_t ithr, int id, int ih, |
315 | const jit_pool_conf_t &jpp) const { |
316 | const wsp_data_t *const wsp = cvt_slice_src_wsp_ + ithr * src_slice_; |
317 | return static_cast<const void *>(&wsp[ih * jpp.iw * jpp.c_block |
318 | + id * jpp.ih * jpp.iw * jpp.c_block]); |
319 | } |
320 | |
321 | const void *get_dst_addr_3d(std::size_t ithr, int od, int oh, |
322 | const jit_pool_conf_t &jpp) const { |
323 | const wsp_data_t *const wsp = cvt_slice_dst_wsp_ + ithr * dst_slice_; |
324 | return static_cast<const void *>(&wsp[oh * jpp.ow * jpp.c_block |
325 | + od * jpp.oh * jpp.ow * jpp.c_block]); |
326 | } |
327 | |
328 | const void *get_indices_addr_3d(std::size_t ithr, int od, int oh, |
329 | const jit_pool_conf_t &jpp) const { |
330 | const char *const wsp |
331 | = cvt_slice_ind_wsp_ + ithr * dst_slice_ * ind_dt_size_; |
332 | return static_cast<const void *>( |
333 | &wsp[oh * jpp.ow * jpp.c_block * ind_dt_size_ |
334 | + od * jpp.oh * jpp.ow * jpp.c_block * ind_dt_size_]); |
335 | } |
336 | |
337 | void execute_transpose_input(std::size_t ithr, int n, int b_c) const { |
338 | execute_transpose_input_(ithr, n, b_c); |
339 | } |
340 | |
341 | void execute_transpose_output(std::size_t ithr, int n, int b_c) const { |
342 | execute_transpose_output_(ithr, n, b_c); |
343 | } |
344 | |
345 | protected: |
346 | const dim_t src_sp_; |
347 | const dim_t dst_sp_; |
348 | const dim_t src_slice_; |
349 | const dim_t dst_slice_; |
350 | |
351 | const bool transpose_src_; |
352 | const bool transpose_dst_; |
353 | |
354 | const memory_desc_wrapper &src_d_; |
355 | const memory_desc_wrapper &dst_d_; |
356 | const memory_desc_wrapper &indices_d_; |
357 | const size_t ind_dt_size_; |
358 | |
359 | wsp_data_t *__restrict cvt_slice_src_wsp_; |
360 | wsp_data_t *__restrict cvt_slice_dst_wsp_; |
361 | char *__restrict cvt_slice_ind_wsp_; |
362 | |
363 | std::function<void(std::size_t, int, int)> execute_transpose_input_; |
364 | std::function<void(std::size_t, int, int)> execute_transpose_output_; |
365 | }; |
366 | |
367 | template <typename data_t, typename wsp_data_t, impl::data_type_t d_type> |
368 | class fwd_pooling_transpose_facade_t |
369 | : public transpose_facade_base_t<wsp_data_t, d_type> { |
370 | public: |
371 | fwd_pooling_transpose_facade_t(const jit_pool_conf_t &jpp, |
372 | trans_context_t *trans_ctx, const memory_desc_wrapper &src_d, |
373 | const memory_desc_wrapper &dst_d, |
374 | const memory_desc_wrapper &indices_d, const data_type_t wsp_dt, |
375 | const data_t *src, data_t *dst, char *indices, |
376 | const exec_ctx_t &ctx) |
377 | : transpose_facade_base_t<wsp_data_t, d_type>( |
378 | jpp, src_d, dst_d, indices_d, indices, wsp_dt, ctx) { |
379 | |
380 | if (this->should_transpose_src()) { |
381 | this->execute_transpose_input_ |
382 | = transpose_ncsp_to_block_fmt_t<data_t, wsp_data_t>( |
383 | trans_ctx->src_trans_.get(), |
384 | trans_ctx->src_tail_trans_.get(), src, this->src_d_, |
385 | this->cvt_slice_src_wsp_, this->src_slice_, jpp); |
386 | } |
387 | |
388 | if (this->should_transpose_dst()) { |
389 | using namespace std::placeholders; |
390 | this->execute_transpose_output_ = std::bind( |
391 | [=](const transpose_block_fmt_to_ncsp_t<wsp_data_t, data_t> |
392 | &trans_dst, |
393 | transpose_block_fmt_to_ncsp_t<char, char> |
394 | &trans_indices, |
395 | std::size_t ithr, int n, int b_c) { |
396 | trans_dst(ithr, n, b_c); |
397 | if (indices) trans_indices(ithr, n, b_c); |
398 | }, |
399 | transpose_block_fmt_to_ncsp_t<wsp_data_t, data_t>( |
400 | trans_ctx->dst_trans_.get(), |
401 | trans_ctx->dst_tail_trans_.get(), |
402 | this->cvt_slice_dst_wsp_, this->dst_slice_, dst, |
403 | this->dst_d_, jpp, 1u), |
404 | transpose_block_fmt_to_ncsp_t<char, char>( |
405 | trans_ctx->ind_trans_.get(), |
406 | trans_ctx->ind_tail_trans_.get(), |
407 | this->cvt_slice_ind_wsp_, this->dst_slice_, indices, |
408 | this->indices_d_, jpp, this->ind_dt_size_), |
409 | _1, _2, _3); |
410 | } |
411 | } |
412 | }; |
413 | |
414 | template <typename data_t, typename wsp_data_t, impl::data_type_t d_type> |
415 | class bwd_pooling_transpose_facade_t |
416 | : public transpose_facade_base_t<wsp_data_t, d_type> { |
417 | public: |
418 | bwd_pooling_transpose_facade_t(const jit_pool_conf_t &jpp, |
419 | trans_context_t *trans_ctx, const memory_desc_wrapper &src_d, |
420 | const memory_desc_wrapper &dst_d, |
421 | const memory_desc_wrapper &indices_d, const data_type_t wsp_dt, |
422 | data_t *src, const data_t *dst, const char *indices, |
423 | const exec_ctx_t &ctx) |
424 | : transpose_facade_base_t<wsp_data_t, d_type>( |
425 | jpp, src_d, dst_d, indices_d, indices, wsp_dt, ctx) |
426 | , c_tail_(jpp.c_without_padding % jpp.c_block) { |
427 | |
428 | if (this->should_transpose_src()) |
429 | this->execute_transpose_output_ |
430 | = transpose_block_fmt_to_ncsp_t<wsp_data_t, data_t>( |
431 | trans_ctx->src_trans_.get(), |
432 | trans_ctx->src_tail_trans_.get(), |
433 | this->cvt_slice_src_wsp_, this->src_slice_, src, |
434 | this->src_d_, jpp, 1u); |
435 | |
436 | if (this->should_transpose_dst()) { |
437 | using namespace std::placeholders; |
438 | |
439 | this->execute_transpose_input_ = std::bind( |
440 | [=](const transpose_ncsp_to_block_fmt_t<data_t, wsp_data_t> |
441 | &trans_dst, |
442 | transpose_ncsp_to_block_fmt_t<char, char> |
443 | &trans_indices, |
444 | std::size_t ithr, int n, int b_c) { |
445 | trans_dst(ithr, n, b_c); |
446 | if (indices) trans_indices(ithr, n, b_c); |
447 | }, |
448 | transpose_ncsp_to_block_fmt_t<data_t, wsp_data_t>( |
449 | trans_ctx->dst_trans_.get(), |
450 | trans_ctx->dst_tail_trans_.get(), dst, this->dst_d_, |
451 | this->cvt_slice_dst_wsp_, this->dst_slice_, jpp), |
452 | transpose_ncsp_to_block_fmt_t<char, char>( |
453 | trans_ctx->ind_trans_.get(), |
454 | trans_ctx->ind_tail_trans_.get(), indices, |
455 | this->indices_d_, this->cvt_slice_ind_wsp_, |
456 | this->dst_slice_, jpp, this->ind_dt_size_), |
457 | _1, _2, _3); |
458 | } |
459 | } |
460 | |
461 | inline bool should_fill_input_c_tail_with_zeros() const noexcept { |
462 | return this->should_transpose_dst() && c_tail_ != 0; |
463 | } |
464 | |
465 | void fill_input_c_tail_with_zeros( |
466 | std::size_t ithr, const jit_pool_conf_t &jpp) const { |
467 | |
468 | wsp_data_t *__restrict wsp_ptr |
469 | = this->cvt_slice_dst_wsp_ + ithr * this->dst_slice_; |
470 | for_(dim_t s = 0; s < this->dst_sp_; s++) |
471 | for (dim_t c = c_tail_; c < jpp.c_block; c++) |
472 | wsp_ptr[s * jpp.c_block + c] = 0.f; |
473 | |
474 | char *__restrict ind_ptr = this->cvt_slice_ind_wsp_ |
475 | + ithr * this->dst_slice_ * this->ind_dt_size_; |
476 | for_(dim_t s = 0; s < this->dst_sp_; s++) |
477 | for_(dim_t c = c_tail_; c < jpp.c_block; c++) |
478 | for (size_t i = 0; i < this->ind_dt_size_; i++) |
479 | ind_ptr[(s * jpp.c_block + c) * this->ind_dt_size_ + i] = 0; |
480 | } |
481 | |
482 | private: |
483 | const dim_t c_tail_; |
484 | }; |
485 | |
486 | } // namespace jit_uni_pooling_utils |
487 | |
488 | template <cpu_isa_t isa, impl::data_type_t d_type> |
489 | jit_uni_pooling_fwd_t<isa, d_type>::jit_uni_pooling_fwd_t(const pd_t *apd) |
490 | : primitive_t(apd), kernel_(nullptr), trans_ctx_(nullptr) {} |
491 | |
492 | template <cpu_isa_t isa, impl::data_type_t d_type> |
493 | status_t jit_uni_pooling_fwd_t<isa, d_type>::init(engine_t *engine) { |
494 | |
495 | CHECK(safe_ptr_assign(kernel_, |
496 | new jit_uni_pool_kernel<isa>( |
497 | pd()->jpp_, pd()->invariant_dst_md()))); |
498 | |
499 | if (pd()->jpp_.tag_kind == jit_memory_tag_kind_t::ncsp) |
500 | CHECK(init_ncsp_trans_ctx()); |
501 | return kernel_->create_kernel(); |
502 | } |
503 | |
504 | template <cpu_isa_t isa, data_type_t d_type> |
505 | status_t jit_uni_pooling_fwd_t<isa, d_type>::init_ncsp_trans_ctx() { |
506 | using namespace dnnl::impl; |
507 | using namespace jit_uni_pooling_utils; |
508 | |
509 | const auto &jpp = pd()->jpp_; |
510 | trans_ctx_ = utils::make_unique<trans_context_t>(); |
511 | const dim_t src_sp = static_cast<dim_t>(jpp.id) * jpp.ih * jpp.iw; |
512 | const dim_t dst_sp = static_cast<dim_t>(jpp.od) * jpp.oh * jpp.ow; |
513 | const auto res = std::div(jpp.c_without_padding, jpp.c_block); |
514 | const dim_t &nb_c = res.quot; |
515 | const dim_t &c_tail = res.rem; |
516 | const memory_desc_wrapper indices_d = pd()->workspace_md(); |
517 | const bool have_indices = indices_d.data_type() != data_type::undef; |
518 | static constexpr auto wsp_dt = wsp_dt_; |
519 | |
520 | if (nb_c) { |
521 | trans_ctx_->src_trans_ = utils::make_unique<trans_wrapper_t>( |
522 | d_type, src_sp, wsp_dt, jpp.c_block, jpp.c_block, src_sp); |
523 | trans_ctx_->dst_trans_ = utils::make_unique<trans_wrapper_t>( |
524 | wsp_dt, jpp.c_block, d_type, dst_sp, dst_sp, jpp.c_block); |
525 | if (have_indices) |
526 | trans_ctx_->ind_trans_ = utils::make_unique<trans_wrapper_t>( |
527 | indices_d.data_type(), jpp.c_block, indices_d.data_type(), |
528 | dst_sp, dst_sp, jpp.c_block); |
529 | } |
530 | |
531 | if (c_tail) { |
532 | trans_ctx_->src_tail_trans_ = utils::make_unique<trans_wrapper_t>( |
533 | d_type, src_sp, wsp_dt, jpp.c_block, c_tail, src_sp); |
534 | trans_ctx_->dst_tail_trans_ = utils::make_unique<trans_wrapper_t>( |
535 | wsp_dt, jpp.c_block, d_type, dst_sp, dst_sp, c_tail); |
536 | if (have_indices) |
537 | trans_ctx_->ind_tail_trans_ = utils::make_unique<trans_wrapper_t>( |
538 | indices_d.data_type(), jpp.c_block, indices_d.data_type(), |
539 | dst_sp, dst_sp, c_tail); |
540 | } |
541 | |
542 | return trans_ctx_->create_kernel(); |
543 | } |
544 | |
545 | template <cpu_isa_t isa, impl::data_type_t d_type> |
546 | jit_uni_pooling_fwd_t<isa, d_type>::~jit_uni_pooling_fwd_t() = default; |
547 | |
548 | template <cpu_isa_t isa, data_type_t d_type> |
549 | void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src, |
550 | data_t *dst, char *indices, const exec_ctx_t &ctx) const { |
551 | |
552 | const memory_desc_wrapper src_d = pd()->src_md(); |
553 | const memory_desc_wrapper dst_d = pd()->dst_md(); |
554 | const memory_desc_wrapper indices_d = pd()->workspace_md(); |
555 | const auto ind_dt_size |
556 | = indices ? types::data_type_size(indices_d.data_type()) : 0; |
557 | const auto &jpp = pd()->jpp_; |
558 | const auto post_ops_binary_rhs_arg_vec |
559 | = binary_injector::prepare_binary_args(jpp.post_ops, ctx); |
560 | |
561 | using wsp_data_t = typename prec_traits<wsp_dt_>::type; |
562 | using namespace jit_uni_pooling_utils; |
563 | |
564 | const auto transpose_facade |
565 | = fwd_pooling_transpose_facade_t<data_t, wsp_data_t, d_type>(jpp, |
566 | trans_ctx_.get(), src_d, dst_d, indices_d, wsp_dt_, src, |
567 | dst, indices, ctx); |
568 | |
569 | const auto trans_src = transpose_facade.should_transpose_src(); |
570 | const auto trans_dst = transpose_facade.should_transpose_dst(); |
571 | |
572 | const auto ker = [&](std::size_t ithr, int n, int b_c, int oh, int ur_bc) { |
573 | assert(ur_bc == jpp.ur_bc || ur_bc == jpp.ur_bc_tail); |
574 | auto arg = jit_pool_call_s(); |
575 | |
576 | const int ij = oh * jpp.stride_h; |
577 | const int i_t_overflow = nstl::max(0, jpp.t_pad - ij); |
578 | const int i_b_overflow |
579 | = nstl::max(jpp.ih, ij + jpp.kh - jpp.t_pad) - jpp.ih; |
580 | const int ih = nstl::max(ij - jpp.t_pad, 0); |
581 | assert(IMPLICATION(pd()->ndims() == 3, utils::everyone_is(0, ih, oh))); |
582 | const int c_off |
583 | = ((jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c_block |
584 | : 1) |
585 | * b_c; |
586 | |
587 | if (trans_src) |
588 | arg.src = transpose_facade.get_src_addr(ithr, ih, jpp); |
589 | else |
590 | arg.src = static_cast<const void *>( |
591 | &src[src_d.blk_off(n, c_off, ih)]); |
592 | |
593 | arg.dst_orig = dst; |
594 | if (trans_dst) { |
595 | arg.dst = transpose_facade.get_dst_addr(ithr, oh, jpp); |
596 | if (!types::is_zero_md(&jpp.tmp_md)) { |
597 | const memory_desc_wrapper tmp_d |
598 | = memory_desc_wrapper(jpp.tmp_md); |
599 | // offset needs to be f32 |
600 | const int dt_scale |
601 | = sizeof(float) / types::data_type_size(d_type); |
602 | const auto blk_off = tmp_d.blk_off(n, c_off, oh) * dt_scale; |
603 | arg.dst_po_helper = static_cast<const void *>(&dst[blk_off]); |
604 | } |
605 | } else { |
606 | arg.dst = static_cast<const void *>( |
607 | &dst[dst_d.blk_off(n, c_off, oh)]); |
608 | } |
609 | |
610 | if (indices) { |
611 | if (trans_dst) |
612 | arg.indices = transpose_facade.get_indices_addr(ithr, oh, jpp); |
613 | else { |
614 | const size_t ind_off = indices_d.blk_off(n, c_off, oh); |
615 | arg.indices = static_cast<const void *>( |
616 | &indices[ind_off * ind_dt_size]); |
617 | } |
618 | } |
619 | arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; |
620 | arg.kh_padding_shift = i_t_overflow * jpp.kw; |
621 | arg.ker_area_h = static_cast<float>(jpp.kh |
622 | - nstl::max(0, oh * jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) |
623 | - nstl::max(0, jpp.t_pad - oh * jpp.stride_h)); |
624 | arg.ur_bc = ur_bc; |
625 | arg.b_c = b_c; |
626 | arg.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
627 | (*kernel_)(&arg); |
628 | }; |
629 | |
630 | const int nthr = jpp.nthr; |
631 | |
632 | if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { |
633 | const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); |
634 | parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](dim_t n, dim_t oh, dim_t b2_c) { |
635 | const auto b_c = b2_c * jpp.ur_bc; |
636 | const auto ur_bc = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c); |
637 | ker(0, n, b_c, oh, ur_bc); |
638 | }); |
639 | } else { |
640 | if (trans_src || trans_dst) { |
641 | // ncsp format |
642 | parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, |
643 | [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { |
644 | if (trans_src) |
645 | transpose_facade.execute_transpose_input( |
646 | ithr, n, b_c); |
647 | for (dim_t oh = 0; oh < jpp.oh; ++oh) |
648 | ker(ithr, n, b_c, oh, 1); |
649 | if (trans_dst) |
650 | transpose_facade.execute_transpose_output( |
651 | ithr, n, b_c); |
652 | }); |
653 | } else { |
654 | // nChw16c, nChw8c format |
655 | parallel(nthr, [&](dim_t ithr, dim_t nthr) { |
656 | dim_t work_amount = jpp.mb * jpp.nb_c * jpp.oh; |
657 | if (ithr >= work_amount) return; |
658 | |
659 | dim_t start {0}, end {0}; |
660 | dim_t n {0}, b_c {0}, oh {0}; |
661 | |
662 | balance211(work_amount, nthr, ithr, start, end); |
663 | utils::nd_iterator_init( |
664 | start, n, jpp.mb, b_c, jpp.nb_c, oh, jpp.oh); |
665 | |
666 | for (dim_t iwork = start; iwork < end; ++iwork) { |
667 | ker(ithr, n, b_c, oh, 1); |
668 | utils::nd_iterator_step( |
669 | n, jpp.mb, b_c, jpp.nb_c, oh, jpp.oh); |
670 | } |
671 | }); |
672 | } |
673 | } |
674 | } |
675 | |
676 | template <cpu_isa_t isa, data_type_t d_type> |
677 | void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src, |
678 | data_t *dst, char *indices, const exec_ctx_t &ctx) const { |
679 | |
680 | const auto &jpp = pd()->jpp_; |
681 | const memory_desc_wrapper src_d(pd()->src_md()); |
682 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
683 | const memory_desc_wrapper indices_d(pd()->workspace_md()); |
684 | const size_t ind_dt_size |
685 | = indices ? types::data_type_size(indices_d.data_type()) : 0; |
686 | const auto post_ops_binary_rhs_arg_vec |
687 | = binary_injector::prepare_binary_args(jpp.post_ops, ctx); |
688 | |
689 | using wsp_data_t = typename prec_traits<wsp_dt_>::type; |
690 | using namespace jit_uni_pooling_utils; |
691 | static constexpr int first_ithr = 0; |
692 | |
693 | const auto transpose_facade |
694 | = fwd_pooling_transpose_facade_t<data_t, wsp_data_t, d_type>(jpp, |
695 | trans_ctx_.get(), src_d, dst_d, indices_d, wsp_dt_, src, |
696 | dst, indices, ctx); |
697 | |
698 | const auto trans_src = transpose_facade.should_transpose_src(); |
699 | const auto trans_dst = transpose_facade.should_transpose_dst(); |
700 | |
701 | auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, |
702 | int d_b_overflow, int ur_bc, int ithr) { |
703 | assert(ur_bc == jpp.ur_bc || ur_bc == jpp.ur_bc_tail); |
704 | auto arg = jit_pool_call_s(); |
705 | |
706 | const int ij = oh * jpp.stride_h; |
707 | const int i_t_overflow = nstl::max(0, jpp.t_pad - ij); |
708 | const int i_b_overflow |
709 | = nstl::max(jpp.ih, ij + jpp.kh - jpp.t_pad) - jpp.ih; |
710 | const int ih = nstl::max(ij - jpp.t_pad, 0); |
711 | const int c_off |
712 | = ((jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c_block |
713 | : 1) |
714 | * b_c; |
715 | |
716 | if (trans_src) |
717 | arg.src = transpose_facade.get_src_addr_3d(ithr, id, ih, jpp); |
718 | else |
719 | arg.src = &src[src_d.blk_off(n, c_off, id, ih)]; |
720 | |
721 | arg.dst_orig = dst; |
722 | if (trans_dst) { |
723 | arg.dst = transpose_facade.get_dst_addr_3d(ithr, od, oh, jpp); |
724 | if (!types::is_zero_md(&jpp.tmp_md)) { |
725 | const memory_desc_wrapper tmp_d |
726 | = memory_desc_wrapper(jpp.tmp_md); |
727 | // offset needs to be f32 |
728 | const int dt_scale |
729 | = sizeof(float) / types::data_type_size(d_type); |
730 | const auto blk_off = tmp_d.blk_off(n, c_off, od, oh) * dt_scale; |
731 | arg.dst_po_helper = static_cast<const void *>(&dst[blk_off]); |
732 | } |
733 | } else { |
734 | arg.dst = &dst[dst_d.blk_off(n, c_off, od, oh)]; |
735 | } |
736 | |
737 | if (indices) { |
738 | if (trans_dst) { |
739 | arg.indices = transpose_facade.get_indices_addr_3d( |
740 | ithr, od, oh, jpp); |
741 | } else { |
742 | const size_t ind_off = indices_d.blk_off(n, c_off, od, oh); |
743 | arg.indices = &indices[ind_off * ind_dt_size]; |
744 | } |
745 | } |
746 | |
747 | arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; |
748 | arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; |
749 | arg.kh_padding_shift |
750 | = i_t_overflow * jpp.kw + d_t_overflow * jpp.kw * jpp.kh; |
751 | arg.kd_padding_shift = (i_t_overflow + i_b_overflow) * jpp.kw; |
752 | arg.ker_area_h = (float)(jpp.kh |
753 | - nstl::max(0, |
754 | oh * jpp.stride_h - jpp.t_pad + jpp.kh |
755 | - jpp.ih) |
756 | - nstl::max(0, jpp.t_pad - oh * jpp.stride_h)) |
757 | * (jpp.kd |
758 | - nstl::max(0, |
759 | od * jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) |
760 | - nstl::max(0, jpp.f_pad - od * jpp.stride_d)); |
761 | |
762 | arg.ur_bc = ur_bc; |
763 | arg.b_c = b_c; |
764 | arg.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
765 | (*kernel_)(&arg); |
766 | }; |
767 | |
768 | const int nthr = jpp.nthr; |
769 | |
770 | if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { |
771 | const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); |
772 | parallel_nd(jpp.mb, jpp.od, nb2_c, [&](dim_t n, dim_t od, dim_t b2_c) { |
773 | const dim_t b_c = b2_c * jpp.ur_bc; |
774 | const dim_t ur_bc = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c); |
775 | |
776 | const dim_t ik = od * jpp.stride_d; |
777 | const dim_t d_t_overflow = nstl::max(dim_t(0), jpp.f_pad - ik); |
778 | const dim_t d_b_overflow |
779 | = nstl::max(dim_t(jpp.id), ik + jpp.kd - jpp.f_pad) |
780 | - jpp.id; |
781 | const dim_t id = nstl::max(ik - jpp.f_pad, dim_t(0)); |
782 | for (dim_t oh = 0; oh < jpp.oh; ++oh) { |
783 | ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, ur_bc, |
784 | first_ithr); |
785 | } |
786 | }); |
787 | } else { |
788 | if (trans_src || trans_dst) { |
789 | parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, |
790 | [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { |
791 | if (trans_src) |
792 | transpose_facade.execute_transpose_input( |
793 | ithr, n, b_c); |
794 | |
795 | for (int od = 0; od < jpp.od; ++od) { |
796 | const int ik = od * jpp.stride_d; |
797 | const int d_t_overflow |
798 | = nstl::max(0, jpp.f_pad - ik); |
799 | const int d_b_overflow |
800 | = nstl::max(jpp.id, ik + jpp.kd - jpp.f_pad) |
801 | - jpp.id; |
802 | const int id = nstl::max(ik - jpp.f_pad, 0); |
803 | for (int oh = 0; oh < jpp.oh; ++oh) { |
804 | ker(n, b_c, od, oh, id, d_t_overflow, |
805 | d_b_overflow, 1, ithr); |
806 | } |
807 | } |
808 | |
809 | if (trans_dst) |
810 | transpose_facade.execute_transpose_output( |
811 | ithr, n, b_c); |
812 | }); |
813 | } else { |
814 | parallel_nd(jpp.mb, jpp.nb_c, jpp.od, |
815 | [&](dim_t n, dim_t b_c, dim_t od) { |
816 | const int ik = od * jpp.stride_d; |
817 | const int d_t_overflow = nstl::max(0, jpp.f_pad - ik); |
818 | const int d_b_overflow |
819 | = nstl::max(jpp.id, ik + jpp.kd - jpp.f_pad) |
820 | - jpp.id; |
821 | const int id = nstl::max(ik - jpp.f_pad, 0); |
822 | for (int oh = 0; oh < jpp.oh; ++oh) { |
823 | ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, |
824 | 1, first_ithr); |
825 | } |
826 | }); |
827 | } |
828 | } |
829 | } |
830 | |
831 | template <cpu_isa_t isa, data_type_t d_type> |
832 | jit_uni_pooling_bwd_t<isa, d_type>::jit_uni_pooling_bwd_t(const pd_t *apd) |
833 | : primitive_t(apd) |
834 | , kernel_(utils::make_unique<jit_uni_pool_kernel<isa>>( |
835 | pd()->jpp_, pd()->invariant_dst_md())) |
836 | , trans_ctx_(nullptr) {} |
837 | |
838 | template <cpu_isa_t isa, data_type_t d_type> |
839 | jit_uni_pooling_bwd_t<isa, d_type>::~jit_uni_pooling_bwd_t() = default; |
840 | |
841 | template <cpu_isa_t isa, data_type_t d_type> |
842 | status_t jit_uni_pooling_bwd_t<isa, d_type>::init_ncsp_trans_ctx() { |
843 | using namespace dnnl::impl; |
844 | using namespace jit_uni_pooling_utils; |
845 | |
846 | const auto &jpp = pd()->jpp_; |
847 | trans_ctx_ = utils::make_unique<trans_context_t>(); |
848 | const dim_t diff_src_sp = static_cast<dim_t>(jpp.id) * jpp.ih * jpp.iw; |
849 | const dim_t diff_dst_sp = static_cast<dim_t>(jpp.od) * jpp.oh * jpp.ow; |
850 | const auto res = std::div(jpp.c_without_padding, jpp.c_block); |
851 | const dim_t &nb_c = res.quot; |
852 | const dim_t &c_tail = res.rem; |
853 | const memory_desc_wrapper indices_d = pd()->workspace_md(); |
854 | const bool have_indices = indices_d.data_type() != data_type::undef; |
855 | static constexpr auto wsp_dt = wsp_dt_; |
856 | |
857 | if (nb_c) { |
858 | trans_ctx_->dst_trans_ = utils::make_unique<trans_wrapper_t>(d_type, |
859 | diff_dst_sp, wsp_dt, jpp.c_block, jpp.c_block, diff_dst_sp); |
860 | trans_ctx_->src_trans_ = utils::make_unique<trans_wrapper_t>(wsp_dt, |
861 | jpp.c_block, d_type, diff_src_sp, diff_src_sp, jpp.c_block); |
862 | if (have_indices) |
863 | trans_ctx_->ind_trans_ = utils::make_unique<trans_wrapper_t>( |
864 | indices_d.data_type(), diff_dst_sp, indices_d.data_type(), |
865 | jpp.c_block, jpp.c_block, diff_dst_sp); |
866 | } |
867 | if (c_tail) { |
868 | trans_ctx_->dst_tail_trans_ = utils::make_unique<trans_wrapper_t>( |
869 | d_type, diff_dst_sp, wsp_dt, jpp.c_block, c_tail, diff_dst_sp); |
870 | trans_ctx_->src_tail_trans_ = utils::make_unique<trans_wrapper_t>( |
871 | wsp_dt, jpp.c_block, d_type, diff_src_sp, diff_src_sp, c_tail); |
872 | if (have_indices) |
873 | trans_ctx_->ind_tail_trans_ = utils::make_unique<trans_wrapper_t>( |
874 | indices_d.data_type(), diff_dst_sp, indices_d.data_type(), |
875 | jpp.c_block, c_tail, diff_dst_sp); |
876 | } |
877 | |
878 | return trans_ctx_->create_kernel(); |
879 | } |
880 | |
881 | template <cpu_isa_t isa, data_type_t d_type> |
882 | status_t jit_uni_pooling_bwd_t<isa, d_type>::init(engine_t *engine) { |
883 | if (pd()->jpp_.tag_kind == jit_memory_tag_kind_t::ncsp) |
884 | CHECK(init_ncsp_trans_ctx()); |
885 | return kernel_->create_kernel(); |
886 | } |
887 | |
888 | template <cpu_isa_t isa, data_type_t d_type> |
889 | void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward( |
890 | const data_t *diff_dst, const char *indices, data_t *diff_src, |
891 | const exec_ctx_t &ctx) const { |
892 | |
893 | using namespace jit_uni_pooling_utils; |
894 | using wsp_data_t = typename prec_traits<wsp_dt_>::type; |
895 | |
896 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
897 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
898 | const memory_desc_wrapper indices_d(pd()->workspace_md()); |
899 | const size_t ind_dt_size |
900 | = indices ? types::data_type_size(indices_d.data_type()) : 0; |
901 | const auto &jpp = pd()->jpp_; |
902 | const auto transpose_facade |
903 | = jit_uni_pooling_utils::bwd_pooling_transpose_facade_t<data_t, |
904 | wsp_data_t, d_type>(jpp, trans_ctx_.get(), diff_src_d, |
905 | diff_dst_d, indices_d, wsp_dt_, diff_src, diff_dst, indices, |
906 | ctx); |
907 | |
908 | auto get_first_ih = [&](int oh) { |
909 | return nstl::min(nstl::max(oh * jpp.stride_h - jpp.t_pad, 0), jpp.ih); |
910 | }; |
911 | |
912 | auto get_last_ih = [&](int oh) { |
913 | return nstl::min( |
914 | nstl::max(oh * jpp.stride_h - jpp.t_pad + jpp.kh, 0), jpp.ih); |
915 | }; |
916 | const auto ker = [&](int ithr, int n, int b_c, int oh, int ur_bc) { |
917 | auto arg = jit_pool_call_s(); |
918 | |
919 | const int ih = get_first_ih(oh); |
920 | assert(IMPLICATION(pd()->ndims() == 3, utils::everyone_is(0, ih, oh))); |
921 | assert(pd()->ndims() != 3 || utils::everyone_is(0, ih, oh)); |
922 | |
923 | const auto c_off = jpp.is_plain() ? b_c * jpp.c_block : b_c; |
924 | if (transpose_facade.should_transpose_src()) |
925 | arg.src = transpose_facade.get_src_addr(ithr, ih, jpp); |
926 | else |
927 | arg.src = &diff_src[diff_src_d.blk_off(n, c_off, ih)]; |
928 | |
929 | if (transpose_facade.should_transpose_dst()) |
930 | arg.dst = transpose_facade.get_dst_addr(ithr, oh, jpp); |
931 | else |
932 | arg.dst = &diff_dst[diff_dst_d.blk_off(n, c_off, oh)]; |
933 | |
934 | if (indices) { |
935 | if (transpose_facade.should_transpose_dst()) |
936 | arg.indices = transpose_facade.get_indices_addr(ithr, oh, jpp); |
937 | |
938 | else { |
939 | const size_t ind_off = indices_d.blk_off(n, c_off, oh); |
940 | arg.indices = &indices[ind_off * ind_dt_size]; |
941 | } |
942 | } |
943 | |
944 | const int zero_ih_start = (oh == 0) ? 0 : get_last_ih(oh - 1); |
945 | const int zero_ih_end = (oh == jpp.oh - 1) ? jpp.ih : get_last_ih(oh); |
946 | |
947 | arg.zero_id = 1; |
948 | arg.zero_ih = zero_ih_end - zero_ih_start; |
949 | if (transpose_facade.should_transpose_src()) |
950 | arg.zero_ptr |
951 | = transpose_facade.get_src_addr(ithr, zero_ih_start, jpp); |
952 | else |
953 | arg.zero_ptr |
954 | = &diff_src[diff_src_d.blk_off(n, c_off, zero_ih_start, 0)]; |
955 | |
956 | const int i_t_overflow = nstl::max(0, jpp.t_pad - oh * jpp.stride_h); |
957 | const int i_b_overflow |
958 | = nstl::max(jpp.ih, oh * jpp.stride_h + jpp.kh - jpp.t_pad) |
959 | - jpp.ih; |
960 | arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; |
961 | arg.kh_padding_shift = i_t_overflow * jpp.kw; |
962 | arg.ker_area_h = static_cast<float>(jpp.kh |
963 | - nstl::max(0, oh * jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) |
964 | - nstl::max(0, jpp.t_pad - oh * jpp.stride_h)); |
965 | |
966 | arg.ur_bc = ur_bc; |
967 | arg.b_c = b_c; |
968 | (*kernel_)(&arg); |
969 | }; |
970 | |
971 | auto process_block = [&](int ithr, int n, int b_c, int ur_bc) { |
972 | if (transpose_facade.should_transpose_dst()) |
973 | transpose_facade.execute_transpose_input(ithr, n, b_c); |
974 | |
975 | for (int oh = 0; oh < jpp.oh; ++oh) |
976 | ker(ithr, n, b_c, oh, ur_bc); |
977 | |
978 | if (transpose_facade.should_transpose_src()) |
979 | transpose_facade.execute_transpose_output(ithr, n, b_c); |
980 | }; |
981 | |
982 | const int nthr = jpp.nthr; |
983 | |
984 | parallel(nthr, [&](int ithr, int nthr) { |
985 | const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); |
986 | const std::size_t work_amount |
987 | = static_cast<std::size_t>(jpp.mb) * nb2_c; |
988 | if (static_cast<std::size_t>(ithr) >= work_amount) return; |
989 | |
990 | if (transpose_facade.should_fill_input_c_tail_with_zeros()) |
991 | transpose_facade.fill_input_c_tail_with_zeros(ithr, jpp); |
992 | |
993 | std::size_t start {0}, end {0}; |
994 | balance211(work_amount, nthr, ithr, start, end); |
995 | int n {0}, b2_c {0}; |
996 | utils::nd_iterator_init(start, n, jpp.mb, b2_c, nb2_c); |
997 | for (size_t iwork = start; iwork < end; ++iwork) { |
998 | const auto b_c = b2_c * jpp.ur_bc; |
999 | const auto ur_bc = nstl::min(jpp.ur_bc, jpp.nb_c - b_c); |
1000 | |
1001 | process_block(ithr, n, b_c, ur_bc); |
1002 | utils::nd_iterator_step(n, jpp.mb, b2_c, nb2_c); |
1003 | } |
1004 | }); |
1005 | } |
1006 | |
1007 | template <cpu_isa_t isa, data_type_t d_type> |
1008 | void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d( |
1009 | const data_t *diff_dst, const char *indices, data_t *diff_src, |
1010 | const exec_ctx_t &ctx) const { |
1011 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
1012 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1013 | const memory_desc_wrapper indices_d(pd()->workspace_md()); |
1014 | const size_t ind_dt_size |
1015 | = indices ? types::data_type_size(indices_d.data_type()) : 0; |
1016 | |
1017 | const auto &jpp = pd()->jpp_; |
1018 | |
1019 | using wsp_data_t = typename prec_traits<wsp_dt_>::type; |
1020 | using namespace jit_uni_pooling_utils; |
1021 | static constexpr int first_ithr = 0; |
1022 | |
1023 | const auto transpose_facade |
1024 | = bwd_pooling_transpose_facade_t<data_t, wsp_data_t, d_type>(jpp, |
1025 | trans_ctx_.get(), diff_src_d, diff_dst_d, indices_d, |
1026 | wsp_dt_, diff_src, diff_dst, indices, ctx); |
1027 | |
1028 | const auto trans_src = transpose_facade.should_transpose_src(); |
1029 | const auto trans_dst = transpose_facade.should_transpose_dst(); |
1030 | |
1031 | auto get_last_ih = [&](int oh) { |
1032 | return nstl::min( |
1033 | nstl::max(oh * jpp.stride_h - jpp.t_pad + jpp.kh, 0), jpp.ih); |
1034 | }; |
1035 | |
1036 | auto get_last_id = [&](int od) { |
1037 | return nstl::min( |
1038 | nstl::max(od * jpp.stride_d - jpp.f_pad + jpp.kd, 0), jpp.id); |
1039 | }; |
1040 | |
1041 | auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, |
1042 | int d_b_overflow, bool zero_inp, int kd, int ur_bc, |
1043 | int ithr) { |
1044 | auto arg = jit_pool_call_s(); |
1045 | |
1046 | const int ij = oh * jpp.stride_h; |
1047 | const int i_t_overflow = nstl::max(0, jpp.t_pad - ij); |
1048 | const int i_b_overflow |
1049 | = nstl::max(jpp.ih, ij + jpp.kh - jpp.t_pad) - jpp.ih; |
1050 | const int ih = nstl::max(ij - jpp.t_pad, 0); |
1051 | const int c_off |
1052 | = ((jpp.tag_kind == jit_memory_tag_kind_t::nspc) ? jpp.c_block |
1053 | : 1) |
1054 | * b_c; |
1055 | |
1056 | if (trans_src) |
1057 | arg.src = transpose_facade.get_src_addr_3d(ithr, id + kd, ih, jpp); |
1058 | else |
1059 | arg.src = (const void *)&diff_src[diff_src_d.blk_off( |
1060 | n, c_off, id + kd, ih)]; |
1061 | |
1062 | if (trans_dst) |
1063 | arg.dst = transpose_facade.get_dst_addr_3d(ithr, od, oh, jpp); |
1064 | else |
1065 | arg.dst = (const void |
1066 | *)&diff_dst[diff_dst_d.blk_off(n, c_off, od, oh)]; |
1067 | |
1068 | if (indices) { |
1069 | if (trans_dst) { |
1070 | arg.indices = transpose_facade.get_indices_addr_3d( |
1071 | ithr, od, oh, jpp); |
1072 | } else { |
1073 | const size_t ind_off = indices_d.blk_off(n, c_off, od, oh); |
1074 | arg.indices = (const void *)&indices[ind_off * ind_dt_size]; |
1075 | } |
1076 | } |
1077 | |
1078 | if (zero_inp) { |
1079 | const int zero_id_start = (od == 0) ? 0 : get_last_id(od - 1); |
1080 | const int zero_id_end |
1081 | = (od == jpp.od - 1) ? jpp.id : get_last_id(od); |
1082 | |
1083 | arg.zero_id = zero_id_end - zero_id_start; |
1084 | |
1085 | const int zero_ih_start = (oh == 0) ? 0 : get_last_ih(oh - 1); |
1086 | const int zero_ih_end |
1087 | = (oh == jpp.oh - 1) ? jpp.ih : get_last_ih(oh); |
1088 | arg.zero_ih = zero_ih_end - zero_ih_start; |
1089 | |
1090 | if (trans_src) |
1091 | arg.zero_ptr = transpose_facade.get_src_addr_3d( |
1092 | ithr, zero_id_start, zero_ih_start, jpp); |
1093 | else |
1094 | arg.zero_ptr = &diff_src[diff_src_d.blk_off( |
1095 | n, c_off, zero_id_start, zero_ih_start, 0)]; |
1096 | } else { |
1097 | arg.zero_id = 0; |
1098 | arg.zero_ih = 0; |
1099 | } |
1100 | |
1101 | arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; |
1102 | arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; |
1103 | arg.kh_padding_shift = i_t_overflow * jpp.kw |
1104 | + d_t_overflow * jpp.kw * jpp.kh + kd * jpp.kw * jpp.kh; |
1105 | arg.kd_padding_shift = (i_t_overflow + i_b_overflow) * jpp.kw; |
1106 | arg.ker_area_h = (float)(jpp.kh |
1107 | - nstl::max(0, |
1108 | oh * jpp.stride_h - jpp.t_pad + jpp.kh |
1109 | - jpp.ih) |
1110 | - nstl::max(0, jpp.t_pad - oh * jpp.stride_h)) |
1111 | * (jpp.kd |
1112 | - nstl::max(0, |
1113 | od * jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) |
1114 | - nstl::max(0, jpp.f_pad - od * jpp.stride_d)); |
1115 | |
1116 | arg.ur_bc = ur_bc; |
1117 | arg.b_c = b_c; |
1118 | (*kernel_)(&arg); |
1119 | }; |
1120 | |
1121 | auto process_simple = [&](int n, int b_c, int od, int ur_bc, int ithr) { |
1122 | const int ik = od * jpp.stride_d; |
1123 | const int d_t_overflow = nstl::max(0, jpp.f_pad - ik); |
1124 | const int d_b_overflow |
1125 | = nstl::max(jpp.id, ik + jpp.kd - jpp.f_pad) - jpp.id; |
1126 | const int id = nstl::max(ik - jpp.f_pad, 0); |
1127 | |
1128 | for (int oh = 0; oh < jpp.oh; ++oh) { |
1129 | ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, true, 0, ur_bc, |
1130 | ithr); |
1131 | } |
1132 | }; |
1133 | |
1134 | const int nthr = jpp.nthr; |
1135 | |
1136 | if (jpp.simple_alg) { |
1137 | if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { |
1138 | const dim_t nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); |
1139 | parallel_nd( |
1140 | jpp.mb, jpp.od, nb2_c, [&](dim_t n, dim_t od, dim_t b2_c) { |
1141 | const dim_t b_c = b2_c * jpp.ur_bc; |
1142 | const dim_t ur_bc |
1143 | = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c); |
1144 | process_simple(n, b_c, od, ur_bc, first_ithr); |
1145 | }); |
1146 | } else { |
1147 | assert(jpp.ur_bc == 1); |
1148 | if (trans_src || trans_dst) { |
1149 | parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, |
1150 | [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { |
1151 | if (trans_src) |
1152 | transpose_facade.execute_transpose_input( |
1153 | ithr, n, b_c); |
1154 | for (int od = 0; od < jpp.od; ++od) { |
1155 | process_simple(n, b_c, od, 1, ithr); |
1156 | } |
1157 | if (trans_dst) |
1158 | transpose_facade.execute_transpose_output( |
1159 | ithr, n, b_c); |
1160 | }); |
1161 | } else { |
1162 | parallel_nd(jpp.mb, jpp.nb_c, jpp.od, |
1163 | [&](dim_t n, dim_t b_c, dim_t od) { |
1164 | process_simple(n, b_c, od, 1, first_ithr); |
1165 | }); |
1166 | } |
1167 | } |
1168 | } else { |
1169 | const data_t zero_val = 0; |
1170 | if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { |
1171 | const size_t chunk_size = (size_t)jpp.ih * jpp.iw * jpp.c; |
1172 | parallel_nd(jpp.mb, jpp.id, [&](dim_t n, dim_t id) { |
1173 | const size_t offset = ((size_t)n * jpp.id + id) * chunk_size; |
1174 | PRAGMA_OMP_SIMD() |
1175 | for (size_t idx = 0; idx < chunk_size; ++idx) |
1176 | diff_src[offset + idx] = zero_val; |
1177 | }); |
1178 | } else { |
1179 | if (!trans_src) { |
1180 | const size_t chunk_size |
1181 | = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block; |
1182 | parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, |
1183 | [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b_c) { |
1184 | const size_t offset |
1185 | = ((size_t)n * jpp.nb_c + b_c) * chunk_size; |
1186 | PRAGMA_OMP_SIMD() |
1187 | for (size_t idx = 0; idx < chunk_size; ++idx) |
1188 | diff_src[offset + idx] = zero_val; |
1189 | }); |
1190 | } |
1191 | } |
1192 | |
1193 | const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); |
1194 | if (trans_src || trans_dst) { |
1195 | parallel_nd_ext(nthr, jpp.mb, nb2_c, |
1196 | [&](dim_t ithr, dim_t nthr, dim_t n, dim_t b2_c) { |
1197 | const dim_t b_c = b2_c * jpp.ur_bc; |
1198 | |
1199 | if (trans_dst) { |
1200 | transpose_facade.execute_transpose_input( |
1201 | ithr, n, b_c); |
1202 | |
1203 | size_t block_size = jpp.c_block * jpp.id * jpp.ih |
1204 | * jpp.iw * jpp.dt_size; |
1205 | |
1206 | const void *src = transpose_facade.get_src_addr_3d( |
1207 | ithr, 0, 0, jpp); |
1208 | std::memset((void *)src, zero_val, block_size); |
1209 | } |
1210 | |
1211 | for (dim_t kd = 0; kd < jpp.kd; ++kd) { |
1212 | const dim_t ur_bc = nstl::min( |
1213 | dim_t(jpp.ur_bc), jpp.nb_c - b_c); |
1214 | for (int od = 0; od < jpp.od; ++od) { |
1215 | const dim_t ik = od * jpp.stride_d; |
1216 | const dim_t d_t_overflow |
1217 | = nstl::max(dim_t(0), jpp.f_pad - ik); |
1218 | const dim_t d_b_overflow |
1219 | = nstl::max(dim_t(jpp.id), |
1220 | ik + jpp.kd - jpp.f_pad) |
1221 | - jpp.id; |
1222 | if (kd >= jpp.kd - d_t_overflow - d_b_overflow) |
1223 | continue; |
1224 | const dim_t id |
1225 | = nstl::max(ik - jpp.f_pad, dim_t(0)); |
1226 | for (dim_t oh = 0; oh < jpp.oh; ++oh) { |
1227 | ker(n, b_c, od, oh, id, d_t_overflow, |
1228 | d_b_overflow, false, kd, ur_bc, |
1229 | ithr); |
1230 | } |
1231 | } |
1232 | } |
1233 | |
1234 | if (trans_src) |
1235 | transpose_facade.execute_transpose_output( |
1236 | ithr, n, b_c); |
1237 | }); |
1238 | } else { |
1239 | for (dim_t kd = 0; kd < jpp.kd; ++kd) { |
1240 | parallel_nd(jpp.mb, nb2_c, [&](dim_t n, dim_t b2_c) { |
1241 | const dim_t b_c = b2_c * jpp.ur_bc; |
1242 | const dim_t ur_bc |
1243 | = nstl::min(dim_t(jpp.ur_bc), jpp.nb_c - b_c); |
1244 | for (int od = 0; od < jpp.od; ++od) { |
1245 | const dim_t ik = od * jpp.stride_d; |
1246 | const dim_t d_t_overflow |
1247 | = nstl::max(dim_t(0), jpp.f_pad - ik); |
1248 | const dim_t d_b_overflow |
1249 | = nstl::max(dim_t(jpp.id), |
1250 | ik + jpp.kd - jpp.f_pad) |
1251 | - jpp.id; |
1252 | if (kd >= jpp.kd - d_t_overflow - d_b_overflow) |
1253 | continue; |
1254 | const dim_t id = nstl::max(ik - jpp.f_pad, dim_t(0)); |
1255 | for (dim_t oh = 0; oh < jpp.oh; ++oh) { |
1256 | ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, |
1257 | false, kd, ur_bc, first_ithr); |
1258 | } |
1259 | } |
1260 | }); |
1261 | } |
1262 | } |
1263 | } |
1264 | } |
1265 | |
1266 | template struct jit_uni_pooling_fwd_t<sse41, data_type::f32>; |
1267 | template struct jit_uni_pooling_bwd_t<sse41, data_type::f32>; |
1268 | template struct jit_uni_pooling_fwd_t<avx, data_type::f32>; |
1269 | template struct jit_uni_pooling_bwd_t<avx, data_type::f32>; |
1270 | template struct jit_uni_pooling_fwd_t<avx2, data_type::f32>; |
1271 | template struct jit_uni_pooling_fwd_t<avx2_vnni_2, data_type::bf16>; |
1272 | template struct jit_uni_pooling_fwd_t<avx2_vnni_2, data_type::f16>; |
1273 | template struct jit_uni_pooling_bwd_t<avx2, data_type::f32>; |
1274 | template struct jit_uni_pooling_fwd_t<avx512_core, data_type::f32>; |
1275 | template struct jit_uni_pooling_bwd_t<avx512_core, data_type::f32>; |
1276 | template struct jit_uni_pooling_fwd_t<avx512_core, data_type::bf16>; |
1277 | template struct jit_uni_pooling_bwd_t<avx512_core, data_type::bf16>; |
1278 | template struct jit_uni_pooling_fwd_t<avx512_core_fp16, data_type::f16>; |
1279 | template struct jit_uni_pooling_bwd_t<avx512_core_fp16, data_type::f16>; |
1280 | |
1281 | } // namespace x64 |
1282 | } // namespace cpu |
1283 | } // namespace impl |
1284 | } // namespace dnnl |
1285 | |
1286 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1287 | |