1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <functional>
18#include "common/c_types_map.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/dnnl_traits.hpp"
21#include "common/math_utils.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/cpu_primitive.hpp"
25#include "cpu/ref_io_helper.hpp"
26
27#include "cpu/ref_convolution_utils.hpp"
28#include "cpu/ref_deconvolution.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34void ref_deconvolution_fwd_t::compute_fwd_bias_common(const exec_ctx_t &ctx,
35 void *dst, const float *conv_output, bool non_default_attr) const {
36 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
37 const memory_desc_wrapper dst_d(pd()->dst_md());
38 const memory_desc_wrapper bias_d(pd()->weights_md(1));
39
40 const auto G = pd()->G();
41 const auto MB = pd()->MB();
42 const auto OH = pd()->OH();
43 const auto OW = pd()->OW();
44 const auto OD = pd()->OD();
45 const auto OC = pd()->OC() / G;
46 const auto ndims = pd()->desc()->src_desc.ndims;
47
48 parallel_nd(MB, G, OC, OD, OH, OW,
49 [&](dim_t mb, dim_t g, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
50 const dim_t c = g * OC + oc;
51 const dim_t off = ref_conv_utils::get_data_off(
52 dst_d, ndims, mb, c, od, oh, ow);
53 float b = io::load_float_value(bias_d.data_type(), bias, c);
54 float d = conv_output[off];
55 // Use f32 if attributes happen after bias to get precise answer
56 auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
57 io::store_float_value(dt, d + b, dst, off);
58 });
59}
60
61void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const exec_ctx_t &ctx,
62 void *dst, const float *conv_output, bool non_default_attr) const {
63 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
64 const memory_desc_wrapper dst_d(pd()->dst_md());
65 const memory_desc_wrapper bias_d(pd()->weights_md(1));
66
67 const auto MB = pd()->MB();
68 const auto OC = pd()->OC();
69 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
70
71 parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
72 const dim_t off = (mb * OC + oc) * SP;
73 float b = io::load_float_value(bias_d.data_type(), bias, oc);
74 PRAGMA_OMP_SIMD()
75 for (dim_t sp = 0; sp < SP; ++sp) {
76 float d = conv_output[off + sp];
77 // Use f32 if attributes happen after bias to get precise answer.
78 auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
79 io::store_float_value(dt, d + b, dst, off + sp);
80 }
81 });
82}
83
84void ref_deconvolution_fwd_t::compute_fwd_bias_ndhwc(const exec_ctx_t &ctx,
85 void *dst, const float *conv_output, bool non_default_attr) const {
86 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
87 const memory_desc_wrapper dst_d(pd()->dst_md());
88 const memory_desc_wrapper bias_d(pd()->weights_md(1));
89
90 const auto MB = pd()->MB();
91 const auto OC = pd()->OC();
92 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
93
94 parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) {
95 const dim_t off = (mb * SP + sp) * OC;
96 PRAGMA_OMP_SIMD()
97 for (dim_t oc = 0; oc < OC; ++oc) {
98 float b = io::load_float_value(bias_d.data_type(), bias, oc);
99 float d = conv_output[off + oc];
100 // Use f32 if attributes happen after bias to get precise answer.
101 auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
102 io::store_float_value(dt, d + b, dst, off + oc);
103 }
104 });
105}
106
107template <dim_t blk_size>
108void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx,
109 void *dst, const float *conv_output, bool non_default_attr) const {
110 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
111 const memory_desc_wrapper dst_d(pd()->dst_md());
112 const memory_desc_wrapper bias_d(pd()->weights_md(1));
113
114 const auto MB = pd()->MB();
115 const auto OC = pd()->OC();
116 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
117 const auto stride_mb = dst_d.blocking_desc().strides[0];
118
119 parallel_nd(MB, utils::div_up(OC, blk_size), SP,
120 [&](dim_t mb, dim_t oc_blk, dim_t sp) {
121 const dim_t oc = oc_blk * blk_size;
122 const dim_t off = mb * stride_mb + oc * SP + sp * blk_size;
123 const dim_t blk = nstl::min(blk_size, OC - oc);
124
125 PRAGMA_OMP_SIMD()
126 for (dim_t i = 0; i < blk_size; ++i) {
127 float b = i < blk ? io::load_float_value(
128 bias_d.data_type(), bias, oc + i)
129 : 0;
130 float d = conv_output[off + i];
131 // Use f32 if attributes happen after bias to get precise
132 // answer.
133 auto dt = non_default_attr ? data_type::f32
134 : dst_d.data_type();
135 io::store_float_value(dt, d + b, dst, off + i);
136 }
137 });
138}
139
140void ref_deconvolution_fwd_t::compute_fwd_bias(const exec_ctx_t &ctx, void *dst,
141 const float *conv_output, bool non_default_attr) const {
142 using namespace format_tag;
143 switch (pd()->dst_tag_) {
144 case ncdhw:
145 case nchw:
146 case ncw:
147 compute_fwd_bias_ncdhw(ctx, dst, conv_output, non_default_attr);
148 break;
149 case ndhwc:
150 case nhwc:
151 case nwc:
152 compute_fwd_bias_ndhwc(ctx, dst, conv_output, non_default_attr);
153 break;
154 case nCdhw8c:
155 case nChw8c:
156 case nCw8c:
157 compute_fwd_bias_nCdhwXc<8>(
158 ctx, dst, conv_output, non_default_attr);
159 break;
160 case nCdhw16c:
161 case nChw16c:
162 case nCw16c:
163 compute_fwd_bias_nCdhwXc<16>(
164 ctx, dst, conv_output, non_default_attr);
165 break;
166 default:
167 compute_fwd_bias_common(ctx, dst, conv_output, non_default_attr);
168 break;
169 }
170}
171
172status_t ref_deconvolution_fwd_t::compute_oscale(
173 const exec_ctx_t &ctx, float *dst) const {
174
175 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
176 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
177 const int wei_scale_mask
178 = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
179
180 const memory_desc_wrapper dst_d(pd()->dst_md());
181
182 const auto MB = pd()->MB();
183 const auto OH = pd()->OH();
184 const auto OW = pd()->OW();
185 const auto OD = pd()->OD();
186 const auto OC = pd()->OC();
187 const auto OCP = dst_d.padded_dims()[1];
188 const auto ndims = pd()->desc()->src_desc.ndims;
189
190 const auto maybe_oscale = [](float &d, dim_t oc, const float *src_scales,
191 const float *wei_scales, int wei_mask) {
192 // scale_idx_mult = 1 for per_oc scales and 0, otherwise
193 const int wei_scale_idx_mult = wei_mask != 0;
194 d *= src_scales[0] * wei_scales[oc * wei_scale_idx_mult];
195 };
196
197 parallel_nd(MB, OCP, OD, OH, OW,
198 [&](dim_t mb, int ocp, dim_t od, dim_t oh, dim_t ow) {
199 auto dst_off = ref_conv_utils::get_data_off(
200 dst_d, ndims, mb, ocp, od, oh, ow);
201 float tmp_result = 0;
202
203 if (ocp < OC) {
204 tmp_result = dst[dst_off];
205 maybe_oscale(tmp_result, ocp, src_scales, wei_scales,
206 wei_scale_mask);
207 dst[dst_off] = tmp_result;
208 }
209 });
210
211 return status_t::dnnl_success;
212}
213
214status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx,
215 const float *conv_output, void *original_dst) const {
216 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
217
218 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
219 const int dst_scale_mask = pd()->attr()->scales_.get(DNNL_ARG_DST).mask_;
220
221 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
222 const bool is_dst_zp_common
223 = pd()->attr()->zero_points_.common(DNNL_ARG_DST);
224
225 const memory_desc_wrapper dst_d(pd()->dst_md());
226
227 const auto MB = pd()->MB();
228 const auto OH = pd()->OH();
229 const auto OW = pd()->OW();
230 const auto OD = pd()->OD();
231 const auto OC = pd()->OC();
232 const auto OCP = dst_d.padded_dims()[1];
233 const auto ndims = pd()->desc()->src_desc.ndims;
234
235 const auto maybe_dst_zero_point = [=](float &result, dim_t oc) {
236 if (is_dst_zp_common)
237 result += dst_zero_point[0];
238 else
239 result += dst_zero_point[oc];
240 };
241
242 const auto maybe_scale
243 = [](float &d, dim_t oc, const float *scales, int mask) {
244 // scale_idx_mult = 1 for per_oc scales and 0, otherwise
245 const int scale_idx_mult = mask != 0;
246 d *= scales[oc * scale_idx_mult];
247 };
248
249 parallel_nd(MB, OCP, OD, OH, OW,
250 [&](dim_t mb, int ocp, dim_t od, dim_t oh, dim_t ow) {
251 auto dst_off = ref_conv_utils::get_data_off(
252 dst_d, ndims, mb, ocp, od, oh, ow);
253 float tmp_result = 0;
254
255 if (ocp < OC) {
256 dim_t dst_l_off = (mb * OC + ocp) * OD * OH * OW
257 + od * OH * OW + oh * OW + ow;
258 tmp_result = conv_output[dst_off];
259
260 ref_post_ops_t::args_t args;
261 if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1)
262 args.dst_val = io::load_float_value(
263 dst_d.data_type(), original_dst, dst_off);
264 args.ctx = &ctx;
265 args.l_offset = dst_l_off;
266 args.dst_md = pd()->dst_md();
267 ref_post_ops->execute(tmp_result, args);
268 maybe_scale(tmp_result, ocp, dst_scales, dst_scale_mask);
269 maybe_dst_zero_point(tmp_result, ocp);
270 }
271 io::store_float_value(
272 dst_d.data_type(), tmp_result, dst, dst_off);
273 });
274
275 return status_t::dnnl_success;
276}
277
278dim_t get_weights_off(const memory_desc_wrapper &wei_d, bool with_groups,
279 int ndims, dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
280 switch (ndims) {
281 case 5:
282 return with_groups ? wei_d.off(g, oc, ic, kd, kh, kw)
283 : wei_d.off(oc, ic, kd, kh, kw);
284 case 4:
285 return with_groups ? wei_d.off(g, oc, ic, kh, kw)
286 : wei_d.off(oc, ic, kh, kw);
287 case 3:
288 return with_groups ? wei_d.off(g, oc, ic, kw)
289 : wei_d.off(oc, ic, kw);
290 default: assert(!"unsupported ndims"); return dim_t(0);
291 }
292
293 return 0;
294};
295
296template <data_type_t wei_type>
297static void compute_src_zp_compensation(const exec_ctx_t &ctx,
298 const int32_t *src_zero_point, const bool is_src_zp_common,
299 typename prec_traits<wei_type>::type *wei,
300 const cpu_deconvolution_fwd_pd_t *pd) {
301 using namespace memory_tracking::names;
302
303 const auto scratchpad = ctx.get_scratchpad_grantor();
304 int32_t *zp_compensation = scratchpad.get<int32_t>(key_deconv_zp);
305 const auto G = pd->G();
306 const auto KH = pd->KH();
307 const auto KW = pd->KW();
308 const auto KD = pd->KD();
309 const auto OC = pd->OC() / G;
310 const auto IC = pd->IC() / G;
311 const memory_desc_wrapper wei_d(pd->weights_md());
312 const bool with_groups = pd->with_groups();
313 const auto ndims = wei_d.ndims() - (with_groups ? 1 : 0);
314 const auto get_wei_off
315 = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
316 return get_weights_off(
317 wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
318 };
319
320 parallel_nd(G, OC, [&](const dim_t g, const dim_t oc) {
321 const auto out_offset = g * OC + oc;
322 int32_t acc = 0;
323
324 for_(dim_t kd = 0; kd < KD; ++kd)
325 for_(dim_t kh = 0; kh < KH; ++kh)
326 for (dim_t kw = 0; kw < KW; ++kw) {
327 for (dim_t ic = 0; ic < IC; ++ic) {
328 const auto weights_offset = get_wei_off(g, oc, ic, kd, kh, kw);
329 const int32_t wei32 = static_cast<int32_t>(wei[weights_offset]);
330
331 if (is_src_zp_common)
332 acc += wei32;
333 else
334 acc += wei32 * src_zero_point[g * IC + ic];
335 }
336 }
337
338 zp_compensation[out_offset] = acc * src_zero_point[0];
339 });
340}
341
342template <data_type_t wei_type>
343static std::function<int32_t(
344 const dim_t, const dim_t, const dim_t, const dim_t, const dim_t)>
345prepare_zp_pad_comp_ker(const dim_t ndims, const int32_t *src_zero_point,
346 const bool is_src_zp_common, typename prec_traits<wei_type>::type *wei,
347 const cpu_deconvolution_fwd_pd_t *deconv_pd) {
348
349 const auto KH = deconv_pd->KH();
350 const auto KW = deconv_pd->KW();
351 const auto KD = deconv_pd->KD();
352 const auto KSD = deconv_pd->KSD();
353 const auto KSH = deconv_pd->KSH();
354 const auto KSW = deconv_pd->KSW();
355 const auto KDD = deconv_pd->KDD() + 1;
356 const auto KDH = deconv_pd->KDH() + 1;
357 const auto KDW = deconv_pd->KDW() + 1;
358 const auto IC = deconv_pd->IC() / deconv_pd->G();
359 const auto IH = deconv_pd->IH();
360 const auto IW = deconv_pd->IW();
361 const auto ID = deconv_pd->ID();
362 const auto pad_front = deconv_pd->padFront();
363 const auto pad_top = deconv_pd->padT();
364 const auto pad_left = deconv_pd->padL();
365 const bool with_groups = deconv_pd->with_groups();
366 const memory_desc_wrapper wei_d(deconv_pd->weights_md());
367 const auto get_wei_off
368 = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
369 return get_weights_off(
370 wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
371 };
372
373 return [=](const dim_t g, const dim_t oc, const dim_t od, const dim_t oh,
374 const dim_t ow) {
375 int32_t zp_pad_compensation = 0;
376
377 for (dim_t kd = 0; kd < KD; ++kd) {
378 const dim_t id = od - kd * KDD + pad_front;
379 const bool should_apply_pad_comp_d
380 = id < 0 || id % KSD != 0 || (id / KSD) >= ID;
381
382 for (dim_t kh = 0; kh < KH; ++kh) {
383 const dim_t ih = oh - kh * KDH + pad_top;
384 const bool should_apply_pad_comp_h
385 = ih < 0 || ih % KSH != 0 || (ih / KSH) >= IH;
386
387 for (dim_t kw = 0; kw < KW; ++kw) {
388 const dim_t iw = ow - kw * KDW + pad_left;
389 const bool should_apply_pad_comp_w
390 = iw < 0 || iw % KSW != 0 || (iw / KSW) >= IW;
391
392 if (should_apply_pad_comp_d || should_apply_pad_comp_h
393 || should_apply_pad_comp_w) {
394
395 for (dim_t ic = 0; ic < IC; ic++) {
396 const auto wei_off
397 = get_wei_off(g, oc, ic, kd, kh, kw);
398 const int32_t wei32
399 = static_cast<int32_t>(wei[wei_off]);
400
401 if (is_src_zp_common)
402 zp_pad_compensation += wei32;
403 else
404 zp_pad_compensation
405 += wei32 * src_zero_point[g * IC + ic];
406 }
407 }
408 }
409 }
410 }
411
412 if (is_src_zp_common && zp_pad_compensation)
413 zp_pad_compensation *= src_zero_point[0];
414
415 return zp_pad_compensation;
416 };
417}
418
419template <data_type_t wei_type>
420static status_t apply_src_zero_point(const exec_ctx_t &ctx,
421 const cpu_deconvolution_fwd_pd_t *deconv_pd, float *conv_output) {
422 using wei_data_t = typename prec_traits<wei_type>::type;
423 using namespace memory_tracking::names;
424 using namespace data_type;
425
426 // required by DEFINE_ZERO_POINTS_BUFFER macro
427 const auto pd = [&]() { return deconv_pd; };
428 const auto wei = CTX_OUT_MEM(wei_data_t *, DNNL_ARG_WEIGHTS);
429 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
430 const bool is_src_zp_common
431 = deconv_pd->attr()->zero_points_.common(DNNL_ARG_SRC);
432
433 const auto scratchpad = ctx.get_scratchpad_grantor();
434 const int32_t *const zp_src_compensation
435 = scratchpad.get<int32_t>(key_deconv_zp);
436 const memory_desc_wrapper dst_d(pd()->dst_md());
437 const auto ndims = dst_d.ndims();
438
439 const auto G = pd()->G();
440 const auto MB = pd()->MB();
441 const auto OH = pd()->OH();
442 const auto OW = pd()->OW();
443 const auto OD = pd()->OD();
444 const auto OC = pd()->OC() / G;
445
446 compute_src_zp_compensation<wei_type>(
447 ctx, src_zero_point, is_src_zp_common, wei, deconv_pd);
448 const auto zp_pad_comp_ker = prepare_zp_pad_comp_ker<wei_type>(
449 ndims, src_zero_point, is_src_zp_common, wei, deconv_pd);
450
451 parallel_nd(MB, G, OC, OD, OH, OW,
452 [&](const dim_t mb, const dim_t g, const dim_t oc, const dim_t od,
453 const dim_t oh, const dim_t ow) {
454 const auto oc_off = g * OC + oc;
455 const auto dst_off = ref_conv_utils::get_data_off(
456 dst_d, ndims, mb, oc_off, od, oh, ow);
457 int32_t conv_result
458 = conv_output[dst_off] - zp_src_compensation[oc_off];
459
460 if (const auto zp_pad_compensation
461 = zp_pad_comp_ker(g, oc, od, oh, ow)) {
462 conv_result += zp_pad_compensation;
463 }
464
465 conv_output[dst_off] = static_cast<float>(conv_result);
466 });
467
468 return status::success;
469}
470
471status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const {
472 using namespace memory_tracking::names;
473 const auto scratchpad = ctx.get_scratchpad_grantor();
474 const bool ref_bias = pd()->with_bias() && !pd()->conv_supports_bias_;
475 const bool non_default_attr = !pd()->attr()->has_default_values();
476
477 const auto &args = ctx.args();
478 exec_args_t conv_args;
479 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
480 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
481 if (pd()->with_bias() && pd()->conv_supports_bias_)
482 conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);
483
484 // Create intermediate memory for f32 output if needed.
485 auto dst = args.at(DNNL_ARG_DST);
486 memory_t tmp_memory(dst.mem->engine(), pd()->conv_pd_->diff_src_md(),
487 scratchpad.get_memory_storage(key_deconv_bias));
488 memory_arg_t tmp_conv_output = {&tmp_memory, false};
489
490 conv_args[DNNL_ARG_DIFF_SRC]
491 = ref_bias || non_default_attr ? tmp_conv_output : dst;
492
493 // When sum post-op happens, we need to copy original destination memory
494 // prior call to external convolution happens.
495 if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1) {
496 void *original_dst = scratchpad.get(key_deconv_sum);
497 const memory_desc_wrapper dst_d(pd()->dst_md());
498 void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
499 const auto dt_size = dst_d.data_type_size();
500
501 parallel(0, [&](const int ithr, const int nthr) {
502 dim_t start {0}, end {0};
503 balance211(dst_d.nelems(true), nthr, ithr, start, end);
504 auto o_dst_start = (char *)original_dst + start * dt_size;
505 auto dst_start = (char *)dst + start * dt_size;
506 const auto size = (end - start) * dt_size;
507
508 std::memcpy(o_dst_start, dst_start, size);
509 });
510 }
511
512 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
513
514 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
515 conv_ctx.set_scratchpad_grantor(ns.grantor());
516 auto status = conv_p_->execute(conv_ctx);
517 if (status != status::success) return status;
518
519 using namespace data_type;
520
521 if (!pd()->attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) {
522 float *conv_output = scratchpad.get<float>(key_deconv_bias);
523 const auto wei_dt = pd()->weights_md()->data_type;
524 switch (wei_dt) {
525 case s8: apply_src_zero_point<s8>(ctx, pd(), conv_output); break;
526 case u8: apply_src_zero_point<u8>(ctx, pd(), conv_output); break;
527 default: assert(!"unsupported data type");
528 }
529 }
530
531 float *conv_output = scratchpad.get<float>(key_deconv_bias);
532
533 const auto &arg_scales = pd()->attr()->scales_;
534 const auto &src_scales = arg_scales.get(DNNL_ARG_SRC);
535 const auto &wei_scales = arg_scales.get(DNNL_ARG_WEIGHTS);
536
537 if (!src_scales.has_default_values() || !wei_scales.has_default_values()) {
538 compute_oscale(ctx, conv_output);
539 }
540
541 if (ref_bias) {
542 void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
543 void *tmp_output = non_default_attr ? conv_output : dst;
544 compute_fwd_bias(ctx, tmp_output, conv_output, non_default_attr);
545 }
546
547 if (non_default_attr) {
548 void *original_dst = scratchpad.get<void>(key_deconv_sum);
549 compute_ref_attrs(ctx, conv_output, original_dst);
550 }
551
552 return status::success;
553}
554
555status_t ref_deconvolution_bwd_data_t::execute(const exec_ctx_t &ctx) const {
556 using namespace memory_tracking::names;
557 const auto &args = ctx.args();
558 exec_args_t conv_args;
559 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
560 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
561 conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
562 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
563
564 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
565 conv_ctx.set_scratchpad_grantor(ns.grantor());
566 conv_p_->execute(conv_ctx);
567 return status::success;
568}
569
570void ref_deconvolution_bwd_weights_t::compute_bwd_bias(
571 float *diff_bias, const float *diff_dst) const {
572 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
573
574 const auto G = pd()->G();
575 const auto MB = pd()->MB();
576 const auto OH = pd()->OH();
577 const auto OW = pd()->OW();
578 const auto OC = pd()->OC() / G;
579 const auto OD = pd()->OD();
580 const auto ndims = pd()->desc()->src_desc.ndims;
581
582 parallel_nd(G, OC, [&](dim_t g, dim_t oc) {
583 float db = 0;
584 for_(dim_t mb = 0; mb < MB; ++mb)
585 for_(dim_t od = 0; od < OD; ++od)
586 for_(dim_t oh = 0; oh < OH; ++oh)
587 for (dim_t ow = 0; ow < OW; ++ow) {
588 const auto d_dst_off = ref_conv_utils::get_data_off(
589 diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
590 db += diff_dst[d_dst_off];
591 }
592 diff_bias[g * OC + oc] = db;
593 });
594}
595
596template <data_type_t dbia_type, data_type_t ddst_type>
597void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw(
598 typename prec_traits<dbia_type>::type *diff_bias,
599 const typename prec_traits<ddst_type>::type *diff_dst) const {
600 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
601
602 const auto OC = pd()->OC();
603 const auto MB = pd()->MB();
604 const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
605
606 parallel_nd(OC, [&](dim_t oc) {
607 float db = 0;
608 for (dim_t mb = 0; mb < MB; ++mb) {
609 PRAGMA_OMP_SIMD(reduction(+ : db))
610 for (dim_t sp = 0; sp < SP; ++sp) {
611 auto offset = (size_t)(mb * OC + oc) * SP + sp;
612 db += diff_dst[offset];
613 }
614 }
615 diff_bias[oc] = db;
616 });
617}
618
619template <data_type_t dbia_type, data_type_t ddst_type>
620void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ndhwc(
621 typename prec_traits<dbia_type>::type *diff_bias,
622 const typename prec_traits<ddst_type>::type *diff_dst) const {
623 const auto MB = pd()->MB();
624 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
625 const auto OC = pd()->OC();
626
627 parallel_nd(OC, [&](dim_t oc) {
628 float db = 0;
629 for (dim_t mb = 0; mb < MB; ++mb) {
630 PRAGMA_OMP_SIMD(reduction(+ : db))
631 for (dim_t sp = 0; sp < SP; ++sp) {
632 const dim_t offset = (mb * SP + sp) * OC + oc;
633 db += diff_dst[offset];
634 }
635 }
636 diff_bias[oc] = static_cast<typename prec_traits<dbia_type>::type>(db);
637 });
638}
639
640template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize>
641void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc(
642 typename prec_traits<dbia_type>::type *diff_bias,
643 const typename prec_traits<ddst_type>::type *diff_dst) const {
644 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
645
646 const auto OC = pd()->OC();
647 const auto MB = pd()->MB();
648 const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
649
650 const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0];
651
652 parallel_nd(utils::div_up(OC, blksize), [&](dim_t ocb) {
653 float db[blksize] = {0};
654
655 for (dim_t mb = 0; mb < MB; ++mb) {
656 for (dim_t sp = 0; sp < SP; ++sp) {
657 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
658
659 PRAGMA_OMP_SIMD()
660 for (dim_t i = 0; i < blksize; ++i)
661 db[i] += diff_dst[offset + i];
662 }
663 }
664
665 const dim_t blk = nstl::min(blksize, OC - ocb * blksize);
666
667 PRAGMA_OMP_SIMD()
668 for (dim_t i = 0; i < blk; ++i)
669 diff_bias[ocb * blksize + i] = db[i];
670 });
671}
672
673template <data_type_t dbia_type, data_type_t ddst_type>
674void ref_deconvolution_bwd_weights_t::compute_bias(
675 const exec_ctx_t &ctx) const {
676 using dbia_data_t = typename prec_traits<dbia_type>::type;
677 using ddst_data_t = typename prec_traits<ddst_type>::type;
678
679 auto diff_bias = CTX_OUT_MEM(dbia_data_t *, DNNL_ARG_DIFF_BIAS);
680 auto diff_dst = CTX_IN_MEM(const ddst_data_t *, DNNL_ARG_DIFF_DST);
681
682 using namespace format_tag;
683 switch (pd()->dst_tag_) {
684 case ncdhw:
685 case nchw:
686 case ncw:
687 compute_bwd_bias_ncdhw<dbia_type, ddst_type>(diff_bias, diff_dst);
688 break;
689 case ndhwc:
690 case nhwc:
691 case nwc:
692 compute_bwd_bias_ndhwc<dbia_type, ddst_type>(diff_bias, diff_dst);
693 break;
694 case nCdhw8c:
695 case nChw8c:
696 case nCw8c:
697 assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
698 compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 8>(
699 diff_bias, diff_dst);
700 break;
701 case nCdhw16c:
702 case nChw16c:
703 case nCw16c:
704 compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 16>(
705 diff_bias, diff_dst);
706 break;
707 default:
708 assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
709 compute_bwd_bias((float *)diff_bias, (const float *)diff_dst);
710 break;
711 }
712}
713
714status_t ref_deconvolution_bwd_weights_t::execute(const exec_ctx_t &ctx) const {
715 using namespace memory_tracking::names;
716 const auto &args = ctx.args();
717 exec_args_t conv_args;
718 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
719 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
720 conv_args[DNNL_ARG_DIFF_WEIGHTS] = args.at(DNNL_ARG_DIFF_WEIGHTS);
721 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
722
723 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
724 conv_ctx.set_scratchpad_grantor(ns.grantor());
725 status_t status = conv_p_->execute(conv_ctx);
726 if (status != status::success) return status;
727
728 if (pd()->with_bias()) {
729 using namespace data_type;
730
731 auto dbia_type = pd()->diff_weights_md(1)->data_type;
732 auto ddst_type = pd()->diff_dst_md()->data_type;
733 if (utils::everyone_is(f32, dbia_type, ddst_type))
734 compute_bias<f32, f32>(ctx);
735 else if (utils::everyone_is(bf16, dbia_type, ddst_type))
736 compute_bias<bf16, bf16>(ctx);
737 else if (dbia_type == f32 && ddst_type == bf16)
738 compute_bias<f32, bf16>(ctx);
739 else if (utils::everyone_is(f16, dbia_type, ddst_type))
740 compute_bias<f16, f16>(ctx);
741 else if (dbia_type == f32 && ddst_type == f16)
742 compute_bias<f32, f16>(ctx);
743 else {
744 assert(!"unsupported data type");
745 return status::runtime_error;
746 }
747 }
748 return status::success;
749}
750
751using namespace data_type;
752
753template void ref_deconvolution_bwd_weights_t::compute_bias<f32, f32>(
754 const exec_ctx_t &ctx) const;
755template void ref_deconvolution_bwd_weights_t::compute_bias<f32, bf16>(
756 const exec_ctx_t &ctx) const;
757template void ref_deconvolution_bwd_weights_t::compute_bias<bf16, bf16>(
758 const exec_ctx_t &ctx) const;
759template void ref_deconvolution_bwd_weights_t::compute_bias<f32, f16>(
760 const exec_ctx_t &ctx) const;
761template void ref_deconvolution_bwd_weights_t::compute_bias<f16, f16>(
762 const exec_ctx_t &ctx) const;
763} // namespace cpu
764} // namespace impl
765} // namespace dnnl
766
767// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
768