1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/compiler_workarounds.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/math_utils.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | |
27 | #include "cpu/simple_q10n.hpp" |
28 | |
29 | #include "cpu/nhwc_pooling.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | |
35 | // Intel's LLVM-based compiler on Windows generates incorrect code with |
36 | // PRAGMA_OMP_SIMD in some particular cases. |
37 | // TODO: The issue above seems to be an additional one to the issue mentioned |
38 | // in `CLANG_WA_01_SAFE_TO_USE_OMP_SIMD`. Once the later is resolved, |
39 | // check specifically the former one, maybe it will go away as well. |
40 | #if ((defined _WIN32) && (defined __INTEL_CLANG_COMPILER)) |
41 | #define SAFE_TO_USE_OMP_SIMD (0 && CLANG_WA_01_SAFE_TO_USE_OMP_SIMD) |
42 | #else |
43 | #define SAFE_TO_USE_OMP_SIMD (1 && CLANG_WA_01_SAFE_TO_USE_OMP_SIMD) |
44 | #endif |
45 | |
46 | #define MEM_D(name) name##_d |
47 | |
48 | #define DECLARE_READ_STRIDES(name) \ |
49 | const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \ |
50 | const size_t name##_d_stride \ |
51 | = is_3d ? MEM_D(name).blocking_desc().strides[ndims - 3] : 0; \ |
52 | const size_t name##_h_stride \ |
53 | = is_1d ? 0 : MEM_D(name).blocking_desc().strides[ndims - 2]; \ |
54 | const size_t name##_w_stride \ |
55 | = MEM_D(name).blocking_desc().strides[ndims - 1]; |
56 | |
57 | namespace nhwc_pooling { |
58 | size_t strided_offset(const int _n, const size_t _sn, const int _d, |
59 | const size_t _sd, const int _h, const size_t _sh, const int _w, |
60 | const size_t _sw) { |
61 | return _n * _sn + _d * _sd + _h * _sh + _w * _sw; |
62 | } |
63 | } // namespace nhwc_pooling |
64 | |
65 | template <data_type_t d_type> |
66 | nhwc_pooling_fwd_t<d_type>::nhwc_pooling_fwd_t(const pd_t *apd) |
67 | : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {} |
68 | |
69 | template <data_type_t d_type> |
70 | void nhwc_pooling_fwd_t<d_type>::array_div_by_const(const int n, |
71 | const ker_data_t *src, const size_t num, ker_data_t *dst) const { |
72 | for (int i = 0; i < n; ++i) { |
73 | const float ftmp = ((float)src[i]) / num; |
74 | dst[i] = out_round<ker_data_t>(ftmp); |
75 | } |
76 | } |
77 | |
78 | template <data_type_t d_type> |
79 | void nhwc_pooling_fwd_t<d_type>::array_add( |
80 | const int n, const ker_data_t *src, ker_data_t *dst) const { |
81 | for (int i = 0; i < n; ++i) { |
82 | dst[i] += src[i]; |
83 | } |
84 | } |
85 | |
86 | template <data_type_t d_type> |
87 | void nhwc_pooling_fwd_t<d_type>::array_nhwc_max(const int n, ker_data_t *dst, |
88 | const ker_data_t *src, unsigned char *ws, const size_t ws_offset, |
89 | const data_type_t ws_dt, const int index) const { |
90 | assert(ws); |
91 | #if SAFE_TO_USE_OMP_SIMD |
92 | PRAGMA_OMP_SIMD() |
93 | #endif |
94 | for (int oc = 0; oc < n; ++oc) { |
95 | const auto s = src[oc]; |
96 | ker_data_t mv = dst[oc]; |
97 | |
98 | // update index of maximum |
99 | #if defined __INTEL_COMPILER |
100 | if (s > mv) { |
101 | // if (ws && (s > mv)) { |
102 | assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); |
103 | if (ws_dt == data_type::u8) { |
104 | assert(0 <= index && index <= 255); |
105 | ws[ws_offset + oc] = index; |
106 | } else |
107 | reinterpret_cast<int *>(ws)[ws_offset + oc] = index; |
108 | } |
109 | #else |
110 | // Need to add explicit predicates for GCC to vectorize this. |
111 | // And although the resulting code is ugly, it is still 4 times |
112 | // faster than scalar |
113 | assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); |
114 | |
115 | if (ws_dt == data_type::u8) { |
116 | assert(0 <= index && index <= 255); |
117 | const unsigned char predicate = (s > mv) ? 0xff : 0; |
118 | unsigned char current_value = ws[ws_offset + oc]; |
119 | current_value = (predicate & (unsigned char)index) |
120 | | ((~predicate) & current_value); |
121 | ws[ws_offset + oc] = current_value; |
122 | } else { |
123 | auto wint = reinterpret_cast<int *>(ws); |
124 | const unsigned int predicate = (s > mv) ? 0xffffffff : 0; |
125 | unsigned int current_value = wint[ws_offset + oc]; |
126 | current_value = (predicate & (unsigned int)index) |
127 | | ((~predicate) & current_value); |
128 | wint[ws_offset + oc] = current_value; |
129 | } |
130 | #endif |
131 | // update maximum |
132 | dst[oc] = nstl::max(s, mv); |
133 | } |
134 | } |
135 | |
136 | template <data_type_t d_type> |
137 | void nhwc_pooling_fwd_t<d_type>::array_nhwc_initialize(const int n, |
138 | ker_data_t *dst, unsigned char *ws, const size_t ws_offset, |
139 | const data_type_t ws_dt) const { |
140 | assert(ws && (ws_dt == data_type::u8 || ws_dt == data_type::s32)); |
141 | #if SAFE_TO_USE_OMP_SIMD |
142 | PRAGMA_OMP_SIMD() |
143 | #endif |
144 | for (int oc = 0; oc < n; ++oc) { |
145 | if (ws_dt == data_type::u8) |
146 | ws[ws_offset + oc] = 0; |
147 | else |
148 | reinterpret_cast<int *>(ws)[ws_offset + oc] = 0; |
149 | dst[oc] = nstl::numeric_limits<data_t>::lowest(); |
150 | } |
151 | } |
152 | |
153 | using namespace nstl; |
154 | using namespace nhwc_pooling; |
155 | |
156 | template <> |
157 | status_t nhwc_pooling_fwd_t<data_type::f32>::execute_forward( |
158 | const exec_ctx_t &ctx) const { |
159 | |
160 | const auto alg = pd()->desc()->alg_kind; |
161 | |
162 | const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
163 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
164 | auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); |
165 | |
166 | const memory_desc_wrapper MEM_D(src)(pd()->src_md()); |
167 | const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); |
168 | const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); |
169 | |
170 | const dim_t MB = pd()->MB(); |
171 | const dim_t OC = pd()->OC(); |
172 | const dim_t OD = pd()->OD(); |
173 | const dim_t OH = pd()->OH(); |
174 | const dim_t OW = pd()->OW(); |
175 | const dim_t ID = pd()->ID(); |
176 | const dim_t IH = pd()->IH(); |
177 | const dim_t IW = pd()->IW(); |
178 | const dim_t KD = pd()->KD(); |
179 | const dim_t KH = pd()->KH(); |
180 | const dim_t KW = pd()->KW(); |
181 | const dim_t SD = pd()->KSD(); |
182 | const dim_t SH = pd()->KSH(); |
183 | const dim_t SW = pd()->KSW(); |
184 | const dim_t padF = pd()->padFront(); |
185 | const dim_t padT = pd()->padT(); |
186 | const dim_t padL = pd()->padL(); |
187 | |
188 | const bool is_1d = pd()->desc()->src_desc.ndims == 3; |
189 | const bool is_3d = pd()->desc()->src_desc.ndims == 5; |
190 | const int ndims = pd()->ndims(); |
191 | const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; |
192 | |
193 | DECLARE_READ_STRIDES(src); |
194 | DECLARE_READ_STRIDES(dst); |
195 | |
196 | const auto apply_offset = [](int index, int offset) { |
197 | return (index > offset) ? index - offset : 0; |
198 | }; |
199 | |
200 | const dim_t SP = OW * OH; |
201 | const dim_t OSP = SP * OD; |
202 | |
203 | const auto get_logical_offset |
204 | = [&](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) -> dim_t { |
205 | return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow; |
206 | }; |
207 | const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty()); |
208 | |
209 | parallel_nd(MB, OD, OH, OW, [&](dim_t mb, dim_t od, dim_t oh, dim_t ow) { |
210 | const size_t dst_offset_init = strided_offset(mb, dst_n_stride, od, |
211 | dst_d_stride, oh, dst_h_stride, ow, dst_w_stride); |
212 | if (alg == alg_kind::pooling_max) { |
213 | size_t ws_offset_init = 0; |
214 | if (ws) { |
215 | DECLARE_READ_STRIDES(ws); |
216 | ws_offset_init = strided_offset(mb, ws_n_stride, od, |
217 | ws_d_stride, oh, ws_h_stride, ow, ws_w_stride); |
218 | } |
219 | // Note: GCC 4.8.5 won't vectorize below |
220 | // simple loops unless they are singled out |
221 | // into separate helper routines: |
222 | // array_nhwc_initialize, array_nhwc_max |
223 | if (!ws) { |
224 | auto *const d = dst + dst_offset_init; |
225 | PRAGMA_OMP_SIMD() |
226 | for (dim_t oc = 0; oc < OC; ++oc) { |
227 | d[oc] = nstl::numeric_limits<data_t>::lowest(); |
228 | } |
229 | } else { |
230 | array_nhwc_initialize( |
231 | OC, dst + dst_offset_init, ws, ws_offset_init, ws_dt); |
232 | } |
233 | |
234 | for_(dim_t kd = 0; kd < KD; ++kd) |
235 | for_(dim_t kh = 0; kh < KH; ++kh) |
236 | for (dim_t kw = 0; kw < KW; ++kw) { |
237 | const dim_t id = od * SD - padF + kd; |
238 | const dim_t ih = oh * SH - padT + kh; |
239 | const dim_t iw = ow * SW - padL + kw; |
240 | |
241 | if (id < 0 || id >= ID) continue; |
242 | if (ih < 0 || ih >= IH) continue; |
243 | if (iw < 0 || iw >= IW) continue; |
244 | |
245 | const size_t src_offset_init = strided_offset(mb, src_n_stride, |
246 | id, src_d_stride, ih, src_h_stride, iw, src_w_stride); |
247 | |
248 | if (!ws) { |
249 | auto *const s = src + src_offset_init; |
250 | auto *const d = dst + dst_offset_init; |
251 | PRAGMA_OMP_SIMD() |
252 | for (dim_t oc = 0; oc < OC; ++oc) { |
253 | d[oc] = nstl::max(s[oc], d[oc]); |
254 | } |
255 | } else { |
256 | array_nhwc_max(OC, dst + dst_offset_init, |
257 | src + src_offset_init, ws, ws_offset_init, ws_dt, |
258 | kd * KH * KW + kh * KW + kw); |
259 | } |
260 | } |
261 | } else { |
262 | // pooling_avg |
263 | const auto d = dst + dst_offset_init; |
264 | |
265 | utils::array_set(d, 0, OC); |
266 | |
267 | const auto id_start = apply_offset(od * SD, padF); |
268 | const auto ih_start = apply_offset(oh * SH, padT); |
269 | const auto iw_start = apply_offset(ow * SW, padL); |
270 | const auto id_end = min(od * SD - padF + KD, ID); |
271 | const auto ih_end = min(oh * SH - padT + KH, IH); |
272 | const auto iw_end = min(ow * SW - padL + KW, IW); |
273 | |
274 | // it is cheaper to actually count this in a loop |
275 | // as the typical kernel is small |
276 | size_t num_summands = 0; |
277 | |
278 | for_(dim_t id = id_start; id < id_end; ++id) |
279 | for_(dim_t ih = ih_start; ih < ih_end; ++ih) |
280 | for (dim_t iw = iw_start; iw < iw_end; ++iw) { |
281 | const size_t src_offset_init = strided_offset(mb, src_n_stride, |
282 | id, src_d_stride, ih, src_h_stride, iw, src_w_stride); |
283 | const auto s = src + src_offset_init; |
284 | |
285 | // need to move the loop to separate function |
286 | // for GCC 4.8.5 to vectorize |
287 | array_add(OC, s, d); |
288 | |
289 | num_summands++; |
290 | } |
291 | |
292 | num_summands = (alg == alg_kind::pooling_avg_include_padding) |
293 | ? KW * KH * KD |
294 | : num_summands; |
295 | |
296 | // need to move the loop to separate function |
297 | // for GCC 4.8.5 to vectorize |
298 | array_div_by_const(OC, d, num_summands, d); |
299 | } |
300 | |
301 | if (are_postops_set) { |
302 | auto *const d = dst + dst_offset_init; |
303 | ref_post_ops_t::args_t args; |
304 | args.ctx = &ctx; |
305 | args.l_offset = get_logical_offset(mb, 0, od, oh, ow); |
306 | args.dst_md = pd()->dst_md(); |
307 | |
308 | for (dim_t oc = 0; oc < OC; ++oc) { |
309 | ref_post_ops_.execute(d[oc], args); |
310 | args.l_offset += OSP; |
311 | } |
312 | } |
313 | }); |
314 | return status::success; |
315 | } |
316 | |
317 | template <data_type_t d_type> |
318 | status_t nhwc_pooling_fwd_t<d_type>::execute_forward( |
319 | const exec_ctx_t &ctx) const { |
320 | |
321 | const auto alg = pd()->desc()->alg_kind; |
322 | |
323 | const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
324 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
325 | auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); |
326 | |
327 | auto scratchpad = ctx.get_scratchpad_grantor(); |
328 | float *const cvt_src_wsp = scratchpad.template get<float>( |
329 | memory_tracking::names::key_pool_src_bf16cvt); |
330 | float *const cvt_dst_wsp = scratchpad.template get<float>( |
331 | memory_tracking::names::key_pool_dst_bf16cvt); |
332 | |
333 | const memory_desc_wrapper MEM_D(src)(pd()->src_md()); |
334 | const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); |
335 | const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); |
336 | |
337 | const dim_t MB = pd()->MB(); |
338 | const dim_t OC = pd()->OC(); |
339 | const dim_t OD = pd()->OD(); |
340 | const dim_t OH = pd()->OH(); |
341 | const dim_t OW = pd()->OW(); |
342 | const dim_t ID = pd()->ID(); |
343 | const dim_t IH = pd()->IH(); |
344 | const dim_t IW = pd()->IW(); |
345 | const dim_t KD = pd()->KD(); |
346 | const dim_t KH = pd()->KH(); |
347 | const dim_t KW = pd()->KW(); |
348 | const dim_t SD = pd()->KSD(); |
349 | const dim_t SH = pd()->KSH(); |
350 | const dim_t SW = pd()->KSW(); |
351 | const dim_t padF = pd()->padFront(); |
352 | const dim_t padT = pd()->padT(); |
353 | const dim_t padL = pd()->padL(); |
354 | |
355 | const bool is_1d = pd()->desc()->src_desc.ndims == 3; |
356 | const bool is_3d = pd()->desc()->src_desc.ndims == 5; |
357 | const int ndims = pd()->ndims(); |
358 | const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; |
359 | |
360 | DECLARE_READ_STRIDES(src); |
361 | DECLARE_READ_STRIDES(dst); |
362 | |
363 | const auto apply_offset = [&](dim_t index, dim_t offset) { |
364 | return (index > offset) ? index - offset : 0; |
365 | }; |
366 | |
367 | const dim_t SP = OW * OH; |
368 | const dim_t OSP = SP * OD; |
369 | |
370 | const auto get_logical_offset |
371 | = [&](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) -> dim_t { |
372 | return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow; |
373 | }; |
374 | const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty()); |
375 | const int nthr = pd()->nthr_; |
376 | |
377 | parallel_nd_ext(nthr, MB, OD, OH, OW, |
378 | [&](int ithr, int, dim_t mb, dim_t od, dim_t oh, dim_t ow) { |
379 | const size_t dst_offset_init = strided_offset(mb, dst_n_stride, |
380 | od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride); |
381 | float *const dst_f32 = &cvt_dst_wsp[ithr * OC]; |
382 | float *const src_f32 = &cvt_src_wsp[ithr * OC]; |
383 | |
384 | if (alg == alg_kind::pooling_max) { |
385 | size_t ws_offset_init = 0; |
386 | if (ws) { |
387 | DECLARE_READ_STRIDES(ws); |
388 | ws_offset_init = strided_offset(mb, ws_n_stride, od, |
389 | ws_d_stride, oh, ws_h_stride, ow, ws_w_stride); |
390 | }; |
391 | // Note: GCC 4.8.5 won't vectorize below |
392 | // simple loops unless they are singled out |
393 | // into separate helper routines: |
394 | // array_nhwc_initialize, array_nhwc_max |
395 | if (!ws) { |
396 | PRAGMA_OMP_SIMD() |
397 | for (dim_t oc = 0; oc < OC; ++oc) { |
398 | dst_f32[oc] |
399 | = nstl::numeric_limits<data_t>::lowest(); |
400 | } |
401 | } else { |
402 | array_nhwc_initialize( |
403 | OC, dst_f32, ws, ws_offset_init, ws_dt); |
404 | } |
405 | |
406 | for_(dim_t kd = 0; kd < KD; ++kd) |
407 | for_(dim_t kh = 0; kh < KH; ++kh) |
408 | for (dim_t kw = 0; kw < KW; ++kw) { |
409 | const dim_t id = od * SD - padF + kd; |
410 | const dim_t ih = oh * SH - padT + kh; |
411 | const dim_t iw = ow * SW - padL + kw; |
412 | |
413 | if (id < 0 || id >= ID) continue; |
414 | if (ih < 0 || ih >= IH) continue; |
415 | if (iw < 0 || iw >= IW) continue; |
416 | |
417 | const size_t src_offset_init = strided_offset(mb, |
418 | src_n_stride, id, src_d_stride, ih, |
419 | src_h_stride, iw, src_w_stride); |
420 | |
421 | types::cvt_to_float(src_f32, &src[src_offset_init], OC); |
422 | |
423 | if (!ws) { |
424 | PRAGMA_OMP_SIMD() |
425 | for (dim_t oc = 0; oc < OC; ++oc) { |
426 | dst_f32[oc] |
427 | = nstl::max(src_f32[oc], dst_f32[oc]); |
428 | } |
429 | } else { |
430 | array_nhwc_max(OC, dst_f32, src_f32, ws, |
431 | ws_offset_init, ws_dt, |
432 | kd * KH * KW + kh * KW + kw); |
433 | } |
434 | } |
435 | } else { |
436 | // pooling_avg |
437 | utils::array_set(dst_f32, 0, OC); |
438 | |
439 | const auto id_start = apply_offset(od * SD, padF); |
440 | const auto ih_start = apply_offset(oh * SH, padT); |
441 | const auto iw_start = apply_offset(ow * SW, padL); |
442 | const auto id_end = min(od * SD - padF + KD, ID); |
443 | const auto ih_end = min(oh * SH - padT + KH, IH); |
444 | const auto iw_end = min(ow * SW - padL + KW, IW); |
445 | |
446 | // it is cheaper to actually count this in a loop |
447 | // as the typical kernel is small |
448 | size_t num_summands = 0; |
449 | |
450 | for_(dim_t id = id_start; id < id_end; ++id) |
451 | for_(dim_t ih = ih_start; ih < ih_end; ++ih) |
452 | for (dim_t iw = iw_start; iw < iw_end; ++iw) { |
453 | size_t src_offset_init = strided_offset(mb, |
454 | src_n_stride, id, src_d_stride, ih, |
455 | src_h_stride, iw, src_w_stride); |
456 | types::cvt_to_float(src_f32, &src[src_offset_init], OC); |
457 | |
458 | // need to move the loop to separate function |
459 | // for GCC 4.8.5 to vectorize |
460 | array_add(OC, src_f32, dst_f32); |
461 | num_summands++; |
462 | } |
463 | |
464 | num_summands |
465 | = (alg == alg_kind::pooling_avg_include_padding) |
466 | ? KW * KH * KD |
467 | : num_summands; |
468 | |
469 | // need to move the loop to separate function |
470 | // for GCC 4.8.5 to vectorize |
471 | array_div_by_const(OC, dst_f32, num_summands, dst_f32); |
472 | } |
473 | |
474 | if (are_postops_set) { |
475 | ref_post_ops_t::args_t args; |
476 | args.ctx = &ctx; |
477 | args.l_offset = get_logical_offset(mb, 0, od, oh, ow); |
478 | args.dst_md = pd()->dst_md(); |
479 | |
480 | for (dim_t oc = 0; oc < OC; ++oc) { |
481 | ref_post_ops_.execute(dst_f32[oc], args); |
482 | args.l_offset += OSP; |
483 | } |
484 | } |
485 | types::cvt_from_float(dst + dst_offset_init, dst_f32, OC); |
486 | }); |
487 | return status::success; |
488 | } |
489 | |
490 | template <> |
491 | status_t nhwc_pooling_bwd_t<data_type::f32>::execute_backward( |
492 | const exec_ctx_t &ctx) const { |
493 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
494 | auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE); |
495 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
496 | |
497 | const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md()); |
498 | const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md()); |
499 | const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); |
500 | |
501 | const dim_t MB = pd()->MB(); |
502 | const dim_t OC = pd()->OC(); |
503 | const dim_t OD = pd()->OD(); |
504 | const dim_t OH = pd()->OH(); |
505 | const dim_t OW = pd()->OW(); |
506 | const dim_t ID = pd()->ID(); |
507 | const dim_t IH = pd()->IH(); |
508 | const dim_t IW = pd()->IW(); |
509 | const dim_t KD = pd()->KD(); |
510 | const dim_t KH = pd()->KH(); |
511 | const dim_t KW = pd()->KW(); |
512 | const dim_t SD = pd()->KSD(); |
513 | const dim_t SH = pd()->KSH(); |
514 | const dim_t SW = pd()->KSW(); |
515 | const dim_t padF = pd()->padFront(); |
516 | const dim_t padT = pd()->padT(); |
517 | const dim_t padL = pd()->padL(); |
518 | |
519 | const bool is_1d = pd()->desc()->diff_src_desc.ndims == 3; |
520 | const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; |
521 | const int ndims = pd()->ndims(); |
522 | auto alg = pd()->desc()->alg_kind; |
523 | |
524 | DECLARE_READ_STRIDES(diff_src); |
525 | DECLARE_READ_STRIDES(diff_dst); |
526 | |
527 | auto apply_offset = [=](dim_t index, dim_t offset) { |
528 | return (index > offset) ? index - offset : 0; |
529 | }; |
530 | |
531 | parallel_nd(MB, ID, IH, IW, [&](dim_t mb, dim_t id, dim_t ih, dim_t iw) { |
532 | size_t src_offset_init |
533 | = strided_offset(mb, diff_src_n_stride, id, diff_src_d_stride, |
534 | ih, diff_src_h_stride, iw, diff_src_w_stride); |
535 | |
536 | for (dim_t oc = 0; oc < OC; ++oc) |
537 | diff_src[src_offset_init + oc] = data_type_t(0); |
538 | |
539 | // Find out which output cells may correspond to current |
540 | // input position. Current input postition divided by |
541 | // stride, with integer divide rounding down, is the |
542 | // right-most output. |
543 | // Left-most output may be computed if we decrement input |
544 | // by (kernel_size - 1) and then do the same division by |
545 | // stride. |
546 | dim_t od_left = max((id + padF - KD + 1) / SD, dim_t(0)); |
547 | dim_t oh_left = max((ih + padT - KH + 1) / SH, dim_t(0)); |
548 | dim_t ow_left = max((iw + padL - KW + 1) / SW, dim_t(0)); |
549 | // Notice +1 here to preserve the C loop "less than" |
550 | // condition for continuing the for loop. |
551 | dim_t od_right = min((id + padF) / SD + 1, OD); |
552 | dim_t oh_right = min((ih + padT) / SH + 1, OH); |
553 | dim_t ow_right = min((iw + padL) / SW + 1, OW); |
554 | |
555 | for_(dim_t od = od_left; od < od_right; ++od) |
556 | for_(dim_t oh = oh_left; oh < oh_right; ++oh) |
557 | for (dim_t ow = ow_left; ow < ow_right; ++ow) { |
558 | const dim_t kd = id - od * SD + padF; |
559 | const dim_t kh = ih - oh * SH + padT; |
560 | const dim_t kw = iw - ow * SW + padL; |
561 | |
562 | if (kd < 0 || kd >= KD) continue; |
563 | if (kh < 0 || kh >= KH) continue; |
564 | if (kw < 0 || kw >= KW) continue; |
565 | |
566 | size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, od, |
567 | diff_dst_d_stride, oh, diff_dst_h_stride, ow, |
568 | diff_dst_w_stride); |
569 | |
570 | if (alg == alg_kind::pooling_max) { |
571 | DECLARE_READ_STRIDES(ws); |
572 | size_t ws_offset_init = strided_offset(mb, ws_n_stride, od, |
573 | ws_d_stride, oh, ws_h_stride, ow, ws_w_stride); |
574 | const dim_t index = kd * KH * KW + kh * KW + kw; |
575 | const unsigned char *ws_ = ws + ws_offset_init; |
576 | const int *intws_ = (int *)ws + ws_offset_init; |
577 | const bool ws_is_u8 = MEM_D(ws).data_type() == data_type::u8; |
578 | |
579 | #if SAFE_TO_USE_OMP_SIMD |
580 | PRAGMA_OMP_SIMD() |
581 | #endif |
582 | for (dim_t oc = 0; oc < OC; ++oc) { |
583 | const int index_from_ws = ws_is_u8 ? ws_[oc] : intws_[oc]; |
584 | const data_t d = diff_dst[dst_offset_init + oc]; |
585 | |
586 | // Check if kernel windows are disjoint, in this case |
587 | // there's no update needed and we just write there once |
588 | // otherwise we add value to the contents. |
589 | auto value = (index_from_ws == index) ? d : data_type_t(0); |
590 | if (!(KD == SD && KH == SH && KW == SW)) |
591 | diff_src[src_offset_init + oc] += value; |
592 | else |
593 | diff_src[src_offset_init + oc] = value; |
594 | } |
595 | } else { |
596 | // pooling_avg |
597 | auto id_start = apply_offset(od * SD, padF); |
598 | auto ih_start = apply_offset(oh * SH, padT); |
599 | auto iw_start = apply_offset(ow * SW, padL); |
600 | auto id_end = min(od * SD - padF + KD, ID); |
601 | auto ih_end = min(oh * SH - padT + KH, IH); |
602 | auto iw_end = min(ow * SW - padL + KW, IW); |
603 | |
604 | auto num_summands |
605 | = (alg == alg_kind::pooling_avg_include_padding) |
606 | ? KW * KH * KD |
607 | : (ih_end - ih_start) * (iw_end - iw_start) |
608 | * (id_end - id_start); |
609 | |
610 | PRAGMA_OMP_SIMD() |
611 | for (dim_t oc = 0; oc < OC; ++oc) { |
612 | const data_t d = diff_dst[dst_offset_init + oc]; |
613 | // Check if kernel windows are disjoint, in this case |
614 | // there's no update needed and we just write there once |
615 | // otherwise we add value to the contents. |
616 | if (!(KD == SD && KH == SH && KW == SW)) |
617 | diff_src[src_offset_init + oc] += d / num_summands; |
618 | else |
619 | diff_src[src_offset_init + oc] = d / num_summands; |
620 | } |
621 | } |
622 | } |
623 | }); |
624 | return status::success; |
625 | } |
626 | |
627 | template <data_type_t d_type> |
628 | status_t nhwc_pooling_bwd_t<d_type>::execute_backward( |
629 | const exec_ctx_t &ctx) const { |
630 | |
631 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
632 | auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE); |
633 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
634 | |
635 | auto scratchpad = ctx.get_scratchpad_grantor(); |
636 | float *cvt_dsrc = scratchpad.template get<float>( |
637 | memory_tracking::names::key_pool_src_bf16cvt); |
638 | float *cvt_ddst = scratchpad.template get<float>( |
639 | memory_tracking::names::key_pool_dst_bf16cvt); |
640 | |
641 | const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md()); |
642 | const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md()); |
643 | const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); |
644 | |
645 | const dim_t MB = pd()->MB(); |
646 | const dim_t OC = pd()->OC(); |
647 | const dim_t OD = pd()->OD(); |
648 | const dim_t OH = pd()->OH(); |
649 | const dim_t OW = pd()->OW(); |
650 | const dim_t ID = pd()->ID(); |
651 | const dim_t IH = pd()->IH(); |
652 | const dim_t IW = pd()->IW(); |
653 | const dim_t KD = pd()->KD(); |
654 | const dim_t KH = pd()->KH(); |
655 | const dim_t KW = pd()->KW(); |
656 | const dim_t SD = pd()->KSD(); |
657 | const dim_t SH = pd()->KSH(); |
658 | const dim_t SW = pd()->KSW(); |
659 | const dim_t padF = pd()->padFront(); |
660 | const dim_t padT = pd()->padT(); |
661 | const dim_t padL = pd()->padL(); |
662 | |
663 | const bool is_1d = pd()->desc()->diff_src_desc.ndims == 3; |
664 | const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; |
665 | const int ndims = pd()->ndims(); |
666 | auto alg = pd()->desc()->alg_kind; |
667 | |
668 | DECLARE_READ_STRIDES(diff_src); |
669 | DECLARE_READ_STRIDES(diff_dst); |
670 | |
671 | auto apply_offset = [=](dim_t index, dim_t offset) { |
672 | return (index > offset) ? index - offset : 0; |
673 | }; |
674 | const int nthr = pd()->nthr_; |
675 | |
676 | parallel_nd_ext(nthr, MB, ID, IH, IW, |
677 | [&](int ithr, int, dim_t mb, dim_t id, dim_t ih, dim_t iw) { |
678 | size_t src_offset_init = strided_offset(mb, diff_src_n_stride, |
679 | id, diff_src_d_stride, ih, diff_src_h_stride, iw, |
680 | diff_src_w_stride); |
681 | |
682 | float *diff_dst_fp32 = &cvt_ddst[ithr * OC]; |
683 | float *diff_src_fp32 = &cvt_dsrc[ithr * OC]; |
684 | |
685 | for (dim_t oc = 0; oc < OC; ++oc) { |
686 | diff_src_fp32[oc] = 0.f; |
687 | diff_src[src_offset_init + oc] = (bfloat16_t)0.f; |
688 | } |
689 | |
690 | // Find out which output cells may correspond to current |
691 | // input position. Current input postition divided by |
692 | // stride, with integer divide rounding down, is the |
693 | // right-most output. |
694 | // Left-most output may be computed if we decrement input |
695 | // by (kernel_size - 1) and then do the same division by |
696 | // stride. |
697 | dim_t od_left = max((id + padF - KD + 1) / SD, dim_t(0)); |
698 | dim_t oh_left = max((ih + padT - KH + 1) / SH, dim_t(0)); |
699 | dim_t ow_left = max((iw + padL - KW + 1) / SW, dim_t(0)); |
700 | // Notice +1 here to preserve the C loop "less than" |
701 | // condition for continuing the for loop. |
702 | dim_t od_right = min((id + padF) / SD + 1, OD); |
703 | dim_t oh_right = min((ih + padT) / SH + 1, OH); |
704 | dim_t ow_right = min((iw + padL) / SW + 1, OW); |
705 | |
706 | for_(dim_t od = od_left; od < od_right; ++od) |
707 | for_(dim_t oh = oh_left; oh < oh_right; ++oh) |
708 | for (dim_t ow = ow_left; ow < ow_right; ++ow) { |
709 | const dim_t kd = id - od * SD + padF; |
710 | const dim_t kh = ih - oh * SH + padT; |
711 | const dim_t kw = iw - ow * SW + padL; |
712 | |
713 | if (kd < 0 || kd >= KD) continue; |
714 | if (kh < 0 || kh >= KH) continue; |
715 | if (kw < 0 || kw >= KW) continue; |
716 | |
717 | size_t dst_offset_init = strided_offset(mb, |
718 | diff_dst_n_stride, od, diff_dst_d_stride, oh, |
719 | diff_dst_h_stride, ow, diff_dst_w_stride); |
720 | types::cvt_to_float( |
721 | diff_dst_fp32, &diff_dst[dst_offset_init], OC); |
722 | |
723 | if (alg == alg_kind::pooling_max) { |
724 | DECLARE_READ_STRIDES(ws); |
725 | size_t ws_offset_init = strided_offset(mb, ws_n_stride, |
726 | od, ws_d_stride, oh, ws_h_stride, ow, |
727 | ws_w_stride); |
728 | const dim_t index = kd * KH * KW + kh * KW + kw; |
729 | const unsigned char *ws_ = ws + ws_offset_init; |
730 | const int *intws_ = (int *)ws + ws_offset_init; |
731 | const bool ws_is_u8 |
732 | = MEM_D(ws).data_type() == data_type::u8; |
733 | |
734 | #if SAFE_TO_USE_OMP_SIMD |
735 | PRAGMA_OMP_SIMD() |
736 | #endif |
737 | for (dim_t oc = 0; oc < OC; ++oc) { |
738 | const int index_from_ws |
739 | = ws_is_u8 ? ws_[oc] : intws_[oc]; |
740 | |
741 | // Check if kernel windows are disjoint, in this case |
742 | // there's no update needed and we just write there once |
743 | // otherwise we add value to the contents. |
744 | float value = (index_from_ws == index) |
745 | ? diff_dst_fp32[oc] |
746 | : 0.0f; |
747 | if (!(KD == SD && KH == SH && KW == SW)) |
748 | diff_src_fp32[oc] += value; |
749 | else |
750 | diff_src_fp32[oc] = value; |
751 | } |
752 | } else { |
753 | // pooling_avg |
754 | auto id_start = apply_offset(od * SD, padF); |
755 | auto ih_start = apply_offset(oh * SH, padT); |
756 | auto iw_start = apply_offset(ow * SW, padL); |
757 | auto id_end = min(od * SD - padF + KD, ID); |
758 | auto ih_end = min(oh * SH - padT + KH, IH); |
759 | auto iw_end = min(ow * SW - padL + KW, IW); |
760 | |
761 | auto num_summands |
762 | = (alg == alg_kind::pooling_avg_include_padding) |
763 | ? KW * KH * KD |
764 | : (ih_end - ih_start) * (iw_end - iw_start) |
765 | * (id_end - id_start); |
766 | |
767 | PRAGMA_OMP_SIMD() |
768 | for (dim_t oc = 0; oc < OC; ++oc) { |
769 | // Check if kernel windows are disjoint, in this case |
770 | // there's no update needed and we just write there once |
771 | // otherwise we add value to the contents. |
772 | if (!(KD == SD && KH == SH && KW == SW)) |
773 | diff_src_fp32[oc] |
774 | += diff_dst_fp32[oc] / num_summands; |
775 | else |
776 | diff_src_fp32[oc] |
777 | = diff_dst_fp32[oc] / num_summands; |
778 | } |
779 | } |
780 | types::cvt_from_float( |
781 | &diff_src[src_offset_init], diff_src_fp32, OC); |
782 | } |
783 | }); |
784 | return status::success; |
785 | } |
786 | |
787 | template struct nhwc_pooling_fwd_t<data_type::f32>; |
788 | template struct nhwc_pooling_bwd_t<data_type::f32>; |
789 | template struct nhwc_pooling_fwd_t<data_type::bf16>; |
790 | template struct nhwc_pooling_bwd_t<data_type::bf16>; |
791 | template struct nhwc_pooling_fwd_t<data_type::f16>; |
792 | template struct nhwc_pooling_bwd_t<data_type::f16>; |
793 | |
794 | } // namespace cpu |
795 | } // namespace impl |
796 | } // namespace dnnl |
797 | |
798 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
799 | |