1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3* Copyright 2022 Arm Ltd. and affiliates
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#ifndef CPU_REF_DECONVOLUTION_HPP
19#define CPU_REF_DECONVOLUTION_HPP
20
21#include <assert.h>
22#include <string.h>
23
24#include "common/c_types_map.hpp"
25#include "common/primitive.hpp"
26#include "common/primitive_desc_iterator.hpp"
27#include "common/stream.hpp"
28#include "common/type_helpers.hpp"
29#include "common/utils.hpp"
30
31#include "cpu/primitive_attr_postops.hpp"
32
33#include "cpu/cpu_convolution_pd.hpp"
34#include "cpu/cpu_deconvolution_pd.hpp"
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39
40static status_t weights_axes_permutation(
41 memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) {
42 int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation
43 for (int d = 0; d < DNNL_MAX_NDIMS; ++d)
44 perm[d] = d;
45 nstl::swap(perm[0 + with_groups], perm[1 + with_groups]);
46
47 return memory_desc_permute_axes(*o_md, *i_md, perm);
48}
49
50static status_t conv_descr_create(const deconvolution_desc_t *dd,
51 convolution_desc_t *cd, const memory_desc_t *bias_md = nullptr,
52 data_type_t src_dt = data_type::undef) {
53 using namespace prop_kind;
54 alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct
55 ? alg_kind::convolution_direct
56 : alg_kind::convolution_winograd;
57
58 const memory_desc_t *src_md, *dst_md, *d_weights_d;
59 memory_desc_t src_md_patched;
60 prop_kind_t prop_kind;
61
62 if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) {
63 prop_kind = backward_data;
64 assert(src_dt != data_type::undef);
65 memory_desc_init_by_md_and_dt(src_md_patched, dd->dst_desc, src_dt);
66 src_md = &src_md_patched;
67 dst_md = &dd->src_desc;
68 d_weights_d = &dd->weights_desc;
69 } else if (dd->prop_kind == backward_data) {
70 assert(src_dt == data_type::undef);
71 prop_kind = forward_training;
72 src_md = &dd->diff_dst_desc;
73 dst_md = &dd->diff_src_desc;
74 d_weights_d = &dd->weights_desc;
75 } else {
76 assert(src_dt == data_type::undef);
77 prop_kind = dd->prop_kind;
78 src_md = &dd->diff_dst_desc;
79 dst_md = &dd->src_desc;
80 d_weights_d = &dd->diff_weights_desc;
81 }
82
83 /* create weights desc for convolution */
84 memory_desc_t c_weights_d;
85 const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
86 CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups));
87
88 return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
89 bias_md, dst_md, dd->strides, dd->dilates, dd->padding[0],
90 dd->padding[1]);
91}
92
93struct ref_deconvolution_fwd_t : public primitive_t {
94 struct pd_t : public cpu_deconvolution_fwd_pd_t {
95 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
96 const deconvolution_fwd_pd_t *hint_fwd_pd)
97 : cpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
98
99 pd_t(const pd_t &other)
100 : cpu_deconvolution_fwd_pd_t(other)
101 , conv_pd_(other.conv_pd_->clone())
102 , conv_supports_bias_(other.conv_supports_bias_)
103 , dst_tag_(other.dst_tag_) {}
104
105 ~pd_t() = default;
106
107 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t);
108
109 status_t init_convolution(engine_t *engine) {
110 using namespace format_tag;
111 using namespace data_type;
112
113 // Create empty attributes for bwd_d conv to pick up the fastest
114 // impl available and apply post-ops and/or bias update later in
115 // this impl via simple loop.
116 primitive_attr_t conv_attr;
117
118 convolution_desc_t cd;
119 // When no attributes were requested, try to find a bwd_d conv impl
120 // which supports bias update in-place, if requested, in requested
121 // dst_dt. If appropriate conv impl was not found, enforce f32
122 // diff_src for conv for correct result. If attributes are
123 // requested, enforce conv impl to return f32 output no matter what.
124 if (attr()->has_default_values()) {
125 CHECK(conv_descr_create(
126 desc(), &cd, weights_md(1), dst_md()->data_type));
127 primitive_desc_iterator_t it(
128 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
129 if (!it.is_initialized()) return status::out_of_memory;
130
131 while (++it != it.end()) {
132 conv_pd_ = *it;
133 if (with_bias()) {
134 conv_supports_bias_ = utils::downcast<
135 cpu_convolution_bwd_data_pd_t *>(conv_pd_.get())
136 ->support_bias();
137 if (!conv_supports_bias_) continue;
138 }
139 bool ok = conv_pd_->weights_md()->extra.flags == 0;
140 if (ok) return status::success;
141 }
142 }
143
144 // Intermediate f32 buffer is supported only for given condition.
145 if (!attr()->has_default_values() || with_bias()) {
146 // Enforce f32 dt for diff src and work with f32 output for bias
147 // update or post ops after conv execution.
148 CHECK(conv_descr_create(desc(), &cd, nullptr, data_type::f32));
149 primitive_desc_iterator_t it(
150 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
151 if (!it.is_initialized()) return status::out_of_memory;
152
153 while (++it != it.end()) {
154 conv_pd_ = *it;
155 bool ok = conv_pd_->weights_md()->extra.flags == 0;
156 if (ok) return status::success;
157 }
158 }
159 return status::unimplemented;
160 }
161
162 status_t init(engine_t *engine) {
163 using namespace format_tag;
164 using smask_t = primitive_attr_t::skip_mask_t;
165
166 const bool ok = is_fwd()
167 && utils::one_of(desc()->alg_kind,
168 alg_kind::deconvolution_direct,
169 alg_kind::deconvolution_winograd)
170 && attr()->has_default_values(smask_t::scales_runtime
171 | smask_t::post_ops | smask_t::zero_points_runtime)
172 && scales_mask_ok() && post_ops_ok() && zero_points_ok();
173 if (!ok) return status::unimplemented;
174
175 CHECK(init_convolution(engine));
176
177 if (weights_md_.format_kind == format_kind::any)
178 CHECK(weights_axes_permutation(
179 &weights_md_, conv_pd_->weights_md(), with_groups()));
180 if (src_md_.format_kind == format_kind::any)
181 src_md_ = *conv_pd_->diff_dst_md();
182 if (dst_md_.format_kind == format_kind::any) {
183 // re-apply dt manually since it could be changed due to bias
184 const auto dst_dt = dst_md_.data_type;
185 memory_desc_init_by_md_and_dt(
186 dst_md_, *conv_pd_->diff_src_md(), dst_dt);
187 }
188 if (bias_md_.format_kind == format_kind::any)
189 CHECK(memory_desc_init_by_tag(bias_md_, x));
190
191 dst_tag_ = memory_desc_matches_one_of_tag(dst_md_,
192 utils::pick(ndims() - 3, ncw, nchw, ncdhw),
193 utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
194 utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
195 utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
196
197 init_scratchpad();
198 return attr_.set_default_formats(dst_md(0));
199 }
200
201 std::shared_ptr<primitive_desc_t> conv_pd_;
202 bool conv_supports_bias_ = false;
203 format_tag_t dst_tag_;
204
205 private:
206 void init_scratchpad() {
207 using namespace memory_tracking::names;
208 auto scratchpad = scratchpad_registry().registrar();
209 scratchpad.book(key_nested, conv_pd_->scratchpad_registry());
210
211 // This scratchpad is required for intermediate f32 conv output
212 // since original memory can be of smaller size and will cause
213 // out of boundary access.
214 if ((with_bias() && !conv_supports_bias_)
215 || !attr()->has_default_values()) {
216 const memory_desc_wrapper diff_src_d(conv_pd_->diff_src_md());
217 assert(diff_src_d.data_type_size() == sizeof(float));
218 scratchpad.book(key_deconv_bias, diff_src_d.nelems(true),
219 diff_src_d.data_type_size());
220 }
221 // This scratchpad is required to stash original dst memory for sum
222 // post-op. It will be overwritten by conv execution and will not
223 // be available to get the correct result.
224 const memory_desc_wrapper dst_d(dst_md());
225 if (attr()->post_ops_.find(primitive_kind::sum) != -1)
226 scratchpad.book(key_deconv_sum, dst_d.nelems(true),
227 dst_d.data_type_size());
228
229 if (!attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) {
230 scratchpad.book<int32_t>(key_deconv_zp, OC() * G());
231 }
232 }
233
234 bool scales_mask_ok() const {
235 using namespace data_type;
236 const std::vector<int> supported_args
237 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
238 bool ok = attr()->scales_.has_default_values(supported_args);
239 for (int arg : supported_args) {
240 const auto &mask = attr()->scales_.get(arg).mask_;
241 if (arg == DNNL_ARG_WEIGHTS)
242 ok = ok && (mask == 0 || mask == (1 << (int)with_groups()));
243 else
244 ok = ok && (mask == 0);
245 }
246 return ok;
247 }
248
249 bool post_ops_ok() const {
250 return attr()->post_ops_.find(primitive_kind::convolution) == -1;
251 }
252
253 bool zero_points_ok() const {
254 using namespace data_type;
255 int mask_src = 0, mask_dst = 0;
256 attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
257 attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
258
259 return IMPLICATION(!utils::one_of(src_md()->data_type, s8, u8),
260 attr()->zero_points_.has_default_values())
261 && attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
262 && (mask_src == 0 || mask_src == 1 << 1)
263 && (mask_dst == 0 || mask_dst == 1 << 1);
264 }
265 };
266
267 ref_deconvolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
268
269 status_t init(engine_t *engine) override {
270 CHECK(pd()->conv_pd_->create_primitive(conv_p_, engine));
271
272 ref_post_ops
273 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
274 if (!ref_post_ops) return status::out_of_memory;
275 return status::success;
276 }
277
278 status_t execute(const exec_ctx_t &ctx) const override;
279
280private:
281 void compute_fwd_bias_common(const exec_ctx_t &ctx, void *dst,
282 const float *conv_output, bool non_default_attr) const;
283
284 void compute_fwd_bias_ncdhw(const exec_ctx_t &ctx, void *dst,
285 const float *conv_output, bool non_default_attr) const;
286
287 void compute_fwd_bias_ndhwc(const exec_ctx_t &ctx, void *dst,
288 const float *conv_output, bool non_default_attr) const;
289
290 template <dim_t blk_size>
291 void compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx, void *dst,
292 const float *conv_output, bool non_default_attr) const;
293
294 status_t compute_oscale(const exec_ctx_t &ctx, float *dst) const;
295
296 void compute_fwd_bias(const exec_ctx_t &ctx, void *dst,
297 const float *conv_output, bool non_default_attr) const;
298
299 status_t compute_ref_attrs(const exec_ctx_t &ctx, const float *conv_output,
300 void *original_dst) const;
301
302 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
303 std::shared_ptr<primitive_t> conv_p_;
304 std::unique_ptr<ref_post_ops_t> ref_post_ops;
305};
306
307struct ref_deconvolution_bwd_data_t : public primitive_t {
308 struct pd_t : public cpu_deconvolution_bwd_data_pd_t {
309 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
310 const deconvolution_fwd_pd_t *hint_fwd_pd)
311 : cpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
312
313 pd_t(const pd_t &other)
314 : cpu_deconvolution_bwd_data_pd_t(other)
315 , conv_pd_(other.conv_pd_->clone()) {}
316
317 ~pd_t() = default;
318
319 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t);
320
321 status_t init_convolution(engine_t *engine) {
322 using namespace types;
323
324 convolution_desc_t cd;
325 status_t status = conv_descr_create(desc(), &cd);
326 if (status != status::success) return status;
327 primitive_attr_t conv_attr(*attr());
328 if (!conv_attr.is_initialized()) return status::out_of_memory;
329
330 primitive_desc_iterator_t it(
331 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
332 if (!it.is_initialized()) return status::out_of_memory;
333 while (++it != it.end()) {
334 conv_pd_ = *it;
335 if (conv_pd_->weights_md()->extra.flags == 0)
336 return status::success;
337 }
338
339 return status::unimplemented;
340 }
341
342 status_t init(engine_t *engine) {
343 using namespace data_type;
344 auto dsrc_type = desc()->diff_src_desc.data_type;
345 auto wei_type = desc()->weights_desc.data_type;
346 auto ddst_type = desc()->diff_dst_desc.data_type;
347 bool ok = true && desc()->prop_kind == prop_kind::backward_data
348 && utils::one_of(wei_type, f32, bf16, f16)
349 && ddst_type == wei_type
350 && utils::one_of(dsrc_type, wei_type, f32)
351 && utils::one_of(desc()->alg_kind,
352 alg_kind::deconvolution_direct,
353 alg_kind::deconvolution_winograd)
354 && attr()->has_default_values();
355
356 if (ok) {
357 CHECK(init_convolution(engine));
358 if (weights_md_.format_kind == format_kind::any)
359 CHECK(weights_axes_permutation(&weights_md_,
360 conv_pd_->weights_md(), with_groups()));
361 if (diff_src_md_.format_kind == format_kind::any)
362 diff_src_md_ = *conv_pd_->dst_md();
363 if (diff_dst_md_.format_kind == format_kind::any)
364 diff_dst_md_ = *conv_pd_->src_md();
365 init_scratchpad();
366 return status::success;
367 }
368
369 return status::unimplemented;
370 }
371
372 std::shared_ptr<primitive_desc_t> conv_pd_;
373
374 private:
375 void init_scratchpad() {
376 auto scratchpad = scratchpad_registry().registrar();
377 scratchpad.book(memory_tracking::names::key_nested,
378 conv_pd_->scratchpad_registry());
379 }
380 };
381
382 typedef typename prec_traits<data_type::f32>::type data_t;
383
384 ref_deconvolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
385
386 status_t init(engine_t *engine) override {
387 return pd()->conv_pd_->create_primitive(conv_p_, engine);
388 }
389
390#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL
391 status_t create_resource(
392 engine_t *engine, resource_mapper_t &mapper) const override {
393 CHECK(conv_p_->create_resource(engine, mapper));
394 return status::success;
395 }
396#endif
397
398 status_t execute(const exec_ctx_t &ctx) const override;
399
400private:
401 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
402 std::shared_ptr<primitive_t> conv_p_;
403};
404
405struct ref_deconvolution_bwd_weights_t : public primitive_t {
406 struct pd_t : public cpu_deconvolution_bwd_weights_pd_t {
407 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
408 const deconvolution_fwd_pd_t *hint_fwd_pd)
409 : cpu_deconvolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
410
411 pd_t(const pd_t &other)
412 : cpu_deconvolution_bwd_weights_pd_t(other)
413 , conv_pd_(other.conv_pd_->clone())
414 , dst_tag_(other.dst_tag_) {}
415
416 ~pd_t() = default;
417
418 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t);
419
420 status_t init_convolution(engine_t *engine) {
421 using namespace types;
422 using namespace format_tag;
423
424 convolution_desc_t cd;
425 status_t status = conv_descr_create(desc(), &cd);
426 if (status != status::success) return status;
427 primitive_attr_t conv_attr(*attr());
428 if (!conv_attr.is_initialized()) return status::out_of_memory;
429
430 primitive_desc_iterator_t it(
431 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
432 if (!it.is_initialized()) return status::out_of_memory;
433 while (++it != it.end()) {
434 conv_pd_ = *it;
435 bool bf16_ref_deconv_supports_bias = IMPLICATION(with_bias()
436 && desc()->src_desc.data_type
437 == data_type::bf16,
438 memory_desc_matches_one_of_tag(*conv_pd_->src_md(),
439 utils::pick(ndims() - 3, ncw, nchw, ncdhw),
440 utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
441 utils::pick(ndims() - 3, nCw16c, nChw16c,
442 nCdhw16c)));
443 if (conv_pd_->diff_weights_md()->extra.flags == 0
444 && bf16_ref_deconv_supports_bias) {
445 return status::success;
446 }
447 }
448 return status::unimplemented;
449 }
450
451 status_t init(engine_t *engine) {
452 using namespace format_tag;
453 using namespace data_type;
454 auto src_type = desc()->src_desc.data_type;
455 auto dwei_type = desc()->diff_weights_desc.data_type;
456 auto ddst_type = desc()->diff_dst_desc.data_type;
457 bool ok = true && desc()->prop_kind == prop_kind::backward_weights
458 && utils::one_of(src_type, f32, bf16, f16)
459 && ddst_type == src_type
460 && utils::one_of(dwei_type, src_type, f32)
461 && utils::one_of(desc()->alg_kind,
462 alg_kind::deconvolution_direct,
463 alg_kind::deconvolution_winograd)
464 && attr()->has_default_values();
465
466 if (ok) {
467 CHECK(init_convolution(engine));
468 if (diff_weights_md_.format_kind == format_kind::any)
469 CHECK(weights_axes_permutation(&diff_weights_md_,
470 conv_pd_->diff_weights_md(), with_groups()));
471 if (src_md_.format_kind == format_kind::any)
472 src_md_ = *conv_pd_->diff_dst_md();
473 if (diff_dst_md_.format_kind == format_kind::any)
474 diff_dst_md_ = *conv_pd_->src_md();
475 if (diff_bias_md_.format_kind == format_kind::any)
476 CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
477
478 dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_,
479 utils::pick(ndims() - 3, ncw, nchw, ncdhw),
480 utils::pick(ndims() - 3, nwc, nhwc, ndhwc),
481 utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
482 utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
483 init_scratchpad();
484 return status::success;
485 }
486
487 return status::unimplemented;
488 }
489
490 std::shared_ptr<primitive_desc_t> conv_pd_;
491 format_tag_t dst_tag_;
492
493 private:
494 void init_scratchpad() {
495 auto scratchpad = scratchpad_registry().registrar();
496 scratchpad.book(memory_tracking::names::key_nested,
497 conv_pd_->scratchpad_registry());
498 }
499 };
500
501 ref_deconvolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
502
503 status_t init(engine_t *engine) override {
504 return pd()->conv_pd_->create_primitive(conv_p_, engine);
505 }
506
507 status_t execute(const exec_ctx_t &ctx) const override;
508
509private:
510 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
511 void compute_bwd_bias(float *diff_bias, const float *diff_dst) const;
512
513 template <data_type_t dbia_type, data_type_t ddst_type>
514 void compute_bwd_bias_ncdhw(
515 typename prec_traits<dbia_type>::type *diff_bias,
516 const typename prec_traits<ddst_type>::type *diff_dst) const;
517
518 template <data_type_t dbia_type, data_type_t ddst_type>
519 void compute_bwd_bias_ndhwc(
520 typename prec_traits<dbia_type>::type *diff_bias,
521 const typename prec_traits<ddst_type>::type *diff_dst) const;
522
523 template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize>
524 void compute_bwd_bias_nCdhwXc(
525 typename prec_traits<dbia_type>::type *diff_bias,
526 const typename prec_traits<ddst_type>::type *diff_dst) const;
527
528 template <data_type_t dbia_type, data_type_t ddst_type>
529 void compute_bias(const exec_ctx_t &ctx) const;
530 std::shared_ptr<primitive_t> conv_p_;
531};
532
533} // namespace cpu
534} // namespace impl
535} // namespace dnnl
536
537#endif
538
539// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
540