1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <functional> |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/dnnl_traits.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/cpu_primitive.hpp" |
25 | #include "cpu/ref_io_helper.hpp" |
26 | |
27 | #include "cpu/ref_convolution_utils.hpp" |
28 | #include "cpu/ref_deconvolution.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | void ref_deconvolution_fwd_t::compute_fwd_bias_common(const exec_ctx_t &ctx, |
35 | void *dst, const float *conv_output, bool non_default_attr) const { |
36 | const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
37 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
38 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
39 | |
40 | const auto G = pd()->G(); |
41 | const auto MB = pd()->MB(); |
42 | const auto OH = pd()->OH(); |
43 | const auto OW = pd()->OW(); |
44 | const auto OD = pd()->OD(); |
45 | const auto OC = pd()->OC() / G; |
46 | const auto ndims = pd()->desc()->src_desc.ndims; |
47 | |
48 | parallel_nd(MB, G, OC, OD, OH, OW, |
49 | [&](dim_t mb, dim_t g, dim_t oc, dim_t od, dim_t oh, dim_t ow) { |
50 | const dim_t c = g * OC + oc; |
51 | const dim_t off = ref_conv_utils::get_data_off( |
52 | dst_d, ndims, mb, c, od, oh, ow); |
53 | float b = io::load_float_value(bias_d.data_type(), bias, c); |
54 | float d = conv_output[off]; |
55 | // Use f32 if attributes happen after bias to get precise answer |
56 | auto dt = non_default_attr ? data_type::f32 : dst_d.data_type(); |
57 | io::store_float_value(dt, d + b, dst, off); |
58 | }); |
59 | } |
60 | |
61 | void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const exec_ctx_t &ctx, |
62 | void *dst, const float *conv_output, bool non_default_attr) const { |
63 | const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
64 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
65 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
66 | |
67 | const auto MB = pd()->MB(); |
68 | const auto OC = pd()->OC(); |
69 | const auto SP = pd()->OW() * pd()->OH() * pd()->OD(); |
70 | |
71 | parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) { |
72 | const dim_t off = (mb * OC + oc) * SP; |
73 | float b = io::load_float_value(bias_d.data_type(), bias, oc); |
74 | PRAGMA_OMP_SIMD() |
75 | for (dim_t sp = 0; sp < SP; ++sp) { |
76 | float d = conv_output[off + sp]; |
77 | // Use f32 if attributes happen after bias to get precise answer. |
78 | auto dt = non_default_attr ? data_type::f32 : dst_d.data_type(); |
79 | io::store_float_value(dt, d + b, dst, off + sp); |
80 | } |
81 | }); |
82 | } |
83 | |
84 | void ref_deconvolution_fwd_t::compute_fwd_bias_ndhwc(const exec_ctx_t &ctx, |
85 | void *dst, const float *conv_output, bool non_default_attr) const { |
86 | const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
87 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
88 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
89 | |
90 | const auto MB = pd()->MB(); |
91 | const auto OC = pd()->OC(); |
92 | const auto SP = pd()->OW() * pd()->OH() * pd()->OD(); |
93 | |
94 | parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) { |
95 | const dim_t off = (mb * SP + sp) * OC; |
96 | PRAGMA_OMP_SIMD() |
97 | for (dim_t oc = 0; oc < OC; ++oc) { |
98 | float b = io::load_float_value(bias_d.data_type(), bias, oc); |
99 | float d = conv_output[off + oc]; |
100 | // Use f32 if attributes happen after bias to get precise answer. |
101 | auto dt = non_default_attr ? data_type::f32 : dst_d.data_type(); |
102 | io::store_float_value(dt, d + b, dst, off + oc); |
103 | } |
104 | }); |
105 | } |
106 | |
107 | template <dim_t blk_size> |
108 | void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx, |
109 | void *dst, const float *conv_output, bool non_default_attr) const { |
110 | const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
111 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
112 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
113 | |
114 | const auto MB = pd()->MB(); |
115 | const auto OC = pd()->OC(); |
116 | const auto SP = pd()->OW() * pd()->OH() * pd()->OD(); |
117 | const auto stride_mb = dst_d.blocking_desc().strides[0]; |
118 | |
119 | parallel_nd(MB, utils::div_up(OC, blk_size), SP, |
120 | [&](dim_t mb, dim_t oc_blk, dim_t sp) { |
121 | const dim_t oc = oc_blk * blk_size; |
122 | const dim_t off = mb * stride_mb + oc * SP + sp * blk_size; |
123 | const dim_t blk = nstl::min(blk_size, OC - oc); |
124 | |
125 | PRAGMA_OMP_SIMD() |
126 | for (dim_t i = 0; i < blk_size; ++i) { |
127 | float b = i < blk ? io::load_float_value( |
128 | bias_d.data_type(), bias, oc + i) |
129 | : 0; |
130 | float d = conv_output[off + i]; |
131 | // Use f32 if attributes happen after bias to get precise |
132 | // answer. |
133 | auto dt = non_default_attr ? data_type::f32 |
134 | : dst_d.data_type(); |
135 | io::store_float_value(dt, d + b, dst, off + i); |
136 | } |
137 | }); |
138 | } |
139 | |
140 | void ref_deconvolution_fwd_t::compute_fwd_bias(const exec_ctx_t &ctx, void *dst, |
141 | const float *conv_output, bool non_default_attr) const { |
142 | using namespace format_tag; |
143 | switch (pd()->dst_tag_) { |
144 | case ncdhw: |
145 | case nchw: |
146 | case ncw: |
147 | compute_fwd_bias_ncdhw(ctx, dst, conv_output, non_default_attr); |
148 | break; |
149 | case ndhwc: |
150 | case nhwc: |
151 | case nwc: |
152 | compute_fwd_bias_ndhwc(ctx, dst, conv_output, non_default_attr); |
153 | break; |
154 | case nCdhw8c: |
155 | case nChw8c: |
156 | case nCw8c: |
157 | compute_fwd_bias_nCdhwXc<8>( |
158 | ctx, dst, conv_output, non_default_attr); |
159 | break; |
160 | case nCdhw16c: |
161 | case nChw16c: |
162 | case nCw16c: |
163 | compute_fwd_bias_nCdhwXc<16>( |
164 | ctx, dst, conv_output, non_default_attr); |
165 | break; |
166 | default: |
167 | compute_fwd_bias_common(ctx, dst, conv_output, non_default_attr); |
168 | break; |
169 | } |
170 | } |
171 | |
172 | status_t ref_deconvolution_fwd_t::compute_oscale( |
173 | const exec_ctx_t &ctx, float *dst) const { |
174 | |
175 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
176 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
177 | const int wei_scale_mask |
178 | = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
179 | |
180 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
181 | |
182 | const auto MB = pd()->MB(); |
183 | const auto OH = pd()->OH(); |
184 | const auto OW = pd()->OW(); |
185 | const auto OD = pd()->OD(); |
186 | const auto OC = pd()->OC(); |
187 | const auto OCP = dst_d.padded_dims()[1]; |
188 | const auto ndims = pd()->desc()->src_desc.ndims; |
189 | |
190 | const auto maybe_oscale = [](float &d, dim_t oc, const float *src_scales, |
191 | const float *wei_scales, int wei_mask) { |
192 | // scale_idx_mult = 1 for per_oc scales and 0, otherwise |
193 | const int wei_scale_idx_mult = wei_mask != 0; |
194 | d *= src_scales[0] * wei_scales[oc * wei_scale_idx_mult]; |
195 | }; |
196 | |
197 | parallel_nd(MB, OCP, OD, OH, OW, |
198 | [&](dim_t mb, int ocp, dim_t od, dim_t oh, dim_t ow) { |
199 | auto dst_off = ref_conv_utils::get_data_off( |
200 | dst_d, ndims, mb, ocp, od, oh, ow); |
201 | float tmp_result = 0; |
202 | |
203 | if (ocp < OC) { |
204 | tmp_result = dst[dst_off]; |
205 | maybe_oscale(tmp_result, ocp, src_scales, wei_scales, |
206 | wei_scale_mask); |
207 | dst[dst_off] = tmp_result; |
208 | } |
209 | }); |
210 | |
211 | return status_t::dnnl_success; |
212 | } |
213 | |
214 | status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx, |
215 | const float *conv_output, void *original_dst) const { |
216 | auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
217 | |
218 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
219 | const int dst_scale_mask = pd()->attr()->scales_.get(DNNL_ARG_DST).mask_; |
220 | |
221 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
222 | const bool is_dst_zp_common |
223 | = pd()->attr()->zero_points_.common(DNNL_ARG_DST); |
224 | |
225 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
226 | |
227 | const auto MB = pd()->MB(); |
228 | const auto OH = pd()->OH(); |
229 | const auto OW = pd()->OW(); |
230 | const auto OD = pd()->OD(); |
231 | const auto OC = pd()->OC(); |
232 | const auto OCP = dst_d.padded_dims()[1]; |
233 | const auto ndims = pd()->desc()->src_desc.ndims; |
234 | |
235 | const auto maybe_dst_zero_point = [=](float &result, dim_t oc) { |
236 | if (is_dst_zp_common) |
237 | result += dst_zero_point[0]; |
238 | else |
239 | result += dst_zero_point[oc]; |
240 | }; |
241 | |
242 | const auto maybe_scale |
243 | = [](float &d, dim_t oc, const float *scales, int mask) { |
244 | // scale_idx_mult = 1 for per_oc scales and 0, otherwise |
245 | const int scale_idx_mult = mask != 0; |
246 | d *= scales[oc * scale_idx_mult]; |
247 | }; |
248 | |
249 | parallel_nd(MB, OCP, OD, OH, OW, |
250 | [&](dim_t mb, int ocp, dim_t od, dim_t oh, dim_t ow) { |
251 | auto dst_off = ref_conv_utils::get_data_off( |
252 | dst_d, ndims, mb, ocp, od, oh, ow); |
253 | float tmp_result = 0; |
254 | |
255 | if (ocp < OC) { |
256 | dim_t dst_l_off = (mb * OC + ocp) * OD * OH * OW |
257 | + od * OH * OW + oh * OW + ow; |
258 | tmp_result = conv_output[dst_off]; |
259 | |
260 | ref_post_ops_t::args_t args; |
261 | if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1) |
262 | args.dst_val = io::load_float_value( |
263 | dst_d.data_type(), original_dst, dst_off); |
264 | args.ctx = &ctx; |
265 | args.l_offset = dst_l_off; |
266 | args.dst_md = pd()->dst_md(); |
267 | ref_post_ops->execute(tmp_result, args); |
268 | maybe_scale(tmp_result, ocp, dst_scales, dst_scale_mask); |
269 | maybe_dst_zero_point(tmp_result, ocp); |
270 | } |
271 | io::store_float_value( |
272 | dst_d.data_type(), tmp_result, dst, dst_off); |
273 | }); |
274 | |
275 | return status_t::dnnl_success; |
276 | } |
277 | |
278 | dim_t get_weights_off(const memory_desc_wrapper &wei_d, bool with_groups, |
279 | int ndims, dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) { |
280 | switch (ndims) { |
281 | case 5: |
282 | return with_groups ? wei_d.off(g, oc, ic, kd, kh, kw) |
283 | : wei_d.off(oc, ic, kd, kh, kw); |
284 | case 4: |
285 | return with_groups ? wei_d.off(g, oc, ic, kh, kw) |
286 | : wei_d.off(oc, ic, kh, kw); |
287 | case 3: |
288 | return with_groups ? wei_d.off(g, oc, ic, kw) |
289 | : wei_d.off(oc, ic, kw); |
290 | default: assert(!"unsupported ndims" ); return dim_t(0); |
291 | } |
292 | |
293 | return 0; |
294 | }; |
295 | |
296 | template <data_type_t wei_type> |
297 | static void compute_src_zp_compensation(const exec_ctx_t &ctx, |
298 | const int32_t *src_zero_point, const bool is_src_zp_common, |
299 | typename prec_traits<wei_type>::type *wei, |
300 | const cpu_deconvolution_fwd_pd_t *pd) { |
301 | using namespace memory_tracking::names; |
302 | |
303 | const auto scratchpad = ctx.get_scratchpad_grantor(); |
304 | int32_t *zp_compensation = scratchpad.get<int32_t>(key_deconv_zp); |
305 | const auto G = pd->G(); |
306 | const auto KH = pd->KH(); |
307 | const auto KW = pd->KW(); |
308 | const auto KD = pd->KD(); |
309 | const auto OC = pd->OC() / G; |
310 | const auto IC = pd->IC() / G; |
311 | const memory_desc_wrapper wei_d(pd->weights_md()); |
312 | const bool with_groups = pd->with_groups(); |
313 | const auto ndims = wei_d.ndims() - (with_groups ? 1 : 0); |
314 | const auto get_wei_off |
315 | = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) { |
316 | return get_weights_off( |
317 | wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw); |
318 | }; |
319 | |
320 | parallel_nd(G, OC, [&](const dim_t g, const dim_t oc) { |
321 | const auto out_offset = g * OC + oc; |
322 | int32_t acc = 0; |
323 | |
324 | for_(dim_t kd = 0; kd < KD; ++kd) |
325 | for_(dim_t kh = 0; kh < KH; ++kh) |
326 | for (dim_t kw = 0; kw < KW; ++kw) { |
327 | for (dim_t ic = 0; ic < IC; ++ic) { |
328 | const auto weights_offset = get_wei_off(g, oc, ic, kd, kh, kw); |
329 | const int32_t wei32 = static_cast<int32_t>(wei[weights_offset]); |
330 | |
331 | if (is_src_zp_common) |
332 | acc += wei32; |
333 | else |
334 | acc += wei32 * src_zero_point[g * IC + ic]; |
335 | } |
336 | } |
337 | |
338 | zp_compensation[out_offset] = acc * src_zero_point[0]; |
339 | }); |
340 | } |
341 | |
342 | template <data_type_t wei_type> |
343 | static std::function<int32_t( |
344 | const dim_t, const dim_t, const dim_t, const dim_t, const dim_t)> |
345 | prepare_zp_pad_comp_ker(const dim_t ndims, const int32_t *src_zero_point, |
346 | const bool is_src_zp_common, typename prec_traits<wei_type>::type *wei, |
347 | const cpu_deconvolution_fwd_pd_t *deconv_pd) { |
348 | |
349 | const auto KH = deconv_pd->KH(); |
350 | const auto KW = deconv_pd->KW(); |
351 | const auto KD = deconv_pd->KD(); |
352 | const auto KSD = deconv_pd->KSD(); |
353 | const auto KSH = deconv_pd->KSH(); |
354 | const auto KSW = deconv_pd->KSW(); |
355 | const auto KDD = deconv_pd->KDD() + 1; |
356 | const auto KDH = deconv_pd->KDH() + 1; |
357 | const auto KDW = deconv_pd->KDW() + 1; |
358 | const auto IC = deconv_pd->IC() / deconv_pd->G(); |
359 | const auto IH = deconv_pd->IH(); |
360 | const auto IW = deconv_pd->IW(); |
361 | const auto ID = deconv_pd->ID(); |
362 | const auto pad_front = deconv_pd->padFront(); |
363 | const auto pad_top = deconv_pd->padT(); |
364 | const auto pad_left = deconv_pd->padL(); |
365 | const bool with_groups = deconv_pd->with_groups(); |
366 | const memory_desc_wrapper wei_d(deconv_pd->weights_md()); |
367 | const auto get_wei_off |
368 | = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) { |
369 | return get_weights_off( |
370 | wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw); |
371 | }; |
372 | |
373 | return [=](const dim_t g, const dim_t oc, const dim_t od, const dim_t oh, |
374 | const dim_t ow) { |
375 | int32_t zp_pad_compensation = 0; |
376 | |
377 | for (dim_t kd = 0; kd < KD; ++kd) { |
378 | const dim_t id = od - kd * KDD + pad_front; |
379 | const bool should_apply_pad_comp_d |
380 | = id < 0 || id % KSD != 0 || (id / KSD) >= ID; |
381 | |
382 | for (dim_t kh = 0; kh < KH; ++kh) { |
383 | const dim_t ih = oh - kh * KDH + pad_top; |
384 | const bool should_apply_pad_comp_h |
385 | = ih < 0 || ih % KSH != 0 || (ih / KSH) >= IH; |
386 | |
387 | for (dim_t kw = 0; kw < KW; ++kw) { |
388 | const dim_t iw = ow - kw * KDW + pad_left; |
389 | const bool should_apply_pad_comp_w |
390 | = iw < 0 || iw % KSW != 0 || (iw / KSW) >= IW; |
391 | |
392 | if (should_apply_pad_comp_d || should_apply_pad_comp_h |
393 | || should_apply_pad_comp_w) { |
394 | |
395 | for (dim_t ic = 0; ic < IC; ic++) { |
396 | const auto wei_off |
397 | = get_wei_off(g, oc, ic, kd, kh, kw); |
398 | const int32_t wei32 |
399 | = static_cast<int32_t>(wei[wei_off]); |
400 | |
401 | if (is_src_zp_common) |
402 | zp_pad_compensation += wei32; |
403 | else |
404 | zp_pad_compensation |
405 | += wei32 * src_zero_point[g * IC + ic]; |
406 | } |
407 | } |
408 | } |
409 | } |
410 | } |
411 | |
412 | if (is_src_zp_common && zp_pad_compensation) |
413 | zp_pad_compensation *= src_zero_point[0]; |
414 | |
415 | return zp_pad_compensation; |
416 | }; |
417 | } |
418 | |
419 | template <data_type_t wei_type> |
420 | static status_t apply_src_zero_point(const exec_ctx_t &ctx, |
421 | const cpu_deconvolution_fwd_pd_t *deconv_pd, float *conv_output) { |
422 | using wei_data_t = typename prec_traits<wei_type>::type; |
423 | using namespace memory_tracking::names; |
424 | using namespace data_type; |
425 | |
426 | // required by DEFINE_ZERO_POINTS_BUFFER macro |
427 | const auto pd = [&]() { return deconv_pd; }; |
428 | const auto wei = CTX_OUT_MEM(wei_data_t *, DNNL_ARG_WEIGHTS); |
429 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
430 | const bool is_src_zp_common |
431 | = deconv_pd->attr()->zero_points_.common(DNNL_ARG_SRC); |
432 | |
433 | const auto scratchpad = ctx.get_scratchpad_grantor(); |
434 | const int32_t *const zp_src_compensation |
435 | = scratchpad.get<int32_t>(key_deconv_zp); |
436 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
437 | const auto ndims = dst_d.ndims(); |
438 | |
439 | const auto G = pd()->G(); |
440 | const auto MB = pd()->MB(); |
441 | const auto OH = pd()->OH(); |
442 | const auto OW = pd()->OW(); |
443 | const auto OD = pd()->OD(); |
444 | const auto OC = pd()->OC() / G; |
445 | |
446 | compute_src_zp_compensation<wei_type>( |
447 | ctx, src_zero_point, is_src_zp_common, wei, deconv_pd); |
448 | const auto zp_pad_comp_ker = prepare_zp_pad_comp_ker<wei_type>( |
449 | ndims, src_zero_point, is_src_zp_common, wei, deconv_pd); |
450 | |
451 | parallel_nd(MB, G, OC, OD, OH, OW, |
452 | [&](const dim_t mb, const dim_t g, const dim_t oc, const dim_t od, |
453 | const dim_t oh, const dim_t ow) { |
454 | const auto oc_off = g * OC + oc; |
455 | const auto dst_off = ref_conv_utils::get_data_off( |
456 | dst_d, ndims, mb, oc_off, od, oh, ow); |
457 | int32_t conv_result |
458 | = conv_output[dst_off] - zp_src_compensation[oc_off]; |
459 | |
460 | if (const auto zp_pad_compensation |
461 | = zp_pad_comp_ker(g, oc, od, oh, ow)) { |
462 | conv_result += zp_pad_compensation; |
463 | } |
464 | |
465 | conv_output[dst_off] = static_cast<float>(conv_result); |
466 | }); |
467 | |
468 | return status::success; |
469 | } |
470 | |
471 | status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const { |
472 | using namespace memory_tracking::names; |
473 | const auto scratchpad = ctx.get_scratchpad_grantor(); |
474 | const bool ref_bias = pd()->with_bias() && !pd()->conv_supports_bias_; |
475 | const bool non_default_attr = !pd()->attr()->has_default_values(); |
476 | |
477 | const auto &args = ctx.args(); |
478 | exec_args_t conv_args; |
479 | conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); |
480 | conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); |
481 | if (pd()->with_bias() && pd()->conv_supports_bias_) |
482 | conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS); |
483 | |
484 | // Create intermediate memory for f32 output if needed. |
485 | auto dst = args.at(DNNL_ARG_DST); |
486 | memory_t tmp_memory(dst.mem->engine(), pd()->conv_pd_->diff_src_md(), |
487 | scratchpad.get_memory_storage(key_deconv_bias)); |
488 | memory_arg_t tmp_conv_output = {&tmp_memory, false}; |
489 | |
490 | conv_args[DNNL_ARG_DIFF_SRC] |
491 | = ref_bias || non_default_attr ? tmp_conv_output : dst; |
492 | |
493 | // When sum post-op happens, we need to copy original destination memory |
494 | // prior call to external convolution happens. |
495 | if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1) { |
496 | void *original_dst = scratchpad.get(key_deconv_sum); |
497 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
498 | void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
499 | const auto dt_size = dst_d.data_type_size(); |
500 | |
501 | parallel(0, [&](const int ithr, const int nthr) { |
502 | dim_t start {0}, end {0}; |
503 | balance211(dst_d.nelems(true), nthr, ithr, start, end); |
504 | auto o_dst_start = (char *)original_dst + start * dt_size; |
505 | auto dst_start = (char *)dst + start * dt_size; |
506 | const auto size = (end - start) * dt_size; |
507 | |
508 | std::memcpy(o_dst_start, dst_start, size); |
509 | }); |
510 | } |
511 | |
512 | exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
513 | |
514 | nested_scratchpad_t ns(ctx, key_nested, conv_p_); |
515 | conv_ctx.set_scratchpad_grantor(ns.grantor()); |
516 | auto status = conv_p_->execute(conv_ctx); |
517 | if (status != status::success) return status; |
518 | |
519 | using namespace data_type; |
520 | |
521 | if (!pd()->attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) { |
522 | float *conv_output = scratchpad.get<float>(key_deconv_bias); |
523 | const auto wei_dt = pd()->weights_md()->data_type; |
524 | switch (wei_dt) { |
525 | case s8: apply_src_zero_point<s8>(ctx, pd(), conv_output); break; |
526 | case u8: apply_src_zero_point<u8>(ctx, pd(), conv_output); break; |
527 | default: assert(!"unsupported data type" ); |
528 | } |
529 | } |
530 | |
531 | float *conv_output = scratchpad.get<float>(key_deconv_bias); |
532 | |
533 | const auto &arg_scales = pd()->attr()->scales_; |
534 | const auto &src_scales = arg_scales.get(DNNL_ARG_SRC); |
535 | const auto &wei_scales = arg_scales.get(DNNL_ARG_WEIGHTS); |
536 | |
537 | if (!src_scales.has_default_values() || !wei_scales.has_default_values()) { |
538 | compute_oscale(ctx, conv_output); |
539 | } |
540 | |
541 | if (ref_bias) { |
542 | void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
543 | void *tmp_output = non_default_attr ? conv_output : dst; |
544 | compute_fwd_bias(ctx, tmp_output, conv_output, non_default_attr); |
545 | } |
546 | |
547 | if (non_default_attr) { |
548 | void *original_dst = scratchpad.get<void>(key_deconv_sum); |
549 | compute_ref_attrs(ctx, conv_output, original_dst); |
550 | } |
551 | |
552 | return status::success; |
553 | } |
554 | |
555 | status_t ref_deconvolution_bwd_data_t::execute(const exec_ctx_t &ctx) const { |
556 | using namespace memory_tracking::names; |
557 | const auto &args = ctx.args(); |
558 | exec_args_t conv_args; |
559 | conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); |
560 | conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); |
561 | conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC); |
562 | exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
563 | |
564 | nested_scratchpad_t ns(ctx, key_nested, conv_p_); |
565 | conv_ctx.set_scratchpad_grantor(ns.grantor()); |
566 | conv_p_->execute(conv_ctx); |
567 | return status::success; |
568 | } |
569 | |
570 | void ref_deconvolution_bwd_weights_t::compute_bwd_bias( |
571 | float *diff_bias, const float *diff_dst) const { |
572 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
573 | |
574 | const auto G = pd()->G(); |
575 | const auto MB = pd()->MB(); |
576 | const auto OH = pd()->OH(); |
577 | const auto OW = pd()->OW(); |
578 | const auto OC = pd()->OC() / G; |
579 | const auto OD = pd()->OD(); |
580 | const auto ndims = pd()->desc()->src_desc.ndims; |
581 | |
582 | parallel_nd(G, OC, [&](dim_t g, dim_t oc) { |
583 | float db = 0; |
584 | for_(dim_t mb = 0; mb < MB; ++mb) |
585 | for_(dim_t od = 0; od < OD; ++od) |
586 | for_(dim_t oh = 0; oh < OH; ++oh) |
587 | for (dim_t ow = 0; ow < OW; ++ow) { |
588 | const auto d_dst_off = ref_conv_utils::get_data_off( |
589 | diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow); |
590 | db += diff_dst[d_dst_off]; |
591 | } |
592 | diff_bias[g * OC + oc] = db; |
593 | }); |
594 | } |
595 | |
596 | template <data_type_t dbia_type, data_type_t ddst_type> |
597 | void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw( |
598 | typename prec_traits<dbia_type>::type *diff_bias, |
599 | const typename prec_traits<ddst_type>::type *diff_dst) const { |
600 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
601 | |
602 | const auto OC = pd()->OC(); |
603 | const auto MB = pd()->MB(); |
604 | const auto SP = pd()->OH() * pd()->OW() * pd()->OD(); |
605 | |
606 | parallel_nd(OC, [&](dim_t oc) { |
607 | float db = 0; |
608 | for (dim_t mb = 0; mb < MB; ++mb) { |
609 | PRAGMA_OMP_SIMD(reduction(+ : db)) |
610 | for (dim_t sp = 0; sp < SP; ++sp) { |
611 | auto offset = (size_t)(mb * OC + oc) * SP + sp; |
612 | db += diff_dst[offset]; |
613 | } |
614 | } |
615 | diff_bias[oc] = db; |
616 | }); |
617 | } |
618 | |
619 | template <data_type_t dbia_type, data_type_t ddst_type> |
620 | void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ndhwc( |
621 | typename prec_traits<dbia_type>::type *diff_bias, |
622 | const typename prec_traits<ddst_type>::type *diff_dst) const { |
623 | const auto MB = pd()->MB(); |
624 | const auto SP = pd()->OW() * pd()->OH() * pd()->OD(); |
625 | const auto OC = pd()->OC(); |
626 | |
627 | parallel_nd(OC, [&](dim_t oc) { |
628 | float db = 0; |
629 | for (dim_t mb = 0; mb < MB; ++mb) { |
630 | PRAGMA_OMP_SIMD(reduction(+ : db)) |
631 | for (dim_t sp = 0; sp < SP; ++sp) { |
632 | const dim_t offset = (mb * SP + sp) * OC + oc; |
633 | db += diff_dst[offset]; |
634 | } |
635 | } |
636 | diff_bias[oc] = static_cast<typename prec_traits<dbia_type>::type>(db); |
637 | }); |
638 | } |
639 | |
640 | template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize> |
641 | void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc( |
642 | typename prec_traits<dbia_type>::type *diff_bias, |
643 | const typename prec_traits<ddst_type>::type *diff_dst) const { |
644 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
645 | |
646 | const auto OC = pd()->OC(); |
647 | const auto MB = pd()->MB(); |
648 | const auto SP = pd()->OH() * pd()->OW() * pd()->OD(); |
649 | |
650 | const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0]; |
651 | |
652 | parallel_nd(utils::div_up(OC, blksize), [&](dim_t ocb) { |
653 | float db[blksize] = {0}; |
654 | |
655 | for (dim_t mb = 0; mb < MB; ++mb) { |
656 | for (dim_t sp = 0; sp < SP; ++sp) { |
657 | auto offset = mb * stride_mb + (ocb * SP + sp) * blksize; |
658 | |
659 | PRAGMA_OMP_SIMD() |
660 | for (dim_t i = 0; i < blksize; ++i) |
661 | db[i] += diff_dst[offset + i]; |
662 | } |
663 | } |
664 | |
665 | const dim_t blk = nstl::min(blksize, OC - ocb * blksize); |
666 | |
667 | PRAGMA_OMP_SIMD() |
668 | for (dim_t i = 0; i < blk; ++i) |
669 | diff_bias[ocb * blksize + i] = db[i]; |
670 | }); |
671 | } |
672 | |
673 | template <data_type_t dbia_type, data_type_t ddst_type> |
674 | void ref_deconvolution_bwd_weights_t::compute_bias( |
675 | const exec_ctx_t &ctx) const { |
676 | using dbia_data_t = typename prec_traits<dbia_type>::type; |
677 | using ddst_data_t = typename prec_traits<ddst_type>::type; |
678 | |
679 | auto diff_bias = CTX_OUT_MEM(dbia_data_t *, DNNL_ARG_DIFF_BIAS); |
680 | auto diff_dst = CTX_IN_MEM(const ddst_data_t *, DNNL_ARG_DIFF_DST); |
681 | |
682 | using namespace format_tag; |
683 | switch (pd()->dst_tag_) { |
684 | case ncdhw: |
685 | case nchw: |
686 | case ncw: |
687 | compute_bwd_bias_ncdhw<dbia_type, ddst_type>(diff_bias, diff_dst); |
688 | break; |
689 | case ndhwc: |
690 | case nhwc: |
691 | case nwc: |
692 | compute_bwd_bias_ndhwc<dbia_type, ddst_type>(diff_bias, diff_dst); |
693 | break; |
694 | case nCdhw8c: |
695 | case nChw8c: |
696 | case nCw8c: |
697 | assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type)); |
698 | compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 8>( |
699 | diff_bias, diff_dst); |
700 | break; |
701 | case nCdhw16c: |
702 | case nChw16c: |
703 | case nCw16c: |
704 | compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 16>( |
705 | diff_bias, diff_dst); |
706 | break; |
707 | default: |
708 | assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type)); |
709 | compute_bwd_bias((float *)diff_bias, (const float *)diff_dst); |
710 | break; |
711 | } |
712 | } |
713 | |
714 | status_t ref_deconvolution_bwd_weights_t::execute(const exec_ctx_t &ctx) const { |
715 | using namespace memory_tracking::names; |
716 | const auto &args = ctx.args(); |
717 | exec_args_t conv_args; |
718 | conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); |
719 | conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); |
720 | conv_args[DNNL_ARG_DIFF_WEIGHTS] = args.at(DNNL_ARG_DIFF_WEIGHTS); |
721 | exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
722 | |
723 | nested_scratchpad_t ns(ctx, key_nested, conv_p_); |
724 | conv_ctx.set_scratchpad_grantor(ns.grantor()); |
725 | status_t status = conv_p_->execute(conv_ctx); |
726 | if (status != status::success) return status; |
727 | |
728 | if (pd()->with_bias()) { |
729 | using namespace data_type; |
730 | |
731 | auto dbia_type = pd()->diff_weights_md(1)->data_type; |
732 | auto ddst_type = pd()->diff_dst_md()->data_type; |
733 | if (utils::everyone_is(f32, dbia_type, ddst_type)) |
734 | compute_bias<f32, f32>(ctx); |
735 | else if (utils::everyone_is(bf16, dbia_type, ddst_type)) |
736 | compute_bias<bf16, bf16>(ctx); |
737 | else if (dbia_type == f32 && ddst_type == bf16) |
738 | compute_bias<f32, bf16>(ctx); |
739 | else if (utils::everyone_is(f16, dbia_type, ddst_type)) |
740 | compute_bias<f16, f16>(ctx); |
741 | else if (dbia_type == f32 && ddst_type == f16) |
742 | compute_bias<f32, f16>(ctx); |
743 | else { |
744 | assert(!"unsupported data type" ); |
745 | return status::runtime_error; |
746 | } |
747 | } |
748 | return status::success; |
749 | } |
750 | |
751 | using namespace data_type; |
752 | |
753 | template void ref_deconvolution_bwd_weights_t::compute_bias<f32, f32>( |
754 | const exec_ctx_t &ctx) const; |
755 | template void ref_deconvolution_bwd_weights_t::compute_bias<f32, bf16>( |
756 | const exec_ctx_t &ctx) const; |
757 | template void ref_deconvolution_bwd_weights_t::compute_bias<bf16, bf16>( |
758 | const exec_ctx_t &ctx) const; |
759 | template void ref_deconvolution_bwd_weights_t::compute_bias<f32, f16>( |
760 | const exec_ctx_t &ctx) const; |
761 | template void ref_deconvolution_bwd_weights_t::compute_bias<f16, f16>( |
762 | const exec_ctx_t &ctx) const; |
763 | } // namespace cpu |
764 | } // namespace impl |
765 | } // namespace dnnl |
766 | |
767 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
768 | |