1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/nstl.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/simple_q10n.hpp"
26
27#include "cpu/nchw_pooling.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32
33using namespace nstl;
34
35template <data_type_t d_type>
36nchw_pooling_fwd_t<d_type>::nchw_pooling_fwd_t(const pd_t *apd)
37 : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
38
39template <>
40status_t nchw_pooling_fwd_t<data_type::f32>::execute_forward(
41 const exec_ctx_t &ctx) const {
42 const auto alg = pd()->desc()->alg_kind;
43 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
44 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
45 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
46
47 const memory_desc_wrapper ws_d(pd()->workspace_md());
48 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
49
50 const dim_t MB = pd()->MB();
51 const dim_t C = pd()->OC();
52 const dim_t OD = pd()->OD();
53 const dim_t OH = pd()->OH();
54 const dim_t OW = pd()->OW();
55 const dim_t ID = pd()->ID();
56 const dim_t IH = pd()->IH();
57 const dim_t IW = pd()->IW();
58 const dim_t KD = pd()->KD();
59 const dim_t KH = pd()->KH();
60 const dim_t KW = pd()->KW();
61 const dim_t SD = pd()->KSD();
62 const dim_t SH = pd()->KSH();
63 const dim_t SW = pd()->KSW();
64 const dim_t padF = pd()->padFront();
65 const dim_t padT = pd()->padT();
66 const dim_t padL = pd()->padL();
67
68 const auto apply_offset = [](int index, int offset) {
69 return (index > offset) ? index - offset : 0;
70 };
71
72 const auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow,
73 dim_t value) {
74 if (ws) {
75 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
76 const size_t ws_offset = (size_t)OW * OH * OD * C * mb
77 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
78 + (size_t)OW * oh + (size_t)ow;
79 if (ws_dt == data_type::u8) {
80 assert(0 <= value
81 && value <= numeric_limits<typename prec_traits<
82 data_type::u8>::type>::max());
83 ws[ws_offset] = value;
84 } else
85 reinterpret_cast<int *>(ws)[ws_offset] = value;
86 }
87 };
88
89 const auto ker_max = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
90 dim_t ow) {
91 const auto src_off = IW * IH * ID * C * mb + IW * IH * ID * c;
92 const auto *src_loc = &src[src_off];
93
94 for_(dim_t kd = 0; kd < KD; ++kd)
95 for_(dim_t kh = 0; kh < KH; ++kh)
96 for (dim_t kw = 0; kw < KW; ++kw) {
97 const dim_t id = od * SD - padF + kd;
98 if (id < 0 || id >= ID) continue;
99 const dim_t ih = oh * SH - padT + kh;
100 if (ih < 0 || ih >= IH) continue;
101 const dim_t iw = ow * SW - padL + kw;
102 if (iw < 0 || iw >= IW) continue;
103
104 const auto src_off_loc = IW * IH * id + IW * ih + iw;
105 const auto &s = src_loc[src_off_loc];
106 if (s > d[0]) {
107 d[0] = s;
108 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
109 }
110 }
111 };
112
113 const auto ker_avg = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
114 dim_t ow) {
115 const auto id_start = apply_offset(od * SD, padF);
116 const auto ih_start = apply_offset(oh * SH, padT);
117 const auto iw_start = apply_offset(ow * SW, padL);
118 const auto id_end = min(od * SD - padF + KD, ID);
119 const auto ih_end = min(oh * SH - padT + KH, IH);
120 const auto iw_end = min(ow * SW - padL + KW, IW);
121
122 const auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
123 ? KD * KW * KH
124 : (id_end - id_start) * (ih_end - ih_start)
125 * (iw_end - iw_start);
126
127 const auto src_off
128 = IW * IH * ID * C * mb + IW * IH * ID * c + iw_start;
129
130 float d_val = 0;
131 for_(dim_t id = id_start; id < id_end; ++id)
132 for (dim_t ih = ih_start; ih < ih_end; ++ih) {
133 const auto src_off_loc = src_off + IW * IH * id + IW * ih;
134 const auto *src_loc = &src[src_off_loc];
135 for (dim_t iw = 0; iw < iw_end - iw_start; ++iw)
136 d_val += src_loc[iw];
137 }
138
139 return d_val / num_summands;
140 };
141
142 // Keep branches for post-ops since reference post-ops execution brings
143 // noticeable overhead.
144 const bool has_post_ops = pd()->attr()->post_ops_.len() > 0;
145
146 if (alg == alg_kind::pooling_max) {
147 if (has_post_ops) {
148 parallel_nd(MB, C, OD, OH, OW,
149 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
150 const size_t dst_offset = (size_t)OW * OH * OD * C * mb
151 + (size_t)OW * OH * OD * c
152 + (size_t)OW * OH * od + (size_t)OW * oh
153 + (size_t)ow;
154 data_t *d = &dst[dst_offset];
155 d[0] = numeric_limits<data_t>::lowest();
156 set_ws(mb, c, od, oh, ow, 0);
157 ker_max(d, mb, c, od, oh, ow);
158
159 ref_post_ops_t::args_t args;
160 args.ctx = &ctx;
161 args.l_offset = dst_offset;
162 args.dst_md = pd()->dst_md();
163 ref_post_ops_.execute(dst[dst_offset], args);
164 dst[dst_offset]
165 = saturate_and_round<data_t>(dst[dst_offset]);
166 });
167 } else {
168 parallel_nd(MB, C, OD, OH, OW,
169 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
170 const size_t dst_offset = (size_t)OW * OH * OD * C * mb
171 + (size_t)OW * OH * OD * c
172 + (size_t)OW * OH * od + (size_t)OW * oh
173 + (size_t)ow;
174 data_t *d = &dst[dst_offset];
175 d[0] = numeric_limits<data_t>::lowest();
176 set_ws(mb, c, od, oh, ow, 0);
177 ker_max(d, mb, c, od, oh, ow);
178
179 dst[dst_offset]
180 = saturate_and_round<data_t>(dst[dst_offset]);
181 });
182 }
183 } else {
184 if (has_post_ops) {
185 parallel_nd(MB, C, OD, OH, OW,
186 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
187 const size_t dst_offset = (size_t)OW * OH * OD * C * mb
188 + (size_t)OW * OH * OD * c
189 + (size_t)OW * OH * od + (size_t)OW * oh
190 + (size_t)ow;
191 data_t *d = &dst[dst_offset];
192 d[0] = 0;
193 auto res = ker_avg(d, mb, c, od, oh, ow);
194
195 ref_post_ops_t::args_t args;
196 args.ctx = &ctx;
197 args.l_offset = dst_offset;
198 args.dst_md = pd()->dst_md();
199 ref_post_ops_.execute(res, args);
200 d[0] = saturate_and_round<data_t>(res);
201 });
202 } else {
203 parallel_nd(MB, C, OD, OH, OW,
204 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
205 const size_t dst_offset = (size_t)OW * OH * OD * C * mb
206 + (size_t)OW * OH * OD * c
207 + (size_t)OW * OH * od + (size_t)OW * oh
208 + (size_t)ow;
209 data_t *d = &dst[dst_offset];
210 d[0] = 0;
211 auto res = ker_avg(d, mb, c, od, oh, ow);
212
213 d[0] = saturate_and_round<data_t>(res);
214 });
215 }
216 }
217
218 return status::success;
219}
220
221template <data_type_t d_type>
222status_t nchw_pooling_fwd_t<d_type>::execute_forward(
223 const exec_ctx_t &ctx) const {
224
225 auto alg = pd()->desc()->alg_kind;
226
227 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
228 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
229 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
230
231 auto scratchpad = ctx.get_scratchpad_grantor();
232 float *cvt_wsp = scratchpad.template get<float>(
233 memory_tracking::names::key_pool_src_bf16cvt);
234
235 const memory_desc_wrapper ws_d(pd()->workspace_md());
236 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
237
238 const dim_t MB = pd()->MB();
239 const dim_t C = pd()->OC();
240 const dim_t OD = pd()->OD();
241 const dim_t OH = pd()->OH();
242 const dim_t OW = pd()->OW();
243 const dim_t ID = pd()->ID();
244 const dim_t IH = pd()->IH();
245 const dim_t IW = pd()->IW();
246 const dim_t KD = pd()->KD();
247 const dim_t KH = pd()->KH();
248 const dim_t KW = pd()->KW();
249 const dim_t SD = pd()->KSD();
250 const dim_t SH = pd()->KSH();
251 const dim_t SW = pd()->KSW();
252 const dim_t padF = pd()->padFront();
253 const dim_t padT = pd()->padT();
254 const dim_t padL = pd()->padL();
255
256 const size_t simd_w = 16;
257 const size_t src_size = MB * C * ID * IH * IW;
258 const size_t blocked_size = src_size / simd_w;
259 const size_t tail_size = src_size % simd_w;
260
261 auto apply_offset = [=](int index, int offset) {
262 return (index > offset) ? index - offset : 0;
263 };
264
265 auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow,
266 dim_t value) {
267 if (ws) {
268 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
269 size_t ws_offset = (size_t)OW * OH * OD * C * mb
270 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
271 + (size_t)OW * oh + (size_t)ow;
272 if (ws_dt == data_type::u8) {
273 assert(0 <= value
274 && value <= numeric_limits<typename prec_traits<
275 data_type::u8>::type>::max());
276 ws[ws_offset] = value;
277 } else
278 reinterpret_cast<int *>(ws)[ws_offset] = value;
279 }
280 };
281
282 auto ker_max = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
283 dim_t ow) {
284 const auto src_off = IW * IH * ID * C * mb + IW * IH * ID * c;
285 const auto *src_loc = &cvt_wsp[src_off];
286
287 for_(dim_t kd = 0; kd < KD; ++kd)
288 for_(dim_t kh = 0; kh < KH; ++kh)
289 for (dim_t kw = 0; kw < KW; ++kw) {
290 const dim_t id = od * SD - padF + kd;
291 if (id < 0 || id >= ID) continue;
292 const dim_t ih = oh * SH - padT + kh;
293 if (ih < 0 || ih >= IH) continue;
294 const dim_t iw = ow * SW - padL + kw;
295 if (iw < 0 || iw >= IW) continue;
296
297 const auto src_off_loc = IW * IH * id + IW * ih + iw;
298 const auto &s = src_loc[src_off_loc];
299 if (s > d[0]) {
300 d[0] = s;
301 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
302 }
303 }
304 };
305
306 auto ker_avg = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
307 dim_t ow) {
308 auto id_start = apply_offset(od * SD, padF);
309 auto ih_start = apply_offset(oh * SH, padT);
310 auto iw_start = apply_offset(ow * SW, padL);
311 auto id_end = min(od * SD - padF + KD, ID);
312 auto ih_end = min(oh * SH - padT + KH, IH);
313 auto iw_end = min(ow * SW - padL + KW, IW);
314
315 auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
316 ? KD * KW * KH
317 : (id_end - id_start) * (ih_end - ih_start)
318 * (iw_end - iw_start);
319
320 const auto src_off
321 = IW * IH * ID * C * mb + IW * IH * ID * c + iw_start;
322
323 for_(dim_t id = id_start; id < id_end; ++id)
324 for (dim_t ih = ih_start; ih < ih_end; ++ih) {
325 const auto src_off_loc = src_off + IW * IH * id + IW * ih;
326 const auto *src_loc = &cvt_wsp[src_off_loc];
327 for (dim_t iw = 0; iw < iw_end - iw_start; ++iw)
328 d[0] += src_loc[iw];
329 }
330
331 d[0] = out_round<float>((float)d[0] / num_summands);
332 };
333
334 parallel_nd(blocked_size, [&](size_t i) {
335 types::cvt_to_float(&cvt_wsp[i * simd_w], &src[i * simd_w], simd_w);
336 });
337 if (tail_size)
338 types::cvt_to_float(&cvt_wsp[blocked_size * simd_w],
339 &src[blocked_size * simd_w], tail_size);
340
341 // Keep branches for post-ops since reference post-ops execution brings
342 // noticeable overhead.
343 const bool has_post_ops = pd()->attr()->post_ops_.len() > 0;
344
345 if (alg == alg_kind::pooling_max) {
346 if (has_post_ops) {
347 parallel_nd(MB, C, OD, OH, OW,
348 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
349 size_t dst_offset = (size_t)OW * OH * OD * C * mb
350 + (size_t)OW * OH * OD * c
351 + (size_t)OW * OH * od + (size_t)OW * oh
352 + (size_t)ow;
353 float d_fp32 = numeric_limits<data_t>::lowest();
354
355 set_ws(mb, c, od, oh, ow, 0);
356
357 ker_max(&d_fp32, mb, c, od, oh, ow);
358
359 ref_post_ops_t::args_t args;
360 args.ctx = &ctx;
361 args.l_offset = dst_offset;
362 args.dst_md = pd()->dst_md();
363 ref_post_ops_.execute(d_fp32, args);
364
365 dst[dst_offset] = static_cast<data_t>(d_fp32);
366 });
367 } else {
368 parallel_nd(MB, C, OD, OH, OW,
369 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
370 size_t dst_offset = (size_t)OW * OH * OD * C * mb
371 + (size_t)OW * OH * OD * c
372 + (size_t)OW * OH * od + (size_t)OW * oh
373 + (size_t)ow;
374 float d_fp32 = numeric_limits<data_t>::lowest();
375
376 set_ws(mb, c, od, oh, ow, 0);
377
378 ker_max(&d_fp32, mb, c, od, oh, ow);
379
380 dst[dst_offset] = static_cast<data_t>(d_fp32);
381 });
382 }
383 } else {
384 if (has_post_ops) {
385 parallel_nd(MB, C, OD, OH, OW,
386 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
387 size_t dst_offset = (size_t)OW * OH * OD * C * mb
388 + (size_t)OW * OH * OD * c
389 + (size_t)OW * OH * od + (size_t)OW * oh
390 + (size_t)ow;
391 float d_fp32 = 0.0f;
392 ker_avg(&d_fp32, mb, c, od, oh, ow);
393 ref_post_ops_t::args_t args;
394 args.ctx = &ctx;
395 args.l_offset = dst_offset;
396 args.dst_md = pd()->dst_md();
397 ref_post_ops_.execute(d_fp32, args);
398 dst[dst_offset] = static_cast<data_t>(d_fp32);
399 });
400 } else {
401 parallel_nd(MB, C, OD, OH, OW,
402 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
403 size_t dst_offset = (size_t)OW * OH * OD * C * mb
404 + (size_t)OW * OH * OD * c
405 + (size_t)OW * OH * od + (size_t)OW * oh
406 + (size_t)ow;
407 float d_fp32 = 0.0f;
408 ker_avg(&d_fp32, mb, c, od, oh, ow);
409
410 dst[dst_offset] = static_cast<data_t>(d_fp32);
411 });
412 }
413 }
414
415 return status::success;
416}
417
418template <>
419status_t nchw_pooling_bwd_t<data_type::f32>::execute_backward(
420 const exec_ctx_t &ctx) const {
421 auto alg = pd()->desc()->alg_kind;
422 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
423 const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
424
425 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
426 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
427 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
428
429 const memory_desc_wrapper ws_d(pd()->workspace_md());
430
431 const dim_t MB = pd()->MB();
432 const dim_t C = pd()->OC();
433 const dim_t OD = pd()->OD();
434 const dim_t OH = pd()->OH();
435 const dim_t OW = pd()->OW();
436 const dim_t ID = pd()->ID();
437 const dim_t IH = pd()->IH();
438 const dim_t IW = pd()->IW();
439 const dim_t KD = pd()->KD();
440 const dim_t KH = pd()->KH();
441 const dim_t KW = pd()->KW();
442 const dim_t SD = pd()->KSD();
443 const dim_t SH = pd()->KSH();
444 const dim_t SW = pd()->KSW();
445 const dim_t padF = pd()->padFront();
446 const dim_t padT = pd()->padT();
447 const dim_t padL = pd()->padL();
448
449 auto apply_offset = [=](int index, int offset) {
450 return (index > offset) ? index - offset : 0;
451 };
452
453 auto ker_zero = [=](dim_t mb, dim_t c) {
454 size_t diff_src_offset
455 = (size_t)mb * C * ID * IH * IW + (size_t)c * ID * IH * IW;
456 for_(dim_t id = 0; id < ID; ++id)
457 for_(dim_t ih = 0; ih < IH; ++ih)
458 for (dim_t iw = 0; iw < IW; ++iw) {
459 diff_src[diff_src_offset++] = 0;
460 }
461 };
462
463 auto ker_max = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
464 dim_t ow) {
465 auto b_c = ws_d.blocking_desc().inner_nblks == 0
466 ? 1
467 : ws_d.blocking_desc().inner_blks[0];
468 auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
469 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
470 : ws_d.blk_off(mb, c / b_c, ow))
471 + c % b_c;
472
473 const int index = ws_d.data_type() == data_type::u8
474 ? (int)ws[ws_offset]
475 : ((const int *)ws)[ws_offset];
476 const dim_t kw = index % KW;
477 const dim_t kh = (index / KW) % KH;
478 const dim_t kd = (index / KW) / KH;
479
480 const dim_t id = od * SD - padF + kd;
481 const dim_t ih = oh * SH - padT + kh;
482 const dim_t iw = ow * SW - padL + kw;
483
484 // If padding area could fit the kernel,
485 // then input displacement would be out of bounds.
486 // No need to back propagate there as padding is
487 // virtual in pooling_max case.
488 if (id < 0 || id >= ID) return;
489 if (ih < 0 || ih >= IH) return;
490 if (iw < 0 || iw >= IW) return;
491
492 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
493 + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
494 + (size_t)ih * IW + (size_t)iw;
495 diff_src[diff_src_offset] += d[0];
496 };
497
498 auto ker_avg = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
499 dim_t ow) {
500 dim_t id_start = apply_offset(od * SD, padF);
501 dim_t ih_start = apply_offset(oh * SH, padT);
502 dim_t iw_start = apply_offset(ow * SW, padL);
503 dim_t id_end = min(od * SD - padF + KD, ID);
504 dim_t ih_end = min(oh * SH - padT + KH, IH);
505 dim_t iw_end = min(ow * SW - padL + KW, IW);
506
507 size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
508 ? (size_t)KW * KH * KD
509 : (size_t)(id_end - id_start) * (ih_end - ih_start)
510 * (iw_end - iw_start);
511
512 for_(dim_t id = id_start; id < id_end; ++id)
513 for_(dim_t ih = ih_start; ih < ih_end; ++ih)
514 for (dim_t iw = iw_start; iw < iw_end; ++iw) {
515 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
516 + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
517 + (size_t)ih * IW + (size_t)iw;
518 diff_src[diff_src_offset] += d[0] / num_summands;
519 }
520 };
521
522 dim_t ow_start = max(dim_t(0), utils::div_up(padL - KW + 1, SW));
523 dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW);
524
525 dim_t oh_start = max(dim_t(0), utils::div_up(padT - KH + 1, SH));
526 dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH);
527
528 dim_t od_start = max(dim_t(0), utils::div_up(padF - KD + 1, SD));
529 dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD);
530
531 if (alg == alg_kind::pooling_max) {
532 parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
533 size_t diff_dst_offset_b
534 = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
535 ker_zero(mb, c);
536 for_(dim_t od = od_start; od < od_end; ++od)
537 for (dim_t oh = oh_start; oh < oh_end; ++oh) {
538 size_t diff_dst_offset = diff_dst_offset_b
539 + (size_t)od * OH * OW + (size_t)oh * OW;
540 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
541 const data_t *d = &diff_dst[diff_dst_offset + ow];
542 ker_max(d, mb, c, od, oh, ow);
543 }
544 }
545 });
546 } else {
547 parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
548 size_t diff_dst_offset_b
549 = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
550 ker_zero(mb, c);
551 for_(dim_t od = od_start; od < od_end; ++od)
552 for (dim_t oh = oh_start; oh < oh_end; ++oh) {
553 size_t diff_dst_offset = diff_dst_offset_b
554 + (size_t)od * OH * OW + (size_t)oh * OW;
555 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
556 const data_t *d = &diff_dst[diff_dst_offset + ow];
557 ker_avg(d, mb, c, od, oh, ow);
558 }
559 }
560 });
561 }
562
563 return status::success;
564}
565
566template <data_type_t d_type>
567status_t nchw_pooling_bwd_t<d_type>::execute_backward(
568 const exec_ctx_t &ctx) const {
569
570 auto alg = pd()->desc()->alg_kind;
571 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
572 const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
573
574 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
575 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
576 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
577
578 auto scratchpad = ctx.get_scratchpad_grantor();
579 float *cvt_src = scratchpad.template get<float>(
580 memory_tracking::names::key_pool_src_bf16cvt);
581 float *cvt_dst = scratchpad.template get<float>(
582 memory_tracking::names::key_pool_dst_bf16cvt);
583
584 const memory_desc_wrapper ws_d(pd()->workspace_md());
585
586 const dim_t MB = pd()->MB();
587 const dim_t C = pd()->OC();
588 const dim_t OD = pd()->OD();
589 const dim_t OH = pd()->OH();
590 const dim_t OW = pd()->OW();
591 const dim_t ID = pd()->ID();
592 const dim_t IH = pd()->IH();
593 const dim_t IW = pd()->IW();
594 const dim_t KD = pd()->KD();
595 const dim_t KH = pd()->KH();
596 const dim_t KW = pd()->KW();
597 const dim_t SD = pd()->KSD();
598 const dim_t SH = pd()->KSH();
599 const dim_t SW = pd()->KSW();
600 const dim_t padF = pd()->padFront();
601 const dim_t padT = pd()->padT();
602 const dim_t padL = pd()->padL();
603
604 const size_t dst_sp_size = pd()->OD() * pd()->OH() * pd()->OW();
605 const size_t src_sp_size = pd()->ID() * pd()->IH() * pd()->IW();
606
607 auto apply_offset = [=](int index, int offset) {
608 return (index > offset) ? index - offset : 0;
609 };
610
611 auto ker_zero = [=](float *diff_src, dim_t c_block_size) {
612 size_t diff_src_offset = 0;
613 for_(dim_t c = 0; c < c_block_size; ++c)
614 for_(dim_t id = 0; id < ID; ++id)
615 for_(dim_t ih = 0; ih < IH; ++ih)
616 for (dim_t iw = 0; iw < IW; ++iw) {
617 diff_src[diff_src_offset++] = 0.0f;
618 }
619 };
620
621 auto ker_max = [=](const float *d, float *diff_src, dim_t mb, dim_t c,
622 dim_t od, dim_t oh, dim_t ow) {
623 auto b_c = ws_d.blocking_desc().inner_nblks == 0
624 ? 1
625 : ws_d.blocking_desc().inner_blks[0];
626 auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
627 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
628 : ws_d.blk_off(mb, c / b_c, ow))
629 + c % b_c;
630
631 const int index = ws_d.data_type() == data_type::u8
632 ? (int)ws[ws_offset]
633 : ((const int *)ws)[ws_offset];
634 const dim_t kw = index % KW;
635 const dim_t kh = (index / KW) % KH;
636 const dim_t kd = (index / KW) / KH;
637
638 const dim_t id = od * SD - padF + kd;
639 const dim_t ih = oh * SH - padT + kh;
640 const dim_t iw = ow * SW - padL + kw;
641
642 // If padding area could fit the kernel,
643 // then input displacement would be out of bounds.
644 // No need to back propagate there as padding is
645 // virtual in pooling_max case.
646 if (id < 0 || id >= ID) return;
647 if (ih < 0 || ih >= IH) return;
648 if (iw < 0 || iw >= IW) return;
649
650 size_t diff_src_offset
651 = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
652 diff_src[diff_src_offset] += d[0];
653 };
654
655 auto ker_avg = [=](const float *d, float *diff_src, dim_t mb, dim_t c,
656 dim_t od, dim_t oh, dim_t ow) {
657 auto id_start = apply_offset(od * SD, padF);
658 auto ih_start = apply_offset(oh * SH, padT);
659 auto iw_start = apply_offset(ow * SW, padL);
660 auto id_end = min(od * SD - padF + KD, ID);
661 auto ih_end = min(oh * SH - padT + KH, IH);
662 auto iw_end = min(ow * SW - padL + KW, IW);
663
664 size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
665 ? (size_t)KW * KH * KD
666 : (size_t)(id_end - id_start) * (ih_end - ih_start)
667 * (iw_end - iw_start);
668
669 for_(dim_t id = id_start; id < id_end; ++id)
670 for_(dim_t ih = ih_start; ih < ih_end; ++ih)
671 for (dim_t iw = iw_start; iw < iw_end; ++iw) {
672 size_t diff_src_offset
673 = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
674 diff_src[diff_src_offset] += d[0] / num_summands;
675 }
676 };
677
678 dim_t ow_start = max(dim_t(0), utils::div_up(padL - KW + 1, SW));
679 dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW);
680
681 dim_t oh_start = max(dim_t(0), utils::div_up(padT - KH + 1, SH));
682 dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH);
683
684 dim_t od_start = max(dim_t(0), utils::div_up(padF - KD + 1, SD));
685 dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD);
686
687 dim_t c_blk = pd()->channel_block_size_;
688 dim_t c_blk_tail = C % c_blk;
689 const int nthr = pd()->nthr_;
690
691 if (alg == alg_kind::pooling_max) {
692 parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
693 [&](int ithr, int, dim_t mb, dim_t cb) {
694 bool is_last_c_block
695 = c_blk_tail > 0 && (cb + 1) * c_blk > C;
696 dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
697 size_t diff_dst_offset_b
698 = ((size_t)mb * C + (size_t)cb * c_blk) * OD * OH
699 * OW;
700 size_t diff_src_offset
701 = ((size_t)mb * C + (size_t)cb * c_blk) * ID * IH
702 * IW;
703 float *diff_dst_fp32 = &cvt_dst[ithr * dst_sp_size * c_blk];
704 float *diff_src_fp32 = &cvt_src[ithr * src_sp_size * c_blk];
705
706 ker_zero(diff_src_fp32, curr_c_block);
707
708 types::cvt_to_float(diff_dst_fp32,
709 &diff_dst[diff_dst_offset_b],
710 dst_sp_size * curr_c_block);
711
712 for_(dim_t c = 0; c < curr_c_block; ++c)
713 for_(dim_t od = od_start; od < od_end; ++od)
714 for (dim_t oh = oh_start; oh < oh_end; ++oh) {
715 size_t diff_dst_offset = (size_t)c * OD * OH * OW
716 + (size_t)od * OH * OW + (size_t)oh * OW;
717 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
718 const float *d
719 = &diff_dst_fp32[diff_dst_offset + ow];
720 ker_max(d, &diff_src_fp32[c * ID * IH * IW], mb,
721 cb * c_blk + c, od, oh, ow);
722 }
723 }
724 types::cvt_from_float(&diff_src[diff_src_offset],
725 diff_src_fp32, src_sp_size * curr_c_block);
726 });
727 } else {
728 parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
729 [&](int ithr, int, dim_t mb, dim_t cb) {
730 bool is_last_c_block
731 = c_blk_tail > 0 && (cb + 1) * c_blk > C;
732 dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
733 size_t diff_dst_offset_b = (size_t)mb * C * OD * OH * OW
734 + (size_t)cb * c_blk * OD * OH * OW;
735 float *diff_dst_fp32 = &cvt_dst[ithr * dst_sp_size * c_blk];
736 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
737 + (size_t)cb * c_blk * ID * IH * IW;
738 float *diff_src_fp32 = &cvt_src[ithr * src_sp_size * c_blk];
739
740 ker_zero(diff_src_fp32, curr_c_block);
741
742 types::cvt_to_float(diff_dst_fp32,
743 &diff_dst[diff_dst_offset_b],
744 dst_sp_size * curr_c_block);
745 for_(dim_t c = 0; c < curr_c_block; ++c)
746 for_(dim_t od = od_start; od < od_end; ++od)
747 for (dim_t oh = oh_start; oh < oh_end; ++oh) {
748 size_t diff_dst_offset = (size_t)c * OD * OH * OW
749 + (size_t)od * OH * OW + (size_t)oh * OW;
750 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
751 const float *d
752 = &diff_dst_fp32[diff_dst_offset + ow];
753 ker_avg(d, &diff_src_fp32[c * ID * IH * IW], mb,
754 cb * c_blk + c, od, oh, ow);
755 }
756 }
757 types::cvt_from_float(&diff_src[diff_src_offset],
758 diff_src_fp32, src_sp_size * curr_c_block);
759 });
760 }
761
762 return status::success;
763}
764template struct nchw_pooling_fwd_t<data_type::f32>;
765template struct nchw_pooling_bwd_t<data_type::f32>;
766template struct nchw_pooling_fwd_t<data_type::bf16>;
767template struct nchw_pooling_bwd_t<data_type::bf16>;
768template struct nchw_pooling_fwd_t<data_type::f16>;
769template struct nchw_pooling_bwd_t<data_type::f16>;
770} // namespace cpu
771} // namespace impl
772} // namespace dnnl
773
774// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
775