1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/c_types_map.hpp"
21#include "common/compiler_workarounds.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/math_utils.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26
27#include "cpu/simple_q10n.hpp"
28
29#include "cpu/nhwc_pooling.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34
35// Intel's LLVM-based compiler on Windows generates incorrect code with
36// PRAGMA_OMP_SIMD in some particular cases.
37// TODO: The issue above seems to be an additional one to the issue mentioned
38// in `CLANG_WA_01_SAFE_TO_USE_OMP_SIMD`. Once the later is resolved,
39// check specifically the former one, maybe it will go away as well.
40#if ((defined _WIN32) && (defined __INTEL_CLANG_COMPILER))
41#define SAFE_TO_USE_OMP_SIMD (0 && CLANG_WA_01_SAFE_TO_USE_OMP_SIMD)
42#else
43#define SAFE_TO_USE_OMP_SIMD (1 && CLANG_WA_01_SAFE_TO_USE_OMP_SIMD)
44#endif
45
46#define MEM_D(name) name##_d
47
48#define DECLARE_READ_STRIDES(name) \
49 const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \
50 const size_t name##_d_stride \
51 = is_3d ? MEM_D(name).blocking_desc().strides[ndims - 3] : 0; \
52 const size_t name##_h_stride \
53 = is_1d ? 0 : MEM_D(name).blocking_desc().strides[ndims - 2]; \
54 const size_t name##_w_stride \
55 = MEM_D(name).blocking_desc().strides[ndims - 1];
56
57namespace nhwc_pooling {
58size_t strided_offset(const int _n, const size_t _sn, const int _d,
59 const size_t _sd, const int _h, const size_t _sh, const int _w,
60 const size_t _sw) {
61 return _n * _sn + _d * _sd + _h * _sh + _w * _sw;
62}
63} // namespace nhwc_pooling
64
65template <data_type_t d_type>
66nhwc_pooling_fwd_t<d_type>::nhwc_pooling_fwd_t(const pd_t *apd)
67 : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
68
69template <data_type_t d_type>
70void nhwc_pooling_fwd_t<d_type>::array_div_by_const(const int n,
71 const ker_data_t *src, const size_t num, ker_data_t *dst) const {
72 for (int i = 0; i < n; ++i) {
73 const float ftmp = ((float)src[i]) / num;
74 dst[i] = out_round<ker_data_t>(ftmp);
75 }
76}
77
78template <data_type_t d_type>
79void nhwc_pooling_fwd_t<d_type>::array_add(
80 const int n, const ker_data_t *src, ker_data_t *dst) const {
81 for (int i = 0; i < n; ++i) {
82 dst[i] += src[i];
83 }
84}
85
86template <data_type_t d_type>
87void nhwc_pooling_fwd_t<d_type>::array_nhwc_max(const int n, ker_data_t *dst,
88 const ker_data_t *src, unsigned char *ws, const size_t ws_offset,
89 const data_type_t ws_dt, const int index) const {
90 assert(ws);
91#if SAFE_TO_USE_OMP_SIMD
92 PRAGMA_OMP_SIMD()
93#endif
94 for (int oc = 0; oc < n; ++oc) {
95 const auto s = src[oc];
96 ker_data_t mv = dst[oc];
97
98 // update index of maximum
99#if defined __INTEL_COMPILER
100 if (s > mv) {
101 // if (ws && (s > mv)) {
102 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
103 if (ws_dt == data_type::u8) {
104 assert(0 <= index && index <= 255);
105 ws[ws_offset + oc] = index;
106 } else
107 reinterpret_cast<int *>(ws)[ws_offset + oc] = index;
108 }
109#else
110 // Need to add explicit predicates for GCC to vectorize this.
111 // And although the resulting code is ugly, it is still 4 times
112 // faster than scalar
113 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
114
115 if (ws_dt == data_type::u8) {
116 assert(0 <= index && index <= 255);
117 const unsigned char predicate = (s > mv) ? 0xff : 0;
118 unsigned char current_value = ws[ws_offset + oc];
119 current_value = (predicate & (unsigned char)index)
120 | ((~predicate) & current_value);
121 ws[ws_offset + oc] = current_value;
122 } else {
123 auto wint = reinterpret_cast<int *>(ws);
124 const unsigned int predicate = (s > mv) ? 0xffffffff : 0;
125 unsigned int current_value = wint[ws_offset + oc];
126 current_value = (predicate & (unsigned int)index)
127 | ((~predicate) & current_value);
128 wint[ws_offset + oc] = current_value;
129 }
130#endif
131 // update maximum
132 dst[oc] = nstl::max(s, mv);
133 }
134}
135
136template <data_type_t d_type>
137void nhwc_pooling_fwd_t<d_type>::array_nhwc_initialize(const int n,
138 ker_data_t *dst, unsigned char *ws, const size_t ws_offset,
139 const data_type_t ws_dt) const {
140 assert(ws && (ws_dt == data_type::u8 || ws_dt == data_type::s32));
141#if SAFE_TO_USE_OMP_SIMD
142 PRAGMA_OMP_SIMD()
143#endif
144 for (int oc = 0; oc < n; ++oc) {
145 if (ws_dt == data_type::u8)
146 ws[ws_offset + oc] = 0;
147 else
148 reinterpret_cast<int *>(ws)[ws_offset + oc] = 0;
149 dst[oc] = nstl::numeric_limits<data_t>::lowest();
150 }
151}
152
153using namespace nstl;
154using namespace nhwc_pooling;
155
156template <>
157status_t nhwc_pooling_fwd_t<data_type::f32>::execute_forward(
158 const exec_ctx_t &ctx) const {
159
160 const auto alg = pd()->desc()->alg_kind;
161
162 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
163 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
164 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
165
166 const memory_desc_wrapper MEM_D(src)(pd()->src_md());
167 const memory_desc_wrapper MEM_D(dst)(pd()->dst_md());
168 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
169
170 const dim_t MB = pd()->MB();
171 const dim_t OC = pd()->OC();
172 const dim_t OD = pd()->OD();
173 const dim_t OH = pd()->OH();
174 const dim_t OW = pd()->OW();
175 const dim_t ID = pd()->ID();
176 const dim_t IH = pd()->IH();
177 const dim_t IW = pd()->IW();
178 const dim_t KD = pd()->KD();
179 const dim_t KH = pd()->KH();
180 const dim_t KW = pd()->KW();
181 const dim_t SD = pd()->KSD();
182 const dim_t SH = pd()->KSH();
183 const dim_t SW = pd()->KSW();
184 const dim_t padF = pd()->padFront();
185 const dim_t padT = pd()->padT();
186 const dim_t padL = pd()->padL();
187
188 const bool is_1d = pd()->desc()->src_desc.ndims == 3;
189 const bool is_3d = pd()->desc()->src_desc.ndims == 5;
190 const int ndims = pd()->ndims();
191 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
192
193 DECLARE_READ_STRIDES(src);
194 DECLARE_READ_STRIDES(dst);
195
196 const auto apply_offset = [](int index, int offset) {
197 return (index > offset) ? index - offset : 0;
198 };
199
200 const dim_t SP = OW * OH;
201 const dim_t OSP = SP * OD;
202
203 const auto get_logical_offset
204 = [&](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) -> dim_t {
205 return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
206 };
207 const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
208
209 parallel_nd(MB, OD, OH, OW, [&](dim_t mb, dim_t od, dim_t oh, dim_t ow) {
210 const size_t dst_offset_init = strided_offset(mb, dst_n_stride, od,
211 dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
212 if (alg == alg_kind::pooling_max) {
213 size_t ws_offset_init = 0;
214 if (ws) {
215 DECLARE_READ_STRIDES(ws);
216 ws_offset_init = strided_offset(mb, ws_n_stride, od,
217 ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
218 }
219 // Note: GCC 4.8.5 won't vectorize below
220 // simple loops unless they are singled out
221 // into separate helper routines:
222 // array_nhwc_initialize, array_nhwc_max
223 if (!ws) {
224 auto *const d = dst + dst_offset_init;
225 PRAGMA_OMP_SIMD()
226 for (dim_t oc = 0; oc < OC; ++oc) {
227 d[oc] = nstl::numeric_limits<data_t>::lowest();
228 }
229 } else {
230 array_nhwc_initialize(
231 OC, dst + dst_offset_init, ws, ws_offset_init, ws_dt);
232 }
233
234 for_(dim_t kd = 0; kd < KD; ++kd)
235 for_(dim_t kh = 0; kh < KH; ++kh)
236 for (dim_t kw = 0; kw < KW; ++kw) {
237 const dim_t id = od * SD - padF + kd;
238 const dim_t ih = oh * SH - padT + kh;
239 const dim_t iw = ow * SW - padL + kw;
240
241 if (id < 0 || id >= ID) continue;
242 if (ih < 0 || ih >= IH) continue;
243 if (iw < 0 || iw >= IW) continue;
244
245 const size_t src_offset_init = strided_offset(mb, src_n_stride,
246 id, src_d_stride, ih, src_h_stride, iw, src_w_stride);
247
248 if (!ws) {
249 auto *const s = src + src_offset_init;
250 auto *const d = dst + dst_offset_init;
251 PRAGMA_OMP_SIMD()
252 for (dim_t oc = 0; oc < OC; ++oc) {
253 d[oc] = nstl::max(s[oc], d[oc]);
254 }
255 } else {
256 array_nhwc_max(OC, dst + dst_offset_init,
257 src + src_offset_init, ws, ws_offset_init, ws_dt,
258 kd * KH * KW + kh * KW + kw);
259 }
260 }
261 } else {
262 // pooling_avg
263 const auto d = dst + dst_offset_init;
264
265 utils::array_set(d, 0, OC);
266
267 const auto id_start = apply_offset(od * SD, padF);
268 const auto ih_start = apply_offset(oh * SH, padT);
269 const auto iw_start = apply_offset(ow * SW, padL);
270 const auto id_end = min(od * SD - padF + KD, ID);
271 const auto ih_end = min(oh * SH - padT + KH, IH);
272 const auto iw_end = min(ow * SW - padL + KW, IW);
273
274 // it is cheaper to actually count this in a loop
275 // as the typical kernel is small
276 size_t num_summands = 0;
277
278 for_(dim_t id = id_start; id < id_end; ++id)
279 for_(dim_t ih = ih_start; ih < ih_end; ++ih)
280 for (dim_t iw = iw_start; iw < iw_end; ++iw) {
281 const size_t src_offset_init = strided_offset(mb, src_n_stride,
282 id, src_d_stride, ih, src_h_stride, iw, src_w_stride);
283 const auto s = src + src_offset_init;
284
285 // need to move the loop to separate function
286 // for GCC 4.8.5 to vectorize
287 array_add(OC, s, d);
288
289 num_summands++;
290 }
291
292 num_summands = (alg == alg_kind::pooling_avg_include_padding)
293 ? KW * KH * KD
294 : num_summands;
295
296 // need to move the loop to separate function
297 // for GCC 4.8.5 to vectorize
298 array_div_by_const(OC, d, num_summands, d);
299 }
300
301 if (are_postops_set) {
302 auto *const d = dst + dst_offset_init;
303 ref_post_ops_t::args_t args;
304 args.ctx = &ctx;
305 args.l_offset = get_logical_offset(mb, 0, od, oh, ow);
306 args.dst_md = pd()->dst_md();
307
308 for (dim_t oc = 0; oc < OC; ++oc) {
309 ref_post_ops_.execute(d[oc], args);
310 args.l_offset += OSP;
311 }
312 }
313 });
314 return status::success;
315}
316
317template <data_type_t d_type>
318status_t nhwc_pooling_fwd_t<d_type>::execute_forward(
319 const exec_ctx_t &ctx) const {
320
321 const auto alg = pd()->desc()->alg_kind;
322
323 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
324 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
325 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
326
327 auto scratchpad = ctx.get_scratchpad_grantor();
328 float *const cvt_src_wsp = scratchpad.template get<float>(
329 memory_tracking::names::key_pool_src_bf16cvt);
330 float *const cvt_dst_wsp = scratchpad.template get<float>(
331 memory_tracking::names::key_pool_dst_bf16cvt);
332
333 const memory_desc_wrapper MEM_D(src)(pd()->src_md());
334 const memory_desc_wrapper MEM_D(dst)(pd()->dst_md());
335 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
336
337 const dim_t MB = pd()->MB();
338 const dim_t OC = pd()->OC();
339 const dim_t OD = pd()->OD();
340 const dim_t OH = pd()->OH();
341 const dim_t OW = pd()->OW();
342 const dim_t ID = pd()->ID();
343 const dim_t IH = pd()->IH();
344 const dim_t IW = pd()->IW();
345 const dim_t KD = pd()->KD();
346 const dim_t KH = pd()->KH();
347 const dim_t KW = pd()->KW();
348 const dim_t SD = pd()->KSD();
349 const dim_t SH = pd()->KSH();
350 const dim_t SW = pd()->KSW();
351 const dim_t padF = pd()->padFront();
352 const dim_t padT = pd()->padT();
353 const dim_t padL = pd()->padL();
354
355 const bool is_1d = pd()->desc()->src_desc.ndims == 3;
356 const bool is_3d = pd()->desc()->src_desc.ndims == 5;
357 const int ndims = pd()->ndims();
358 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
359
360 DECLARE_READ_STRIDES(src);
361 DECLARE_READ_STRIDES(dst);
362
363 const auto apply_offset = [&](dim_t index, dim_t offset) {
364 return (index > offset) ? index - offset : 0;
365 };
366
367 const dim_t SP = OW * OH;
368 const dim_t OSP = SP * OD;
369
370 const auto get_logical_offset
371 = [&](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) -> dim_t {
372 return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
373 };
374 const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
375 const int nthr = pd()->nthr_;
376
377 parallel_nd_ext(nthr, MB, OD, OH, OW,
378 [&](int ithr, int, dim_t mb, dim_t od, dim_t oh, dim_t ow) {
379 const size_t dst_offset_init = strided_offset(mb, dst_n_stride,
380 od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
381 float *const dst_f32 = &cvt_dst_wsp[ithr * OC];
382 float *const src_f32 = &cvt_src_wsp[ithr * OC];
383
384 if (alg == alg_kind::pooling_max) {
385 size_t ws_offset_init = 0;
386 if (ws) {
387 DECLARE_READ_STRIDES(ws);
388 ws_offset_init = strided_offset(mb, ws_n_stride, od,
389 ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
390 };
391 // Note: GCC 4.8.5 won't vectorize below
392 // simple loops unless they are singled out
393 // into separate helper routines:
394 // array_nhwc_initialize, array_nhwc_max
395 if (!ws) {
396 PRAGMA_OMP_SIMD()
397 for (dim_t oc = 0; oc < OC; ++oc) {
398 dst_f32[oc]
399 = nstl::numeric_limits<data_t>::lowest();
400 }
401 } else {
402 array_nhwc_initialize(
403 OC, dst_f32, ws, ws_offset_init, ws_dt);
404 }
405
406 for_(dim_t kd = 0; kd < KD; ++kd)
407 for_(dim_t kh = 0; kh < KH; ++kh)
408 for (dim_t kw = 0; kw < KW; ++kw) {
409 const dim_t id = od * SD - padF + kd;
410 const dim_t ih = oh * SH - padT + kh;
411 const dim_t iw = ow * SW - padL + kw;
412
413 if (id < 0 || id >= ID) continue;
414 if (ih < 0 || ih >= IH) continue;
415 if (iw < 0 || iw >= IW) continue;
416
417 const size_t src_offset_init = strided_offset(mb,
418 src_n_stride, id, src_d_stride, ih,
419 src_h_stride, iw, src_w_stride);
420
421 types::cvt_to_float(src_f32, &src[src_offset_init], OC);
422
423 if (!ws) {
424 PRAGMA_OMP_SIMD()
425 for (dim_t oc = 0; oc < OC; ++oc) {
426 dst_f32[oc]
427 = nstl::max(src_f32[oc], dst_f32[oc]);
428 }
429 } else {
430 array_nhwc_max(OC, dst_f32, src_f32, ws,
431 ws_offset_init, ws_dt,
432 kd * KH * KW + kh * KW + kw);
433 }
434 }
435 } else {
436 // pooling_avg
437 utils::array_set(dst_f32, 0, OC);
438
439 const auto id_start = apply_offset(od * SD, padF);
440 const auto ih_start = apply_offset(oh * SH, padT);
441 const auto iw_start = apply_offset(ow * SW, padL);
442 const auto id_end = min(od * SD - padF + KD, ID);
443 const auto ih_end = min(oh * SH - padT + KH, IH);
444 const auto iw_end = min(ow * SW - padL + KW, IW);
445
446 // it is cheaper to actually count this in a loop
447 // as the typical kernel is small
448 size_t num_summands = 0;
449
450 for_(dim_t id = id_start; id < id_end; ++id)
451 for_(dim_t ih = ih_start; ih < ih_end; ++ih)
452 for (dim_t iw = iw_start; iw < iw_end; ++iw) {
453 size_t src_offset_init = strided_offset(mb,
454 src_n_stride, id, src_d_stride, ih,
455 src_h_stride, iw, src_w_stride);
456 types::cvt_to_float(src_f32, &src[src_offset_init], OC);
457
458 // need to move the loop to separate function
459 // for GCC 4.8.5 to vectorize
460 array_add(OC, src_f32, dst_f32);
461 num_summands++;
462 }
463
464 num_summands
465 = (alg == alg_kind::pooling_avg_include_padding)
466 ? KW * KH * KD
467 : num_summands;
468
469 // need to move the loop to separate function
470 // for GCC 4.8.5 to vectorize
471 array_div_by_const(OC, dst_f32, num_summands, dst_f32);
472 }
473
474 if (are_postops_set) {
475 ref_post_ops_t::args_t args;
476 args.ctx = &ctx;
477 args.l_offset = get_logical_offset(mb, 0, od, oh, ow);
478 args.dst_md = pd()->dst_md();
479
480 for (dim_t oc = 0; oc < OC; ++oc) {
481 ref_post_ops_.execute(dst_f32[oc], args);
482 args.l_offset += OSP;
483 }
484 }
485 types::cvt_from_float(dst + dst_offset_init, dst_f32, OC);
486 });
487 return status::success;
488}
489
490template <>
491status_t nhwc_pooling_bwd_t<data_type::f32>::execute_backward(
492 const exec_ctx_t &ctx) const {
493 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
494 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
495 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
496
497 const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md());
498 const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md());
499 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
500
501 const dim_t MB = pd()->MB();
502 const dim_t OC = pd()->OC();
503 const dim_t OD = pd()->OD();
504 const dim_t OH = pd()->OH();
505 const dim_t OW = pd()->OW();
506 const dim_t ID = pd()->ID();
507 const dim_t IH = pd()->IH();
508 const dim_t IW = pd()->IW();
509 const dim_t KD = pd()->KD();
510 const dim_t KH = pd()->KH();
511 const dim_t KW = pd()->KW();
512 const dim_t SD = pd()->KSD();
513 const dim_t SH = pd()->KSH();
514 const dim_t SW = pd()->KSW();
515 const dim_t padF = pd()->padFront();
516 const dim_t padT = pd()->padT();
517 const dim_t padL = pd()->padL();
518
519 const bool is_1d = pd()->desc()->diff_src_desc.ndims == 3;
520 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
521 const int ndims = pd()->ndims();
522 auto alg = pd()->desc()->alg_kind;
523
524 DECLARE_READ_STRIDES(diff_src);
525 DECLARE_READ_STRIDES(diff_dst);
526
527 auto apply_offset = [=](dim_t index, dim_t offset) {
528 return (index > offset) ? index - offset : 0;
529 };
530
531 parallel_nd(MB, ID, IH, IW, [&](dim_t mb, dim_t id, dim_t ih, dim_t iw) {
532 size_t src_offset_init
533 = strided_offset(mb, diff_src_n_stride, id, diff_src_d_stride,
534 ih, diff_src_h_stride, iw, diff_src_w_stride);
535
536 for (dim_t oc = 0; oc < OC; ++oc)
537 diff_src[src_offset_init + oc] = data_type_t(0);
538
539 // Find out which output cells may correspond to current
540 // input position. Current input postition divided by
541 // stride, with integer divide rounding down, is the
542 // right-most output.
543 // Left-most output may be computed if we decrement input
544 // by (kernel_size - 1) and then do the same division by
545 // stride.
546 dim_t od_left = max((id + padF - KD + 1) / SD, dim_t(0));
547 dim_t oh_left = max((ih + padT - KH + 1) / SH, dim_t(0));
548 dim_t ow_left = max((iw + padL - KW + 1) / SW, dim_t(0));
549 // Notice +1 here to preserve the C loop "less than"
550 // condition for continuing the for loop.
551 dim_t od_right = min((id + padF) / SD + 1, OD);
552 dim_t oh_right = min((ih + padT) / SH + 1, OH);
553 dim_t ow_right = min((iw + padL) / SW + 1, OW);
554
555 for_(dim_t od = od_left; od < od_right; ++od)
556 for_(dim_t oh = oh_left; oh < oh_right; ++oh)
557 for (dim_t ow = ow_left; ow < ow_right; ++ow) {
558 const dim_t kd = id - od * SD + padF;
559 const dim_t kh = ih - oh * SH + padT;
560 const dim_t kw = iw - ow * SW + padL;
561
562 if (kd < 0 || kd >= KD) continue;
563 if (kh < 0 || kh >= KH) continue;
564 if (kw < 0 || kw >= KW) continue;
565
566 size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, od,
567 diff_dst_d_stride, oh, diff_dst_h_stride, ow,
568 diff_dst_w_stride);
569
570 if (alg == alg_kind::pooling_max) {
571 DECLARE_READ_STRIDES(ws);
572 size_t ws_offset_init = strided_offset(mb, ws_n_stride, od,
573 ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
574 const dim_t index = kd * KH * KW + kh * KW + kw;
575 const unsigned char *ws_ = ws + ws_offset_init;
576 const int *intws_ = (int *)ws + ws_offset_init;
577 const bool ws_is_u8 = MEM_D(ws).data_type() == data_type::u8;
578
579#if SAFE_TO_USE_OMP_SIMD
580 PRAGMA_OMP_SIMD()
581#endif
582 for (dim_t oc = 0; oc < OC; ++oc) {
583 const int index_from_ws = ws_is_u8 ? ws_[oc] : intws_[oc];
584 const data_t d = diff_dst[dst_offset_init + oc];
585
586 // Check if kernel windows are disjoint, in this case
587 // there's no update needed and we just write there once
588 // otherwise we add value to the contents.
589 auto value = (index_from_ws == index) ? d : data_type_t(0);
590 if (!(KD == SD && KH == SH && KW == SW))
591 diff_src[src_offset_init + oc] += value;
592 else
593 diff_src[src_offset_init + oc] = value;
594 }
595 } else {
596 // pooling_avg
597 auto id_start = apply_offset(od * SD, padF);
598 auto ih_start = apply_offset(oh * SH, padT);
599 auto iw_start = apply_offset(ow * SW, padL);
600 auto id_end = min(od * SD - padF + KD, ID);
601 auto ih_end = min(oh * SH - padT + KH, IH);
602 auto iw_end = min(ow * SW - padL + KW, IW);
603
604 auto num_summands
605 = (alg == alg_kind::pooling_avg_include_padding)
606 ? KW * KH * KD
607 : (ih_end - ih_start) * (iw_end - iw_start)
608 * (id_end - id_start);
609
610 PRAGMA_OMP_SIMD()
611 for (dim_t oc = 0; oc < OC; ++oc) {
612 const data_t d = diff_dst[dst_offset_init + oc];
613 // Check if kernel windows are disjoint, in this case
614 // there's no update needed and we just write there once
615 // otherwise we add value to the contents.
616 if (!(KD == SD && KH == SH && KW == SW))
617 diff_src[src_offset_init + oc] += d / num_summands;
618 else
619 diff_src[src_offset_init + oc] = d / num_summands;
620 }
621 }
622 }
623 });
624 return status::success;
625}
626
627template <data_type_t d_type>
628status_t nhwc_pooling_bwd_t<d_type>::execute_backward(
629 const exec_ctx_t &ctx) const {
630
631 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
632 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
633 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
634
635 auto scratchpad = ctx.get_scratchpad_grantor();
636 float *cvt_dsrc = scratchpad.template get<float>(
637 memory_tracking::names::key_pool_src_bf16cvt);
638 float *cvt_ddst = scratchpad.template get<float>(
639 memory_tracking::names::key_pool_dst_bf16cvt);
640
641 const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md());
642 const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md());
643 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
644
645 const dim_t MB = pd()->MB();
646 const dim_t OC = pd()->OC();
647 const dim_t OD = pd()->OD();
648 const dim_t OH = pd()->OH();
649 const dim_t OW = pd()->OW();
650 const dim_t ID = pd()->ID();
651 const dim_t IH = pd()->IH();
652 const dim_t IW = pd()->IW();
653 const dim_t KD = pd()->KD();
654 const dim_t KH = pd()->KH();
655 const dim_t KW = pd()->KW();
656 const dim_t SD = pd()->KSD();
657 const dim_t SH = pd()->KSH();
658 const dim_t SW = pd()->KSW();
659 const dim_t padF = pd()->padFront();
660 const dim_t padT = pd()->padT();
661 const dim_t padL = pd()->padL();
662
663 const bool is_1d = pd()->desc()->diff_src_desc.ndims == 3;
664 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
665 const int ndims = pd()->ndims();
666 auto alg = pd()->desc()->alg_kind;
667
668 DECLARE_READ_STRIDES(diff_src);
669 DECLARE_READ_STRIDES(diff_dst);
670
671 auto apply_offset = [=](dim_t index, dim_t offset) {
672 return (index > offset) ? index - offset : 0;
673 };
674 const int nthr = pd()->nthr_;
675
676 parallel_nd_ext(nthr, MB, ID, IH, IW,
677 [&](int ithr, int, dim_t mb, dim_t id, dim_t ih, dim_t iw) {
678 size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
679 id, diff_src_d_stride, ih, diff_src_h_stride, iw,
680 diff_src_w_stride);
681
682 float *diff_dst_fp32 = &cvt_ddst[ithr * OC];
683 float *diff_src_fp32 = &cvt_dsrc[ithr * OC];
684
685 for (dim_t oc = 0; oc < OC; ++oc) {
686 diff_src_fp32[oc] = 0.f;
687 diff_src[src_offset_init + oc] = (bfloat16_t)0.f;
688 }
689
690 // Find out which output cells may correspond to current
691 // input position. Current input postition divided by
692 // stride, with integer divide rounding down, is the
693 // right-most output.
694 // Left-most output may be computed if we decrement input
695 // by (kernel_size - 1) and then do the same division by
696 // stride.
697 dim_t od_left = max((id + padF - KD + 1) / SD, dim_t(0));
698 dim_t oh_left = max((ih + padT - KH + 1) / SH, dim_t(0));
699 dim_t ow_left = max((iw + padL - KW + 1) / SW, dim_t(0));
700 // Notice +1 here to preserve the C loop "less than"
701 // condition for continuing the for loop.
702 dim_t od_right = min((id + padF) / SD + 1, OD);
703 dim_t oh_right = min((ih + padT) / SH + 1, OH);
704 dim_t ow_right = min((iw + padL) / SW + 1, OW);
705
706 for_(dim_t od = od_left; od < od_right; ++od)
707 for_(dim_t oh = oh_left; oh < oh_right; ++oh)
708 for (dim_t ow = ow_left; ow < ow_right; ++ow) {
709 const dim_t kd = id - od * SD + padF;
710 const dim_t kh = ih - oh * SH + padT;
711 const dim_t kw = iw - ow * SW + padL;
712
713 if (kd < 0 || kd >= KD) continue;
714 if (kh < 0 || kh >= KH) continue;
715 if (kw < 0 || kw >= KW) continue;
716
717 size_t dst_offset_init = strided_offset(mb,
718 diff_dst_n_stride, od, diff_dst_d_stride, oh,
719 diff_dst_h_stride, ow, diff_dst_w_stride);
720 types::cvt_to_float(
721 diff_dst_fp32, &diff_dst[dst_offset_init], OC);
722
723 if (alg == alg_kind::pooling_max) {
724 DECLARE_READ_STRIDES(ws);
725 size_t ws_offset_init = strided_offset(mb, ws_n_stride,
726 od, ws_d_stride, oh, ws_h_stride, ow,
727 ws_w_stride);
728 const dim_t index = kd * KH * KW + kh * KW + kw;
729 const unsigned char *ws_ = ws + ws_offset_init;
730 const int *intws_ = (int *)ws + ws_offset_init;
731 const bool ws_is_u8
732 = MEM_D(ws).data_type() == data_type::u8;
733
734#if SAFE_TO_USE_OMP_SIMD
735 PRAGMA_OMP_SIMD()
736#endif
737 for (dim_t oc = 0; oc < OC; ++oc) {
738 const int index_from_ws
739 = ws_is_u8 ? ws_[oc] : intws_[oc];
740
741 // Check if kernel windows are disjoint, in this case
742 // there's no update needed and we just write there once
743 // otherwise we add value to the contents.
744 float value = (index_from_ws == index)
745 ? diff_dst_fp32[oc]
746 : 0.0f;
747 if (!(KD == SD && KH == SH && KW == SW))
748 diff_src_fp32[oc] += value;
749 else
750 diff_src_fp32[oc] = value;
751 }
752 } else {
753 // pooling_avg
754 auto id_start = apply_offset(od * SD, padF);
755 auto ih_start = apply_offset(oh * SH, padT);
756 auto iw_start = apply_offset(ow * SW, padL);
757 auto id_end = min(od * SD - padF + KD, ID);
758 auto ih_end = min(oh * SH - padT + KH, IH);
759 auto iw_end = min(ow * SW - padL + KW, IW);
760
761 auto num_summands
762 = (alg == alg_kind::pooling_avg_include_padding)
763 ? KW * KH * KD
764 : (ih_end - ih_start) * (iw_end - iw_start)
765 * (id_end - id_start);
766
767 PRAGMA_OMP_SIMD()
768 for (dim_t oc = 0; oc < OC; ++oc) {
769 // Check if kernel windows are disjoint, in this case
770 // there's no update needed and we just write there once
771 // otherwise we add value to the contents.
772 if (!(KD == SD && KH == SH && KW == SW))
773 diff_src_fp32[oc]
774 += diff_dst_fp32[oc] / num_summands;
775 else
776 diff_src_fp32[oc]
777 = diff_dst_fp32[oc] / num_summands;
778 }
779 }
780 types::cvt_from_float(
781 &diff_src[src_offset_init], diff_src_fp32, OC);
782 }
783 });
784 return status::success;
785}
786
787template struct nhwc_pooling_fwd_t<data_type::f32>;
788template struct nhwc_pooling_bwd_t<data_type::f32>;
789template struct nhwc_pooling_fwd_t<data_type::bf16>;
790template struct nhwc_pooling_bwd_t<data_type::bf16>;
791template struct nhwc_pooling_fwd_t<data_type::f16>;
792template struct nhwc_pooling_bwd_t<data_type::f16>;
793
794} // namespace cpu
795} // namespace impl
796} // namespace dnnl
797
798// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
799