1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/x64/jit_brgemm_deconv.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | namespace { |
31 | status_t weights_axes_permutation( |
32 | memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) { |
33 | int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation |
34 | for (int d = 0; d < DNNL_MAX_NDIMS; ++d) |
35 | perm[d] = d; |
36 | nstl::swap(perm[0 + with_groups], perm[1 + with_groups]); |
37 | |
38 | return memory_desc_permute_axes(*o_md, *i_md, perm); |
39 | } |
40 | |
41 | status_t fwd_conv_desc_create(const deconvolution_desc_t *fwd_deconv_d, |
42 | convolution_desc_t *fwd_conv_d) { |
43 | const memory_desc_t &fwd_weights_md = fwd_deconv_d->weights_desc; |
44 | // create a fwd convolution descriptor with padding adjusted |
45 | // to the perspective of backward propagation, namely: |
46 | // - left padding replaced by left overflow |
47 | // - right padding replaced by right overflow |
48 | const int ndims_spatial = fwd_deconv_d->dst_desc.ndims - 2; |
49 | dims_t overflow_l; |
50 | dims_t overflow_r; |
51 | dim_t ks = 1; |
52 | for (int i = 0; i < ndims_spatial; i++) { |
53 | // only unit strides are allowed for bwd-to-fwd conversion |
54 | if (fwd_deconv_d->strides[i] != 1) return status::unimplemented; |
55 | const dim_t K |
56 | = fwd_weights_md.dims[fwd_weights_md.ndims - ndims_spatial + i]; |
57 | ks *= K; |
58 | const dim_t D = fwd_deconv_d->dilates[i]; |
59 | const dim_t PL = fwd_deconv_d->padding[0][i]; // left padding |
60 | const dim_t PR = fwd_deconv_d->padding[1][i]; // right padding |
61 | constexpr dim_t S = 1; |
62 | // the following relations hold for unit stride only |
63 | overflow_l[i] = ((K - 1) * (D + 1) - PL) / S; |
64 | overflow_r[i] = ((K - 1) * (D + 1) - PR) / S; |
65 | } |
66 | |
67 | CHECK(conv_desc_init(fwd_conv_d, prop_kind::forward_training, |
68 | alg_kind::convolution_direct, &fwd_deconv_d->src_desc, |
69 | &fwd_weights_md, &fwd_deconv_d->bias_desc, &fwd_deconv_d->dst_desc, |
70 | fwd_deconv_d->strides, fwd_deconv_d->dilates, overflow_l, |
71 | overflow_r)); |
72 | |
73 | // HACK: Set diff_src_desc and diff_dst_desc as a signal to the primitive |
74 | // descriptor cache that we are using the bwd-via-fwd version of |
75 | // fwd conv and thus need a separate cache entry. Only needed for |
76 | // non-1x1 convs due to spatial inversion of weights. This assumes |
77 | // that external users only use the API to create conv descs, and |
78 | // relies on common/convolution.cpp only setting the expected mem descs. |
79 | // TODO: Pass this information via attributes or integrate the bwd-via-fwd |
80 | // method directly into fwd conv implementations. |
81 | const bool with_spatial_inversion = ks > 1; |
82 | if (with_spatial_inversion) { |
83 | fwd_conv_d->diff_src_desc = fwd_conv_d->src_desc; |
84 | fwd_conv_d->diff_dst_desc = fwd_conv_d->dst_desc; |
85 | } |
86 | return status::success; |
87 | } |
88 | |
89 | status_t bwd_conv_desc_create(const deconvolution_desc_t *fwd_deconv_d, |
90 | convolution_desc_t *bwd_conv_d) { |
91 | const memory_desc_t *src_md, *dst_md, *deconv_weights_d; |
92 | memory_desc_t src_md_patched; |
93 | const auto src_dt = fwd_deconv_d->dst_desc.data_type; |
94 | |
95 | memory_desc_init_by_md_and_dt( |
96 | src_md_patched, fwd_deconv_d->dst_desc, src_dt); |
97 | src_md = &src_md_patched; |
98 | dst_md = &fwd_deconv_d->src_desc; |
99 | deconv_weights_d = &fwd_deconv_d->weights_desc; |
100 | |
101 | /* create weights desc for convolution */ |
102 | memory_desc_t conv_weights_d; |
103 | const bool with_groups = deconv_weights_d->ndims == src_md->ndims + 1; |
104 | CHECK(weights_axes_permutation( |
105 | &conv_weights_d, deconv_weights_d, with_groups)); |
106 | |
107 | return conv_desc_init(bwd_conv_d, prop_kind::backward_data, |
108 | alg_kind::convolution_direct, src_md, &conv_weights_d, |
109 | &fwd_deconv_d->bias_desc, dst_md, fwd_deconv_d->strides, |
110 | fwd_deconv_d->dilates, fwd_deconv_d->padding[0], |
111 | fwd_deconv_d->padding[1]); |
112 | } |
113 | } // namespace |
114 | |
115 | template <cpu_isa_t isa> |
116 | status_t brgemm_deconvolution_fwd_t<isa>::pd_t::init(engine_t *engine) { |
117 | using namespace data_type; |
118 | using namespace utils; |
119 | using namespace format_tag; |
120 | using smask_t = primitive_attr_t::skip_mask_t; |
121 | const deconvolution_desc_t *fwd_deconv_d = desc(); |
122 | |
123 | const bool ok = is_fwd() |
124 | && (desc()->alg_kind & alg_kind::deconvolution_direct) |
125 | && IMPLICATION(fwd_deconv_d->src_desc.data_type == f16, |
126 | isa == avx512_core_amx_fp16) |
127 | && attr()->has_default_values(smask_t::oscale_runtime |
128 | | smask_t::post_ops | smask_t::zero_points_runtime) |
129 | && output_scales_mask_ok() && post_ops_ok() && zero_points_ok() |
130 | && !has_zero_dim_memory(); |
131 | if (!ok) return status::unimplemented; |
132 | |
133 | convolution_desc_t conv_d = convolution_desc_t(); |
134 | |
135 | assert(fwd_deconv_d->src_desc.data_type != data_type::undef); |
136 | |
137 | const int ndims_spatial = fwd_deconv_d->dst_desc.ndims - 2; |
138 | for (int i = 0; i < ndims_spatial; i++) { |
139 | if (fwd_deconv_d->strides[i] != 1) { |
140 | has_strides_ = true; |
141 | break; |
142 | } |
143 | } |
144 | |
145 | primitive_desc_t *pd; |
146 | |
147 | if (has_strides_) { |
148 | CHECK(bwd_conv_desc_create(fwd_deconv_d, &conv_d)); |
149 | // try creating bwd conv prim desc |
150 | constexpr bool enable_postops |
151 | = true; // postops are enabled only for deconv (used only in strided version) |
152 | using bwd_conv_str_pd_t = typename brgemm_convolution_bwd_strided_t<isa, |
153 | enable_postops>::pd_t; |
154 | CHECK(primitive_desc_t::create<bwd_conv_str_pd_t>(&pd, |
155 | reinterpret_cast<const op_desc_t *>(&conv_d), attr(), engine, |
156 | nullptr)); |
157 | } else { |
158 | CHECK(fwd_conv_desc_create(fwd_deconv_d, &conv_d)); |
159 | do { |
160 | // try creating fwd 1x1 conv prim desc |
161 | using fwd_1x1_conv_pd_t = |
162 | typename brgemm_1x1_convolution_fwd_t<isa>::pd_t; |
163 | status_t s = primitive_desc_t::create<fwd_1x1_conv_pd_t>(&pd, |
164 | reinterpret_cast<const op_desc_t *>(&conv_d), attr(), |
165 | engine, nullptr); |
166 | if (s == status::success) break; |
167 | // try creating fwd conv prim desc |
168 | constexpr bool use_inversion |
169 | = true; // invert weights' spatial indices |
170 | using fwd_conv_pd_t = |
171 | typename brgemm_convolution_fwd_t<isa, use_inversion>::pd_t; |
172 | CHECK(primitive_desc_t::create<fwd_conv_pd_t>(&pd, |
173 | reinterpret_cast<const op_desc_t *>(&conv_d), attr(), |
174 | engine, nullptr)); |
175 | } while (false); |
176 | } |
177 | conv_pd_.reset(pd); |
178 | |
179 | if (weights_md_.format_kind == format_kind::any) { |
180 | if (has_strides_) |
181 | CHECK(weights_axes_permutation( |
182 | &weights_md_, conv_pd_->weights_md(), with_groups())); |
183 | else |
184 | weights_md_ = *conv_pd_->weights_md(); |
185 | } |
186 | if (src_md_.format_kind == format_kind::any) { |
187 | if (has_strides_) |
188 | src_md_ = *conv_pd_->diff_dst_md(); |
189 | else |
190 | src_md_ = *conv_pd_->src_md(); |
191 | } |
192 | if (dst_md_.format_kind == format_kind::any) { |
193 | if (has_strides_) |
194 | dst_md_ = *conv_pd_->diff_src_md(); |
195 | else |
196 | dst_md_ = *conv_pd_->dst_md(); |
197 | } |
198 | attr_.set_default_formats(&dst_md_); |
199 | if (bias_md_.format_kind == format_kind::any) |
200 | CHECK(memory_desc_init_by_tag(bias_md_, x)); |
201 | |
202 | auto scratchpad = scratchpad_registry().registrar(); |
203 | scratchpad.book(memory_tracking::names::key_nested, |
204 | conv_pd_->scratchpad_registry()); |
205 | |
206 | return status::success; |
207 | } |
208 | |
209 | template <cpu_isa_t isa> |
210 | status_t brgemm_deconvolution_fwd_t<isa>::init(engine_t *engine) { |
211 | return pd()->conv_pd_->create_primitive(conv_p_, engine); |
212 | } |
213 | |
214 | template <cpu_isa_t isa> |
215 | status_t brgemm_deconvolution_fwd_t<isa>::execute(const exec_ctx_t &ctx) const { |
216 | const auto &args = ctx.args(); |
217 | exec_args_t conv_args(args); |
218 | if (pd()->has_strides_) { |
219 | conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST); |
220 | conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); |
221 | conv_args.erase(DNNL_ARG_DST); |
222 | conv_args.erase(DNNL_ARG_SRC); |
223 | } |
224 | |
225 | exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
226 | |
227 | nested_scratchpad_t ns(ctx, memory_tracking::names::key_nested, conv_p_); |
228 | conv_ctx.set_scratchpad_grantor(ns.grantor()); |
229 | return conv_p_->execute(conv_ctx); |
230 | } |
231 | |
232 | template struct brgemm_deconvolution_fwd_t<avx512_core_amx>; |
233 | template struct brgemm_deconvolution_fwd_t<avx512_core_amx_fp16>; |
234 | |
235 | } // namespace x64 |
236 | } // namespace cpu |
237 | } // namespace impl |
238 | } // namespace dnnl |
239 | |
240 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
241 | |