1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/nstl.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/x64/jit_brgemm_deconv.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30namespace {
31status_t weights_axes_permutation(
32 memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) {
33 int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation
34 for (int d = 0; d < DNNL_MAX_NDIMS; ++d)
35 perm[d] = d;
36 nstl::swap(perm[0 + with_groups], perm[1 + with_groups]);
37
38 return memory_desc_permute_axes(*o_md, *i_md, perm);
39}
40
41status_t fwd_conv_desc_create(const deconvolution_desc_t *fwd_deconv_d,
42 convolution_desc_t *fwd_conv_d) {
43 const memory_desc_t &fwd_weights_md = fwd_deconv_d->weights_desc;
44 // create a fwd convolution descriptor with padding adjusted
45 // to the perspective of backward propagation, namely:
46 // - left padding replaced by left overflow
47 // - right padding replaced by right overflow
48 const int ndims_spatial = fwd_deconv_d->dst_desc.ndims - 2;
49 dims_t overflow_l;
50 dims_t overflow_r;
51 dim_t ks = 1;
52 for (int i = 0; i < ndims_spatial; i++) {
53 // only unit strides are allowed for bwd-to-fwd conversion
54 if (fwd_deconv_d->strides[i] != 1) return status::unimplemented;
55 const dim_t K
56 = fwd_weights_md.dims[fwd_weights_md.ndims - ndims_spatial + i];
57 ks *= K;
58 const dim_t D = fwd_deconv_d->dilates[i];
59 const dim_t PL = fwd_deconv_d->padding[0][i]; // left padding
60 const dim_t PR = fwd_deconv_d->padding[1][i]; // right padding
61 constexpr dim_t S = 1;
62 // the following relations hold for unit stride only
63 overflow_l[i] = ((K - 1) * (D + 1) - PL) / S;
64 overflow_r[i] = ((K - 1) * (D + 1) - PR) / S;
65 }
66
67 CHECK(conv_desc_init(fwd_conv_d, prop_kind::forward_training,
68 alg_kind::convolution_direct, &fwd_deconv_d->src_desc,
69 &fwd_weights_md, &fwd_deconv_d->bias_desc, &fwd_deconv_d->dst_desc,
70 fwd_deconv_d->strides, fwd_deconv_d->dilates, overflow_l,
71 overflow_r));
72
73 // HACK: Set diff_src_desc and diff_dst_desc as a signal to the primitive
74 // descriptor cache that we are using the bwd-via-fwd version of
75 // fwd conv and thus need a separate cache entry. Only needed for
76 // non-1x1 convs due to spatial inversion of weights. This assumes
77 // that external users only use the API to create conv descs, and
78 // relies on common/convolution.cpp only setting the expected mem descs.
79 // TODO: Pass this information via attributes or integrate the bwd-via-fwd
80 // method directly into fwd conv implementations.
81 const bool with_spatial_inversion = ks > 1;
82 if (with_spatial_inversion) {
83 fwd_conv_d->diff_src_desc = fwd_conv_d->src_desc;
84 fwd_conv_d->diff_dst_desc = fwd_conv_d->dst_desc;
85 }
86 return status::success;
87}
88
89status_t bwd_conv_desc_create(const deconvolution_desc_t *fwd_deconv_d,
90 convolution_desc_t *bwd_conv_d) {
91 const memory_desc_t *src_md, *dst_md, *deconv_weights_d;
92 memory_desc_t src_md_patched;
93 const auto src_dt = fwd_deconv_d->dst_desc.data_type;
94
95 memory_desc_init_by_md_and_dt(
96 src_md_patched, fwd_deconv_d->dst_desc, src_dt);
97 src_md = &src_md_patched;
98 dst_md = &fwd_deconv_d->src_desc;
99 deconv_weights_d = &fwd_deconv_d->weights_desc;
100
101 /* create weights desc for convolution */
102 memory_desc_t conv_weights_d;
103 const bool with_groups = deconv_weights_d->ndims == src_md->ndims + 1;
104 CHECK(weights_axes_permutation(
105 &conv_weights_d, deconv_weights_d, with_groups));
106
107 return conv_desc_init(bwd_conv_d, prop_kind::backward_data,
108 alg_kind::convolution_direct, src_md, &conv_weights_d,
109 &fwd_deconv_d->bias_desc, dst_md, fwd_deconv_d->strides,
110 fwd_deconv_d->dilates, fwd_deconv_d->padding[0],
111 fwd_deconv_d->padding[1]);
112}
113} // namespace
114
115template <cpu_isa_t isa>
116status_t brgemm_deconvolution_fwd_t<isa>::pd_t::init(engine_t *engine) {
117 using namespace data_type;
118 using namespace utils;
119 using namespace format_tag;
120 using smask_t = primitive_attr_t::skip_mask_t;
121 const deconvolution_desc_t *fwd_deconv_d = desc();
122
123 const bool ok = is_fwd()
124 && (desc()->alg_kind & alg_kind::deconvolution_direct)
125 && IMPLICATION(fwd_deconv_d->src_desc.data_type == f16,
126 isa == avx512_core_amx_fp16)
127 && attr()->has_default_values(smask_t::oscale_runtime
128 | smask_t::post_ops | smask_t::zero_points_runtime)
129 && output_scales_mask_ok() && post_ops_ok() && zero_points_ok()
130 && !has_zero_dim_memory();
131 if (!ok) return status::unimplemented;
132
133 convolution_desc_t conv_d = convolution_desc_t();
134
135 assert(fwd_deconv_d->src_desc.data_type != data_type::undef);
136
137 const int ndims_spatial = fwd_deconv_d->dst_desc.ndims - 2;
138 for (int i = 0; i < ndims_spatial; i++) {
139 if (fwd_deconv_d->strides[i] != 1) {
140 has_strides_ = true;
141 break;
142 }
143 }
144
145 primitive_desc_t *pd;
146
147 if (has_strides_) {
148 CHECK(bwd_conv_desc_create(fwd_deconv_d, &conv_d));
149 // try creating bwd conv prim desc
150 constexpr bool enable_postops
151 = true; // postops are enabled only for deconv (used only in strided version)
152 using bwd_conv_str_pd_t = typename brgemm_convolution_bwd_strided_t<isa,
153 enable_postops>::pd_t;
154 CHECK(primitive_desc_t::create<bwd_conv_str_pd_t>(&pd,
155 reinterpret_cast<const op_desc_t *>(&conv_d), attr(), engine,
156 nullptr));
157 } else {
158 CHECK(fwd_conv_desc_create(fwd_deconv_d, &conv_d));
159 do {
160 // try creating fwd 1x1 conv prim desc
161 using fwd_1x1_conv_pd_t =
162 typename brgemm_1x1_convolution_fwd_t<isa>::pd_t;
163 status_t s = primitive_desc_t::create<fwd_1x1_conv_pd_t>(&pd,
164 reinterpret_cast<const op_desc_t *>(&conv_d), attr(),
165 engine, nullptr);
166 if (s == status::success) break;
167 // try creating fwd conv prim desc
168 constexpr bool use_inversion
169 = true; // invert weights' spatial indices
170 using fwd_conv_pd_t =
171 typename brgemm_convolution_fwd_t<isa, use_inversion>::pd_t;
172 CHECK(primitive_desc_t::create<fwd_conv_pd_t>(&pd,
173 reinterpret_cast<const op_desc_t *>(&conv_d), attr(),
174 engine, nullptr));
175 } while (false);
176 }
177 conv_pd_.reset(pd);
178
179 if (weights_md_.format_kind == format_kind::any) {
180 if (has_strides_)
181 CHECK(weights_axes_permutation(
182 &weights_md_, conv_pd_->weights_md(), with_groups()));
183 else
184 weights_md_ = *conv_pd_->weights_md();
185 }
186 if (src_md_.format_kind == format_kind::any) {
187 if (has_strides_)
188 src_md_ = *conv_pd_->diff_dst_md();
189 else
190 src_md_ = *conv_pd_->src_md();
191 }
192 if (dst_md_.format_kind == format_kind::any) {
193 if (has_strides_)
194 dst_md_ = *conv_pd_->diff_src_md();
195 else
196 dst_md_ = *conv_pd_->dst_md();
197 }
198 attr_.set_default_formats(&dst_md_);
199 if (bias_md_.format_kind == format_kind::any)
200 CHECK(memory_desc_init_by_tag(bias_md_, x));
201
202 auto scratchpad = scratchpad_registry().registrar();
203 scratchpad.book(memory_tracking::names::key_nested,
204 conv_pd_->scratchpad_registry());
205
206 return status::success;
207}
208
209template <cpu_isa_t isa>
210status_t brgemm_deconvolution_fwd_t<isa>::init(engine_t *engine) {
211 return pd()->conv_pd_->create_primitive(conv_p_, engine);
212}
213
214template <cpu_isa_t isa>
215status_t brgemm_deconvolution_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
216 const auto &args = ctx.args();
217 exec_args_t conv_args(args);
218 if (pd()->has_strides_) {
219 conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST);
220 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
221 conv_args.erase(DNNL_ARG_DST);
222 conv_args.erase(DNNL_ARG_SRC);
223 }
224
225 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
226
227 nested_scratchpad_t ns(ctx, memory_tracking::names::key_nested, conv_p_);
228 conv_ctx.set_scratchpad_grantor(ns.grantor());
229 return conv_p_->execute(conv_ctx);
230}
231
232template struct brgemm_deconvolution_fwd_t<avx512_core_amx>;
233template struct brgemm_deconvolution_fwd_t<avx512_core_amx_fp16>;
234
235} // namespace x64
236} // namespace cpu
237} // namespace impl
238} // namespace dnnl
239
240// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
241