1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_CONVOLUTION_PD_HPP
18#define COMMON_CONVOLUTION_PD_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "c_types_map.hpp"
23#include "primitive_desc.hpp"
24#include "utils.hpp"
25
26namespace dnnl {
27namespace impl {
28
29status_t conv_desc_init(convolution_desc_t *conv_desc, prop_kind_t prop_kind,
30 alg_kind_t alg_kind, const memory_desc_t *src_desc,
31 const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
32 const memory_desc_t *dst_desc, const dims_t strides,
33 const dims_t dilates, const dims_t padding_l, const dims_t padding_r);
34
35memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc);
36memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc);
37memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc);
38memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc);
39const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc);
40const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc);
41const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc);
42const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc);
43
44struct convolution_fwd_pd_t;
45
46struct convolution_pd_t : public primitive_desc_t {
47 static constexpr auto base_pkind = primitive_kind::convolution;
48
49 const convolution_desc_t *desc() const { return &desc_; }
50 const op_desc_t *op_desc() const override {
51 return reinterpret_cast<const op_desc_t *>(this->desc());
52 }
53
54 status_t query(query_t what, int idx, void *result) const override {
55 switch (what) {
56 case query::prop_kind:
57 *(prop_kind_t *)result = desc()->prop_kind;
58 break;
59 case query::alg_kind:
60 *(alg_kind_t *)result = desc()->alg_kind;
61 break;
62 case query::strides:
63 *(const dims_t **)result = &desc()->strides;
64 break;
65 case query::dilations:
66 *(const dims_t **)result = &desc()->dilates;
67 break;
68 case query::padding_l:
69 *(const dims_t **)result = &desc()->padding[0];
70 break;
71 case query::padding_r:
72 *(const dims_t **)result = &desc()->padding[1];
73 break;
74 default: return primitive_desc_t::query(what, idx, result);
75 }
76 return status::success;
77 }
78
79 /* common conv aux functions */
80
81 dim_t MB() const { return invariant_src_md()->dims[0]; }
82
83 dim_t IC() const { return invariant_src_md()->dims[1]; }
84 dim_t OC() const { return invariant_dst_md()->dims[1]; }
85 dim_t G() const { return with_groups() ? invariant_wei_md()->dims[0] : 1; }
86
87 dim_t ID() const {
88 return ndims() >= 5 ? invariant_src_md()->dims[ndims() - 3] : 1;
89 }
90 dim_t IH() const {
91 return ndims() >= 4 ? invariant_src_md()->dims[ndims() - 2] : 1;
92 }
93 dim_t IW() const { return invariant_src_md()->dims[ndims() - 1]; }
94
95 dim_t OD() const {
96 return ndims() >= 5 ? invariant_dst_md()->dims[ndims() - 3] : 1;
97 }
98 dim_t OH() const {
99 return ndims() >= 4 ? invariant_dst_md()->dims[ndims() - 2] : 1;
100 }
101 dim_t OW() const { return invariant_dst_md()->dims[ndims() - 1]; }
102
103 dim_t KD() const {
104 return ndims() >= 5
105 ? invariant_wei_md()->dims[ndims() + with_groups() - 3]
106 : 1;
107 }
108 dim_t KH() const {
109 return ndims() >= 4
110 ? invariant_wei_md()->dims[ndims() + with_groups() - 2]
111 : 1;
112 }
113 dim_t KW() const {
114 return invariant_wei_md()->dims[ndims() + with_groups() - 1];
115 }
116
117 dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
118 dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
119 dim_t KSW() const { return desc_.strides[ndims() - 3]; }
120
121 dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
122 dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
123 dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
124
125 dim_t padFront() const {
126 return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0;
127 }
128 dim_t padBack() const {
129 return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0;
130 }
131 dim_t padT() const {
132 return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0;
133 }
134 dim_t padB() const {
135 return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0;
136 }
137 dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
138 dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
139
140 int ndims() const { return invariant_src_md()->ndims; }
141
142 bool with_bias() const {
143 auto *bia_d = desc()->prop_kind == prop_kind::backward_weights
144 ? &desc()->diff_bias_desc
145 : &desc()->bias_desc;
146 return !memory_desc_wrapper(bia_d).is_zero();
147 }
148 bool with_groups() const {
149 return invariant_wei_md()->ndims == ndims() + 1;
150 }
151
152 bool is_fwd() const {
153 return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
154 prop_kind::forward_inference);
155 }
156
157 bool is_bwd_d() const {
158 return desc_.prop_kind == prop_kind::backward_data;
159 }
160
161 bool is_bwd_w() const {
162 return desc_.prop_kind == prop_kind::backward_weights;
163 }
164
165 bool has_zero_dim_memory() const {
166 const auto s_d = memory_desc_wrapper(*invariant_src_md());
167 const auto d_d = memory_desc_wrapper(*invariant_dst_md());
168 return s_d.has_zero_dim() || d_d.has_zero_dim();
169 }
170
171 const memory_desc_t *invariant_src_md() const {
172 return desc()->prop_kind == prop_kind::backward_data ? diff_src_md()
173 : src_md();
174 }
175 const memory_desc_t *invariant_wei_md(int index = 0) const {
176 return desc()->prop_kind == prop_kind::backward_weights
177 ? diff_weights_md(index)
178 : weights_md(index);
179 }
180 const memory_desc_t *invariant_bia_md() const {
181 return invariant_wei_md(1);
182 }
183 const memory_desc_t *invariant_dst_md() const {
184 return is_fwd() ? dst_md() : diff_dst_md();
185 }
186 memory_desc_t *invariant_src_md() {
187 auto *const_this = (const convolution_pd_t *)this;
188 return const_cast<memory_desc_t *>(const_this->invariant_src_md());
189 }
190 memory_desc_t *invariant_wei_md(int index = 0) {
191 auto *const_this = (const convolution_pd_t *)this;
192 return const_cast<memory_desc_t *>(const_this->invariant_wei_md(index));
193 }
194 memory_desc_t *invariant_bia_md() {
195 auto *const_this = (const convolution_pd_t *)this;
196 return const_cast<memory_desc_t *>(const_this->invariant_bia_md());
197 }
198 memory_desc_t *invariant_dst_md() {
199 auto *const_this = (const convolution_pd_t *)this;
200 return const_cast<memory_desc_t *>(const_this->invariant_dst_md());
201 }
202
203protected:
204 convolution_desc_t desc_;
205 const convolution_fwd_pd_t *hint_fwd_pd_;
206
207 convolution_pd_t(const convolution_desc_t *adesc,
208 const primitive_attr_t *attr,
209 const convolution_fwd_pd_t *hint_fwd_pd)
210 : primitive_desc_t(attr, base_pkind)
211 , desc_(*adesc)
212 , hint_fwd_pd_(hint_fwd_pd) {}
213
214 bool set_default_formats_common_template(memory_desc_t &src_md,
215 format_tag_t src_tag, memory_desc_t &wei_md, format_tag_t wei_tag,
216 memory_desc_t &dst_md, format_tag_t dst_tag,
217 memory_desc_t &bia_md) {
218 using namespace format_tag;
219
220#define IS_OK(f) \
221 do { \
222 if ((f) != status::success) return false; \
223 } while (0)
224 if (src_md.format_kind == format_kind::any
225 && !utils::one_of(src_tag, any, undef))
226 IS_OK(memory_desc_init_by_tag(src_md, src_tag));
227 if (dst_md.format_kind == format_kind::any
228 && !utils::one_of(dst_tag, any, undef))
229 IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
230 if (wei_md.format_kind == format_kind::any
231 && !utils::one_of(wei_tag, any, undef))
232 IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
233 if (with_bias() && bia_md.format_kind == format_kind::any)
234 IS_OK(memory_desc_init_by_tag(bia_md, x));
235#undef IS_OK
236
237 return true;
238 }
239
240 bool set_default_alg_kind(alg_kind_t alg_kind) {
241 assert(utils::one_of(alg_kind, alg_kind::convolution_direct,
242 alg_kind::convolution_winograd));
243 if (desc_.alg_kind == alg_kind::convolution_auto)
244 desc_.alg_kind = alg_kind;
245 return desc_.alg_kind == alg_kind;
246 }
247
248 bool expect_data_types(data_type_t src_dt, data_type_t wei_dt,
249 data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const {
250 bool ok = true
251 && (src_dt == data_type::undef
252 || invariant_src_md()->data_type == src_dt)
253 && (wei_dt == data_type::undef
254 || invariant_wei_md()->data_type == wei_dt)
255 && (dst_dt == data_type::undef
256 || invariant_dst_md()->data_type == dst_dt)
257 && (acc_dt == data_type::undef
258 || desc_.accum_data_type == acc_dt);
259 if (with_bias() && bia_dt != data_type::undef)
260 ok = ok && invariant_bia_md()->data_type == bia_dt;
261 return ok;
262 }
263};
264
265struct convolution_fwd_pd_t : public convolution_pd_t {
266 typedef convolution_fwd_pd_t base_class;
267 typedef convolution_fwd_pd_t hint_class;
268
269 arg_usage_t arg_usage(int arg) const override {
270 if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS))
271 return arg_usage_t::input;
272
273 if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input;
274
275 if (arg == DNNL_ARG_DST) return arg_usage_t::output;
276
277 return primitive_desc_t::arg_usage(arg);
278 }
279
280 const memory_desc_t *arg_md(int arg) const override {
281 switch (arg) {
282 case DNNL_ARG_SRC: return src_md(0);
283 case DNNL_ARG_WEIGHTS: return weights_md(0);
284 case DNNL_ARG_BIAS: return weights_md(1);
285 case DNNL_ARG_DST: return dst_md(0);
286 default: return convolution_pd_t::arg_md(arg);
287 }
288 }
289
290 const memory_desc_t *src_md(int index = 0) const override {
291 return index == 0 ? &src_md_ : &glob_zero_md;
292 }
293 const memory_desc_t *dst_md(int index = 0) const override {
294 return index == 0 ? &dst_md_ : &glob_zero_md;
295 }
296 const memory_desc_t *weights_md(int index = 0) const override {
297 if (index == 0) return &weights_md_;
298 if (index == 1 && with_bias()) return &bias_md_;
299 return &glob_zero_md;
300 }
301
302 int n_inputs() const override {
303 return 2 + with_bias() + attr_post_op_dw_inputs() + n_binary_po_inputs()
304 + n_prelu_po_inputs();
305 }
306
307 int n_outputs() const override { return 1; }
308
309protected:
310 memory_desc_t src_md_;
311 memory_desc_t weights_md_;
312 memory_desc_t bias_md_;
313 memory_desc_t dst_md_;
314
315 convolution_fwd_pd_t(const convolution_desc_t *adesc,
316 const primitive_attr_t *attr,
317 const convolution_fwd_pd_t *hint_fwd_pd)
318 : convolution_pd_t(adesc, attr, hint_fwd_pd)
319 , src_md_(desc_.src_desc)
320 , weights_md_(desc_.weights_desc)
321 , bias_md_(desc_.bias_desc)
322 , dst_md_(desc_.dst_desc) {}
323
324 bool set_default_formats_common(
325 format_tag_t src_tag, format_tag_t wei_tag, format_tag_t dst_tag) {
326 return set_default_formats_common_template(src_md_, src_tag,
327 weights_md_, wei_tag, dst_md_, dst_tag, bias_md_);
328 }
329
330 int attr_post_op_dw_inputs() const {
331 const auto &po = attr_.post_ops_;
332 int conv = po.find(primitive_kind::convolution);
333 if (conv == -1) return 0;
334 return po.entry_[conv].depthwise_conv.bias_dt == data_type::undef ? 1
335 : 2;
336 }
337};
338
339struct convolution_bwd_data_pd_t : public convolution_pd_t {
340 typedef convolution_bwd_data_pd_t base_class;
341 typedef convolution_fwd_pd_t hint_class;
342
343 arg_usage_t arg_usage(int arg) const override {
344 if (utils::one_of(arg, DNNL_ARG_WEIGHTS, DNNL_ARG_DIFF_DST))
345 return arg_usage_t::input;
346
347 if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output;
348
349 return primitive_desc_t::arg_usage(arg);
350 }
351
352 const memory_desc_t *arg_md(int arg) const override {
353 switch (arg) {
354 case DNNL_ARG_DIFF_SRC: return diff_src_md(0);
355 case DNNL_ARG_WEIGHTS: return weights_md(0);
356 case DNNL_ARG_BIAS: return weights_md(1);
357 case DNNL_ARG_DIFF_DST: return diff_dst_md(0);
358 default: return convolution_pd_t::arg_md(arg);
359 }
360 }
361
362 const memory_desc_t *diff_src_md(int index = 0) const override {
363 return index == 0 ? &diff_src_md_ : &glob_zero_md;
364 }
365 const memory_desc_t *diff_dst_md(int index = 0) const override {
366 return index == 0 ? &diff_dst_md_ : &glob_zero_md;
367 }
368 const memory_desc_t *weights_md(int index = 0) const override {
369 if (index == 0) return &weights_md_;
370 if (index == 1 && with_bias()) return &bias_md_;
371 return &glob_zero_md;
372 }
373
374 int n_inputs() const override { return 2 + with_bias(); }
375 int n_outputs() const override { return 1; }
376
377 virtual bool support_bias() const { return false; }
378
379protected:
380 memory_desc_t diff_src_md_;
381 memory_desc_t weights_md_;
382 memory_desc_t bias_md_;
383 memory_desc_t diff_dst_md_;
384
385 convolution_bwd_data_pd_t(const convolution_desc_t *adesc,
386 const primitive_attr_t *attr,
387 const convolution_fwd_pd_t *hint_fwd_pd)
388 : convolution_pd_t(adesc, attr, hint_fwd_pd)
389 , diff_src_md_(desc_.diff_src_desc)
390 , weights_md_(desc_.weights_desc)
391 , bias_md_(desc_.bias_desc)
392 , diff_dst_md_(desc_.diff_dst_desc) {}
393
394 bool set_default_formats_common(format_tag_t diff_src_tag,
395 format_tag_t wei_tag, format_tag_t diff_dst_tag) {
396 return set_default_formats_common_template(diff_src_md_, diff_src_tag,
397 weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_);
398 }
399};
400
401struct convolution_bwd_weights_pd_t : public convolution_pd_t {
402 typedef convolution_bwd_weights_pd_t base_class;
403 typedef convolution_fwd_pd_t hint_class;
404
405 convolution_bwd_weights_pd_t(const convolution_desc_t *adesc,
406 const primitive_attr_t *attr,
407 const convolution_fwd_pd_t *hint_fwd_pd)
408 : convolution_pd_t(adesc, attr, hint_fwd_pd)
409 , src_md_(desc_.src_desc)
410 , diff_weights_md_(desc_.diff_weights_desc)
411 , diff_bias_md_(desc_.diff_bias_desc)
412 , diff_dst_md_(desc_.diff_dst_desc) {}
413
414 arg_usage_t arg_usage(int arg) const override {
415 if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST))
416 return arg_usage_t::input;
417
418 if (arg == DNNL_ARG_DIFF_WEIGHTS) return arg_usage_t::output;
419
420 if (arg == DNNL_ARG_DIFF_BIAS && with_bias())
421 return arg_usage_t::output;
422
423 return primitive_desc_t::arg_usage(arg);
424 }
425
426 const memory_desc_t *arg_md(int arg) const override {
427 switch (arg) {
428 case DNNL_ARG_SRC: return src_md(0);
429 case DNNL_ARG_DIFF_WEIGHTS: return diff_weights_md(0);
430 case DNNL_ARG_DIFF_BIAS: return diff_weights_md(1);
431 case DNNL_ARG_DIFF_DST: return diff_dst_md(0);
432 default: return convolution_pd_t::arg_md(arg);
433 }
434 }
435
436 const memory_desc_t *src_md(int index = 0) const override {
437 return index == 0 ? &src_md_ : &glob_zero_md;
438 }
439 const memory_desc_t *diff_dst_md(int index = 0) const override {
440 return index == 0 ? &diff_dst_md_ : &glob_zero_md;
441 }
442 const memory_desc_t *diff_weights_md(int index = 0) const override {
443 if (index == 0) return &diff_weights_md_;
444 if (index == 1 && with_bias()) return &diff_bias_md_;
445 return &glob_zero_md;
446 }
447
448 int n_inputs() const override { return 2; }
449 int n_outputs() const override { return 1 + with_bias(); }
450
451protected:
452 memory_desc_t src_md_;
453 memory_desc_t diff_weights_md_;
454 memory_desc_t diff_bias_md_;
455 memory_desc_t diff_dst_md_;
456
457 bool set_default_formats_common(format_tag_t src_tag,
458 format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) {
459 return set_default_formats_common_template(src_md_, src_tag,
460 diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag,
461 diff_bias_md_);
462 }
463};
464
465} // namespace impl
466} // namespace dnnl
467
468#endif
469
470// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
471