1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/nstl.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | #include "cpu/simple_q10n.hpp" |
26 | |
27 | #include "cpu/nchw_pooling.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | using namespace nstl; |
34 | |
35 | template <data_type_t d_type> |
36 | nchw_pooling_fwd_t<d_type>::nchw_pooling_fwd_t(const pd_t *apd) |
37 | : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {} |
38 | |
39 | template <> |
40 | status_t nchw_pooling_fwd_t<data_type::f32>::execute_forward( |
41 | const exec_ctx_t &ctx) const { |
42 | const auto alg = pd()->desc()->alg_kind; |
43 | const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
44 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
45 | auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); |
46 | |
47 | const memory_desc_wrapper ws_d(pd()->workspace_md()); |
48 | const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; |
49 | |
50 | const dim_t MB = pd()->MB(); |
51 | const dim_t C = pd()->OC(); |
52 | const dim_t OD = pd()->OD(); |
53 | const dim_t OH = pd()->OH(); |
54 | const dim_t OW = pd()->OW(); |
55 | const dim_t ID = pd()->ID(); |
56 | const dim_t IH = pd()->IH(); |
57 | const dim_t IW = pd()->IW(); |
58 | const dim_t KD = pd()->KD(); |
59 | const dim_t KH = pd()->KH(); |
60 | const dim_t KW = pd()->KW(); |
61 | const dim_t SD = pd()->KSD(); |
62 | const dim_t SH = pd()->KSH(); |
63 | const dim_t SW = pd()->KSW(); |
64 | const dim_t padF = pd()->padFront(); |
65 | const dim_t padT = pd()->padT(); |
66 | const dim_t padL = pd()->padL(); |
67 | |
68 | const auto apply_offset = [](int index, int offset) { |
69 | return (index > offset) ? index - offset : 0; |
70 | }; |
71 | |
72 | const auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow, |
73 | dim_t value) { |
74 | if (ws) { |
75 | assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); |
76 | const size_t ws_offset = (size_t)OW * OH * OD * C * mb |
77 | + (size_t)OW * OH * OD * c + (size_t)OW * OH * od |
78 | + (size_t)OW * oh + (size_t)ow; |
79 | if (ws_dt == data_type::u8) { |
80 | assert(0 <= value |
81 | && value <= numeric_limits<typename prec_traits< |
82 | data_type::u8>::type>::max()); |
83 | ws[ws_offset] = value; |
84 | } else |
85 | reinterpret_cast<int *>(ws)[ws_offset] = value; |
86 | } |
87 | }; |
88 | |
89 | const auto ker_max = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh, |
90 | dim_t ow) { |
91 | const auto src_off = IW * IH * ID * C * mb + IW * IH * ID * c; |
92 | const auto *src_loc = &src[src_off]; |
93 | |
94 | for_(dim_t kd = 0; kd < KD; ++kd) |
95 | for_(dim_t kh = 0; kh < KH; ++kh) |
96 | for (dim_t kw = 0; kw < KW; ++kw) { |
97 | const dim_t id = od * SD - padF + kd; |
98 | if (id < 0 || id >= ID) continue; |
99 | const dim_t ih = oh * SH - padT + kh; |
100 | if (ih < 0 || ih >= IH) continue; |
101 | const dim_t iw = ow * SW - padL + kw; |
102 | if (iw < 0 || iw >= IW) continue; |
103 | |
104 | const auto src_off_loc = IW * IH * id + IW * ih + iw; |
105 | const auto &s = src_loc[src_off_loc]; |
106 | if (s > d[0]) { |
107 | d[0] = s; |
108 | set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw); |
109 | } |
110 | } |
111 | }; |
112 | |
113 | const auto ker_avg = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh, |
114 | dim_t ow) { |
115 | const auto id_start = apply_offset(od * SD, padF); |
116 | const auto ih_start = apply_offset(oh * SH, padT); |
117 | const auto iw_start = apply_offset(ow * SW, padL); |
118 | const auto id_end = min(od * SD - padF + KD, ID); |
119 | const auto ih_end = min(oh * SH - padT + KH, IH); |
120 | const auto iw_end = min(ow * SW - padL + KW, IW); |
121 | |
122 | const auto num_summands = (alg == alg_kind::pooling_avg_include_padding) |
123 | ? KD * KW * KH |
124 | : (id_end - id_start) * (ih_end - ih_start) |
125 | * (iw_end - iw_start); |
126 | |
127 | const auto src_off |
128 | = IW * IH * ID * C * mb + IW * IH * ID * c + iw_start; |
129 | |
130 | float d_val = 0; |
131 | for_(dim_t id = id_start; id < id_end; ++id) |
132 | for (dim_t ih = ih_start; ih < ih_end; ++ih) { |
133 | const auto src_off_loc = src_off + IW * IH * id + IW * ih; |
134 | const auto *src_loc = &src[src_off_loc]; |
135 | for (dim_t iw = 0; iw < iw_end - iw_start; ++iw) |
136 | d_val += src_loc[iw]; |
137 | } |
138 | |
139 | return d_val / num_summands; |
140 | }; |
141 | |
142 | // Keep branches for post-ops since reference post-ops execution brings |
143 | // noticeable overhead. |
144 | const bool has_post_ops = pd()->attr()->post_ops_.len() > 0; |
145 | |
146 | if (alg == alg_kind::pooling_max) { |
147 | if (has_post_ops) { |
148 | parallel_nd(MB, C, OD, OH, OW, |
149 | [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { |
150 | const size_t dst_offset = (size_t)OW * OH * OD * C * mb |
151 | + (size_t)OW * OH * OD * c |
152 | + (size_t)OW * OH * od + (size_t)OW * oh |
153 | + (size_t)ow; |
154 | data_t *d = &dst[dst_offset]; |
155 | d[0] = numeric_limits<data_t>::lowest(); |
156 | set_ws(mb, c, od, oh, ow, 0); |
157 | ker_max(d, mb, c, od, oh, ow); |
158 | |
159 | ref_post_ops_t::args_t args; |
160 | args.ctx = &ctx; |
161 | args.l_offset = dst_offset; |
162 | args.dst_md = pd()->dst_md(); |
163 | ref_post_ops_.execute(dst[dst_offset], args); |
164 | dst[dst_offset] |
165 | = saturate_and_round<data_t>(dst[dst_offset]); |
166 | }); |
167 | } else { |
168 | parallel_nd(MB, C, OD, OH, OW, |
169 | [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { |
170 | const size_t dst_offset = (size_t)OW * OH * OD * C * mb |
171 | + (size_t)OW * OH * OD * c |
172 | + (size_t)OW * OH * od + (size_t)OW * oh |
173 | + (size_t)ow; |
174 | data_t *d = &dst[dst_offset]; |
175 | d[0] = numeric_limits<data_t>::lowest(); |
176 | set_ws(mb, c, od, oh, ow, 0); |
177 | ker_max(d, mb, c, od, oh, ow); |
178 | |
179 | dst[dst_offset] |
180 | = saturate_and_round<data_t>(dst[dst_offset]); |
181 | }); |
182 | } |
183 | } else { |
184 | if (has_post_ops) { |
185 | parallel_nd(MB, C, OD, OH, OW, |
186 | [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { |
187 | const size_t dst_offset = (size_t)OW * OH * OD * C * mb |
188 | + (size_t)OW * OH * OD * c |
189 | + (size_t)OW * OH * od + (size_t)OW * oh |
190 | + (size_t)ow; |
191 | data_t *d = &dst[dst_offset]; |
192 | d[0] = 0; |
193 | auto res = ker_avg(d, mb, c, od, oh, ow); |
194 | |
195 | ref_post_ops_t::args_t args; |
196 | args.ctx = &ctx; |
197 | args.l_offset = dst_offset; |
198 | args.dst_md = pd()->dst_md(); |
199 | ref_post_ops_.execute(res, args); |
200 | d[0] = saturate_and_round<data_t>(res); |
201 | }); |
202 | } else { |
203 | parallel_nd(MB, C, OD, OH, OW, |
204 | [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { |
205 | const size_t dst_offset = (size_t)OW * OH * OD * C * mb |
206 | + (size_t)OW * OH * OD * c |
207 | + (size_t)OW * OH * od + (size_t)OW * oh |
208 | + (size_t)ow; |
209 | data_t *d = &dst[dst_offset]; |
210 | d[0] = 0; |
211 | auto res = ker_avg(d, mb, c, od, oh, ow); |
212 | |
213 | d[0] = saturate_and_round<data_t>(res); |
214 | }); |
215 | } |
216 | } |
217 | |
218 | return status::success; |
219 | } |
220 | |
221 | template <data_type_t d_type> |
222 | status_t nchw_pooling_fwd_t<d_type>::execute_forward( |
223 | const exec_ctx_t &ctx) const { |
224 | |
225 | auto alg = pd()->desc()->alg_kind; |
226 | |
227 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
228 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
229 | auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE); |
230 | |
231 | auto scratchpad = ctx.get_scratchpad_grantor(); |
232 | float *cvt_wsp = scratchpad.template get<float>( |
233 | memory_tracking::names::key_pool_src_bf16cvt); |
234 | |
235 | const memory_desc_wrapper ws_d(pd()->workspace_md()); |
236 | const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; |
237 | |
238 | const dim_t MB = pd()->MB(); |
239 | const dim_t C = pd()->OC(); |
240 | const dim_t OD = pd()->OD(); |
241 | const dim_t OH = pd()->OH(); |
242 | const dim_t OW = pd()->OW(); |
243 | const dim_t ID = pd()->ID(); |
244 | const dim_t IH = pd()->IH(); |
245 | const dim_t IW = pd()->IW(); |
246 | const dim_t KD = pd()->KD(); |
247 | const dim_t KH = pd()->KH(); |
248 | const dim_t KW = pd()->KW(); |
249 | const dim_t SD = pd()->KSD(); |
250 | const dim_t SH = pd()->KSH(); |
251 | const dim_t SW = pd()->KSW(); |
252 | const dim_t padF = pd()->padFront(); |
253 | const dim_t padT = pd()->padT(); |
254 | const dim_t padL = pd()->padL(); |
255 | |
256 | const size_t simd_w = 16; |
257 | const size_t src_size = MB * C * ID * IH * IW; |
258 | const size_t blocked_size = src_size / simd_w; |
259 | const size_t tail_size = src_size % simd_w; |
260 | |
261 | auto apply_offset = [=](int index, int offset) { |
262 | return (index > offset) ? index - offset : 0; |
263 | }; |
264 | |
265 | auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow, |
266 | dim_t value) { |
267 | if (ws) { |
268 | assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); |
269 | size_t ws_offset = (size_t)OW * OH * OD * C * mb |
270 | + (size_t)OW * OH * OD * c + (size_t)OW * OH * od |
271 | + (size_t)OW * oh + (size_t)ow; |
272 | if (ws_dt == data_type::u8) { |
273 | assert(0 <= value |
274 | && value <= numeric_limits<typename prec_traits< |
275 | data_type::u8>::type>::max()); |
276 | ws[ws_offset] = value; |
277 | } else |
278 | reinterpret_cast<int *>(ws)[ws_offset] = value; |
279 | } |
280 | }; |
281 | |
282 | auto ker_max = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh, |
283 | dim_t ow) { |
284 | const auto src_off = IW * IH * ID * C * mb + IW * IH * ID * c; |
285 | const auto *src_loc = &cvt_wsp[src_off]; |
286 | |
287 | for_(dim_t kd = 0; kd < KD; ++kd) |
288 | for_(dim_t kh = 0; kh < KH; ++kh) |
289 | for (dim_t kw = 0; kw < KW; ++kw) { |
290 | const dim_t id = od * SD - padF + kd; |
291 | if (id < 0 || id >= ID) continue; |
292 | const dim_t ih = oh * SH - padT + kh; |
293 | if (ih < 0 || ih >= IH) continue; |
294 | const dim_t iw = ow * SW - padL + kw; |
295 | if (iw < 0 || iw >= IW) continue; |
296 | |
297 | const auto src_off_loc = IW * IH * id + IW * ih + iw; |
298 | const auto &s = src_loc[src_off_loc]; |
299 | if (s > d[0]) { |
300 | d[0] = s; |
301 | set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw); |
302 | } |
303 | } |
304 | }; |
305 | |
306 | auto ker_avg = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh, |
307 | dim_t ow) { |
308 | auto id_start = apply_offset(od * SD, padF); |
309 | auto ih_start = apply_offset(oh * SH, padT); |
310 | auto iw_start = apply_offset(ow * SW, padL); |
311 | auto id_end = min(od * SD - padF + KD, ID); |
312 | auto ih_end = min(oh * SH - padT + KH, IH); |
313 | auto iw_end = min(ow * SW - padL + KW, IW); |
314 | |
315 | auto num_summands = (alg == alg_kind::pooling_avg_include_padding) |
316 | ? KD * KW * KH |
317 | : (id_end - id_start) * (ih_end - ih_start) |
318 | * (iw_end - iw_start); |
319 | |
320 | const auto src_off |
321 | = IW * IH * ID * C * mb + IW * IH * ID * c + iw_start; |
322 | |
323 | for_(dim_t id = id_start; id < id_end; ++id) |
324 | for (dim_t ih = ih_start; ih < ih_end; ++ih) { |
325 | const auto src_off_loc = src_off + IW * IH * id + IW * ih; |
326 | const auto *src_loc = &cvt_wsp[src_off_loc]; |
327 | for (dim_t iw = 0; iw < iw_end - iw_start; ++iw) |
328 | d[0] += src_loc[iw]; |
329 | } |
330 | |
331 | d[0] = out_round<float>((float)d[0] / num_summands); |
332 | }; |
333 | |
334 | parallel_nd(blocked_size, [&](size_t i) { |
335 | types::cvt_to_float(&cvt_wsp[i * simd_w], &src[i * simd_w], simd_w); |
336 | }); |
337 | if (tail_size) |
338 | types::cvt_to_float(&cvt_wsp[blocked_size * simd_w], |
339 | &src[blocked_size * simd_w], tail_size); |
340 | |
341 | // Keep branches for post-ops since reference post-ops execution brings |
342 | // noticeable overhead. |
343 | const bool has_post_ops = pd()->attr()->post_ops_.len() > 0; |
344 | |
345 | if (alg == alg_kind::pooling_max) { |
346 | if (has_post_ops) { |
347 | parallel_nd(MB, C, OD, OH, OW, |
348 | [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { |
349 | size_t dst_offset = (size_t)OW * OH * OD * C * mb |
350 | + (size_t)OW * OH * OD * c |
351 | + (size_t)OW * OH * od + (size_t)OW * oh |
352 | + (size_t)ow; |
353 | float d_fp32 = numeric_limits<data_t>::lowest(); |
354 | |
355 | set_ws(mb, c, od, oh, ow, 0); |
356 | |
357 | ker_max(&d_fp32, mb, c, od, oh, ow); |
358 | |
359 | ref_post_ops_t::args_t args; |
360 | args.ctx = &ctx; |
361 | args.l_offset = dst_offset; |
362 | args.dst_md = pd()->dst_md(); |
363 | ref_post_ops_.execute(d_fp32, args); |
364 | |
365 | dst[dst_offset] = static_cast<data_t>(d_fp32); |
366 | }); |
367 | } else { |
368 | parallel_nd(MB, C, OD, OH, OW, |
369 | [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { |
370 | size_t dst_offset = (size_t)OW * OH * OD * C * mb |
371 | + (size_t)OW * OH * OD * c |
372 | + (size_t)OW * OH * od + (size_t)OW * oh |
373 | + (size_t)ow; |
374 | float d_fp32 = numeric_limits<data_t>::lowest(); |
375 | |
376 | set_ws(mb, c, od, oh, ow, 0); |
377 | |
378 | ker_max(&d_fp32, mb, c, od, oh, ow); |
379 | |
380 | dst[dst_offset] = static_cast<data_t>(d_fp32); |
381 | }); |
382 | } |
383 | } else { |
384 | if (has_post_ops) { |
385 | parallel_nd(MB, C, OD, OH, OW, |
386 | [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { |
387 | size_t dst_offset = (size_t)OW * OH * OD * C * mb |
388 | + (size_t)OW * OH * OD * c |
389 | + (size_t)OW * OH * od + (size_t)OW * oh |
390 | + (size_t)ow; |
391 | float d_fp32 = 0.0f; |
392 | ker_avg(&d_fp32, mb, c, od, oh, ow); |
393 | ref_post_ops_t::args_t args; |
394 | args.ctx = &ctx; |
395 | args.l_offset = dst_offset; |
396 | args.dst_md = pd()->dst_md(); |
397 | ref_post_ops_.execute(d_fp32, args); |
398 | dst[dst_offset] = static_cast<data_t>(d_fp32); |
399 | }); |
400 | } else { |
401 | parallel_nd(MB, C, OD, OH, OW, |
402 | [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) { |
403 | size_t dst_offset = (size_t)OW * OH * OD * C * mb |
404 | + (size_t)OW * OH * OD * c |
405 | + (size_t)OW * OH * od + (size_t)OW * oh |
406 | + (size_t)ow; |
407 | float d_fp32 = 0.0f; |
408 | ker_avg(&d_fp32, mb, c, od, oh, ow); |
409 | |
410 | dst[dst_offset] = static_cast<data_t>(d_fp32); |
411 | }); |
412 | } |
413 | } |
414 | |
415 | return status::success; |
416 | } |
417 | |
418 | template <> |
419 | status_t nchw_pooling_bwd_t<data_type::f32>::execute_backward( |
420 | const exec_ctx_t &ctx) const { |
421 | auto alg = pd()->desc()->alg_kind; |
422 | const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; |
423 | const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4; |
424 | |
425 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
426 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
427 | auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE); |
428 | |
429 | const memory_desc_wrapper ws_d(pd()->workspace_md()); |
430 | |
431 | const dim_t MB = pd()->MB(); |
432 | const dim_t C = pd()->OC(); |
433 | const dim_t OD = pd()->OD(); |
434 | const dim_t OH = pd()->OH(); |
435 | const dim_t OW = pd()->OW(); |
436 | const dim_t ID = pd()->ID(); |
437 | const dim_t IH = pd()->IH(); |
438 | const dim_t IW = pd()->IW(); |
439 | const dim_t KD = pd()->KD(); |
440 | const dim_t KH = pd()->KH(); |
441 | const dim_t KW = pd()->KW(); |
442 | const dim_t SD = pd()->KSD(); |
443 | const dim_t SH = pd()->KSH(); |
444 | const dim_t SW = pd()->KSW(); |
445 | const dim_t padF = pd()->padFront(); |
446 | const dim_t padT = pd()->padT(); |
447 | const dim_t padL = pd()->padL(); |
448 | |
449 | auto apply_offset = [=](int index, int offset) { |
450 | return (index > offset) ? index - offset : 0; |
451 | }; |
452 | |
453 | auto ker_zero = [=](dim_t mb, dim_t c) { |
454 | size_t diff_src_offset |
455 | = (size_t)mb * C * ID * IH * IW + (size_t)c * ID * IH * IW; |
456 | for_(dim_t id = 0; id < ID; ++id) |
457 | for_(dim_t ih = 0; ih < IH; ++ih) |
458 | for (dim_t iw = 0; iw < IW; ++iw) { |
459 | diff_src[diff_src_offset++] = 0; |
460 | } |
461 | }; |
462 | |
463 | auto ker_max = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh, |
464 | dim_t ow) { |
465 | auto b_c = ws_d.blocking_desc().inner_nblks == 0 |
466 | ? 1 |
467 | : ws_d.blocking_desc().inner_blks[0]; |
468 | auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow) |
469 | : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow) |
470 | : ws_d.blk_off(mb, c / b_c, ow)) |
471 | + c % b_c; |
472 | |
473 | const int index = ws_d.data_type() == data_type::u8 |
474 | ? (int)ws[ws_offset] |
475 | : ((const int *)ws)[ws_offset]; |
476 | const dim_t kw = index % KW; |
477 | const dim_t kh = (index / KW) % KH; |
478 | const dim_t kd = (index / KW) / KH; |
479 | |
480 | const dim_t id = od * SD - padF + kd; |
481 | const dim_t ih = oh * SH - padT + kh; |
482 | const dim_t iw = ow * SW - padL + kw; |
483 | |
484 | // If padding area could fit the kernel, |
485 | // then input displacement would be out of bounds. |
486 | // No need to back propagate there as padding is |
487 | // virtual in pooling_max case. |
488 | if (id < 0 || id >= ID) return; |
489 | if (ih < 0 || ih >= IH) return; |
490 | if (iw < 0 || iw >= IW) return; |
491 | |
492 | size_t diff_src_offset = (size_t)mb * C * ID * IH * IW |
493 | + (size_t)c * ID * IH * IW + (size_t)id * IH * IW |
494 | + (size_t)ih * IW + (size_t)iw; |
495 | diff_src[diff_src_offset] += d[0]; |
496 | }; |
497 | |
498 | auto ker_avg = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh, |
499 | dim_t ow) { |
500 | dim_t id_start = apply_offset(od * SD, padF); |
501 | dim_t ih_start = apply_offset(oh * SH, padT); |
502 | dim_t iw_start = apply_offset(ow * SW, padL); |
503 | dim_t id_end = min(od * SD - padF + KD, ID); |
504 | dim_t ih_end = min(oh * SH - padT + KH, IH); |
505 | dim_t iw_end = min(ow * SW - padL + KW, IW); |
506 | |
507 | size_t num_summands = (alg == alg_kind::pooling_avg_include_padding) |
508 | ? (size_t)KW * KH * KD |
509 | : (size_t)(id_end - id_start) * (ih_end - ih_start) |
510 | * (iw_end - iw_start); |
511 | |
512 | for_(dim_t id = id_start; id < id_end; ++id) |
513 | for_(dim_t ih = ih_start; ih < ih_end; ++ih) |
514 | for (dim_t iw = iw_start; iw < iw_end; ++iw) { |
515 | size_t diff_src_offset = (size_t)mb * C * ID * IH * IW |
516 | + (size_t)c * ID * IH * IW + (size_t)id * IH * IW |
517 | + (size_t)ih * IW + (size_t)iw; |
518 | diff_src[diff_src_offset] += d[0] / num_summands; |
519 | } |
520 | }; |
521 | |
522 | dim_t ow_start = max(dim_t(0), utils::div_up(padL - KW + 1, SW)); |
523 | dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW); |
524 | |
525 | dim_t oh_start = max(dim_t(0), utils::div_up(padT - KH + 1, SH)); |
526 | dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH); |
527 | |
528 | dim_t od_start = max(dim_t(0), utils::div_up(padF - KD + 1, SD)); |
529 | dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD); |
530 | |
531 | if (alg == alg_kind::pooling_max) { |
532 | parallel_nd(MB, C, [&](dim_t mb, dim_t c) { |
533 | size_t diff_dst_offset_b |
534 | = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW; |
535 | ker_zero(mb, c); |
536 | for_(dim_t od = od_start; od < od_end; ++od) |
537 | for (dim_t oh = oh_start; oh < oh_end; ++oh) { |
538 | size_t diff_dst_offset = diff_dst_offset_b |
539 | + (size_t)od * OH * OW + (size_t)oh * OW; |
540 | for (dim_t ow = ow_start; ow < ow_end; ++ow) { |
541 | const data_t *d = &diff_dst[diff_dst_offset + ow]; |
542 | ker_max(d, mb, c, od, oh, ow); |
543 | } |
544 | } |
545 | }); |
546 | } else { |
547 | parallel_nd(MB, C, [&](dim_t mb, dim_t c) { |
548 | size_t diff_dst_offset_b |
549 | = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW; |
550 | ker_zero(mb, c); |
551 | for_(dim_t od = od_start; od < od_end; ++od) |
552 | for (dim_t oh = oh_start; oh < oh_end; ++oh) { |
553 | size_t diff_dst_offset = diff_dst_offset_b |
554 | + (size_t)od * OH * OW + (size_t)oh * OW; |
555 | for (dim_t ow = ow_start; ow < ow_end; ++ow) { |
556 | const data_t *d = &diff_dst[diff_dst_offset + ow]; |
557 | ker_avg(d, mb, c, od, oh, ow); |
558 | } |
559 | } |
560 | }); |
561 | } |
562 | |
563 | return status::success; |
564 | } |
565 | |
566 | template <data_type_t d_type> |
567 | status_t nchw_pooling_bwd_t<d_type>::execute_backward( |
568 | const exec_ctx_t &ctx) const { |
569 | |
570 | auto alg = pd()->desc()->alg_kind; |
571 | const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; |
572 | const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4; |
573 | |
574 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
575 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
576 | auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE); |
577 | |
578 | auto scratchpad = ctx.get_scratchpad_grantor(); |
579 | float *cvt_src = scratchpad.template get<float>( |
580 | memory_tracking::names::key_pool_src_bf16cvt); |
581 | float *cvt_dst = scratchpad.template get<float>( |
582 | memory_tracking::names::key_pool_dst_bf16cvt); |
583 | |
584 | const memory_desc_wrapper ws_d(pd()->workspace_md()); |
585 | |
586 | const dim_t MB = pd()->MB(); |
587 | const dim_t C = pd()->OC(); |
588 | const dim_t OD = pd()->OD(); |
589 | const dim_t OH = pd()->OH(); |
590 | const dim_t OW = pd()->OW(); |
591 | const dim_t ID = pd()->ID(); |
592 | const dim_t IH = pd()->IH(); |
593 | const dim_t IW = pd()->IW(); |
594 | const dim_t KD = pd()->KD(); |
595 | const dim_t KH = pd()->KH(); |
596 | const dim_t KW = pd()->KW(); |
597 | const dim_t SD = pd()->KSD(); |
598 | const dim_t SH = pd()->KSH(); |
599 | const dim_t SW = pd()->KSW(); |
600 | const dim_t padF = pd()->padFront(); |
601 | const dim_t padT = pd()->padT(); |
602 | const dim_t padL = pd()->padL(); |
603 | |
604 | const size_t dst_sp_size = pd()->OD() * pd()->OH() * pd()->OW(); |
605 | const size_t src_sp_size = pd()->ID() * pd()->IH() * pd()->IW(); |
606 | |
607 | auto apply_offset = [=](int index, int offset) { |
608 | return (index > offset) ? index - offset : 0; |
609 | }; |
610 | |
611 | auto ker_zero = [=](float *diff_src, dim_t c_block_size) { |
612 | size_t diff_src_offset = 0; |
613 | for_(dim_t c = 0; c < c_block_size; ++c) |
614 | for_(dim_t id = 0; id < ID; ++id) |
615 | for_(dim_t ih = 0; ih < IH; ++ih) |
616 | for (dim_t iw = 0; iw < IW; ++iw) { |
617 | diff_src[diff_src_offset++] = 0.0f; |
618 | } |
619 | }; |
620 | |
621 | auto ker_max = [=](const float *d, float *diff_src, dim_t mb, dim_t c, |
622 | dim_t od, dim_t oh, dim_t ow) { |
623 | auto b_c = ws_d.blocking_desc().inner_nblks == 0 |
624 | ? 1 |
625 | : ws_d.blocking_desc().inner_blks[0]; |
626 | auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow) |
627 | : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow) |
628 | : ws_d.blk_off(mb, c / b_c, ow)) |
629 | + c % b_c; |
630 | |
631 | const int index = ws_d.data_type() == data_type::u8 |
632 | ? (int)ws[ws_offset] |
633 | : ((const int *)ws)[ws_offset]; |
634 | const dim_t kw = index % KW; |
635 | const dim_t kh = (index / KW) % KH; |
636 | const dim_t kd = (index / KW) / KH; |
637 | |
638 | const dim_t id = od * SD - padF + kd; |
639 | const dim_t ih = oh * SH - padT + kh; |
640 | const dim_t iw = ow * SW - padL + kw; |
641 | |
642 | // If padding area could fit the kernel, |
643 | // then input displacement would be out of bounds. |
644 | // No need to back propagate there as padding is |
645 | // virtual in pooling_max case. |
646 | if (id < 0 || id >= ID) return; |
647 | if (ih < 0 || ih >= IH) return; |
648 | if (iw < 0 || iw >= IW) return; |
649 | |
650 | size_t diff_src_offset |
651 | = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw; |
652 | diff_src[diff_src_offset] += d[0]; |
653 | }; |
654 | |
655 | auto ker_avg = [=](const float *d, float *diff_src, dim_t mb, dim_t c, |
656 | dim_t od, dim_t oh, dim_t ow) { |
657 | auto id_start = apply_offset(od * SD, padF); |
658 | auto ih_start = apply_offset(oh * SH, padT); |
659 | auto iw_start = apply_offset(ow * SW, padL); |
660 | auto id_end = min(od * SD - padF + KD, ID); |
661 | auto ih_end = min(oh * SH - padT + KH, IH); |
662 | auto iw_end = min(ow * SW - padL + KW, IW); |
663 | |
664 | size_t num_summands = (alg == alg_kind::pooling_avg_include_padding) |
665 | ? (size_t)KW * KH * KD |
666 | : (size_t)(id_end - id_start) * (ih_end - ih_start) |
667 | * (iw_end - iw_start); |
668 | |
669 | for_(dim_t id = id_start; id < id_end; ++id) |
670 | for_(dim_t ih = ih_start; ih < ih_end; ++ih) |
671 | for (dim_t iw = iw_start; iw < iw_end; ++iw) { |
672 | size_t diff_src_offset |
673 | = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw; |
674 | diff_src[diff_src_offset] += d[0] / num_summands; |
675 | } |
676 | }; |
677 | |
678 | dim_t ow_start = max(dim_t(0), utils::div_up(padL - KW + 1, SW)); |
679 | dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW); |
680 | |
681 | dim_t oh_start = max(dim_t(0), utils::div_up(padT - KH + 1, SH)); |
682 | dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH); |
683 | |
684 | dim_t od_start = max(dim_t(0), utils::div_up(padF - KD + 1, SD)); |
685 | dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD); |
686 | |
687 | dim_t c_blk = pd()->channel_block_size_; |
688 | dim_t c_blk_tail = C % c_blk; |
689 | const int nthr = pd()->nthr_; |
690 | |
691 | if (alg == alg_kind::pooling_max) { |
692 | parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), |
693 | [&](int ithr, int, dim_t mb, dim_t cb) { |
694 | bool is_last_c_block |
695 | = c_blk_tail > 0 && (cb + 1) * c_blk > C; |
696 | dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk; |
697 | size_t diff_dst_offset_b |
698 | = ((size_t)mb * C + (size_t)cb * c_blk) * OD * OH |
699 | * OW; |
700 | size_t diff_src_offset |
701 | = ((size_t)mb * C + (size_t)cb * c_blk) * ID * IH |
702 | * IW; |
703 | float *diff_dst_fp32 = &cvt_dst[ithr * dst_sp_size * c_blk]; |
704 | float *diff_src_fp32 = &cvt_src[ithr * src_sp_size * c_blk]; |
705 | |
706 | ker_zero(diff_src_fp32, curr_c_block); |
707 | |
708 | types::cvt_to_float(diff_dst_fp32, |
709 | &diff_dst[diff_dst_offset_b], |
710 | dst_sp_size * curr_c_block); |
711 | |
712 | for_(dim_t c = 0; c < curr_c_block; ++c) |
713 | for_(dim_t od = od_start; od < od_end; ++od) |
714 | for (dim_t oh = oh_start; oh < oh_end; ++oh) { |
715 | size_t diff_dst_offset = (size_t)c * OD * OH * OW |
716 | + (size_t)od * OH * OW + (size_t)oh * OW; |
717 | for (dim_t ow = ow_start; ow < ow_end; ++ow) { |
718 | const float *d |
719 | = &diff_dst_fp32[diff_dst_offset + ow]; |
720 | ker_max(d, &diff_src_fp32[c * ID * IH * IW], mb, |
721 | cb * c_blk + c, od, oh, ow); |
722 | } |
723 | } |
724 | types::cvt_from_float(&diff_src[diff_src_offset], |
725 | diff_src_fp32, src_sp_size * curr_c_block); |
726 | }); |
727 | } else { |
728 | parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), |
729 | [&](int ithr, int, dim_t mb, dim_t cb) { |
730 | bool is_last_c_block |
731 | = c_blk_tail > 0 && (cb + 1) * c_blk > C; |
732 | dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk; |
733 | size_t diff_dst_offset_b = (size_t)mb * C * OD * OH * OW |
734 | + (size_t)cb * c_blk * OD * OH * OW; |
735 | float *diff_dst_fp32 = &cvt_dst[ithr * dst_sp_size * c_blk]; |
736 | size_t diff_src_offset = (size_t)mb * C * ID * IH * IW |
737 | + (size_t)cb * c_blk * ID * IH * IW; |
738 | float *diff_src_fp32 = &cvt_src[ithr * src_sp_size * c_blk]; |
739 | |
740 | ker_zero(diff_src_fp32, curr_c_block); |
741 | |
742 | types::cvt_to_float(diff_dst_fp32, |
743 | &diff_dst[diff_dst_offset_b], |
744 | dst_sp_size * curr_c_block); |
745 | for_(dim_t c = 0; c < curr_c_block; ++c) |
746 | for_(dim_t od = od_start; od < od_end; ++od) |
747 | for (dim_t oh = oh_start; oh < oh_end; ++oh) { |
748 | size_t diff_dst_offset = (size_t)c * OD * OH * OW |
749 | + (size_t)od * OH * OW + (size_t)oh * OW; |
750 | for (dim_t ow = ow_start; ow < ow_end; ++ow) { |
751 | const float *d |
752 | = &diff_dst_fp32[diff_dst_offset + ow]; |
753 | ker_avg(d, &diff_src_fp32[c * ID * IH * IW], mb, |
754 | cb * c_blk + c, od, oh, ow); |
755 | } |
756 | } |
757 | types::cvt_from_float(&diff_src[diff_src_offset], |
758 | diff_src_fp32, src_sp_size * curr_c_block); |
759 | }); |
760 | } |
761 | |
762 | return status::success; |
763 | } |
764 | template struct nchw_pooling_fwd_t<data_type::f32>; |
765 | template struct nchw_pooling_bwd_t<data_type::f32>; |
766 | template struct nchw_pooling_fwd_t<data_type::bf16>; |
767 | template struct nchw_pooling_bwd_t<data_type::bf16>; |
768 | template struct nchw_pooling_fwd_t<data_type::f16>; |
769 | template struct nchw_pooling_bwd_t<data_type::f16>; |
770 | } // namespace cpu |
771 | } // namespace impl |
772 | } // namespace dnnl |
773 | |
774 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
775 | |