1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19#include "opdesc.hpp"
20#include "primitive_desc_iface.hpp"
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::utils;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::prop_kind;
30using namespace dnnl::impl::alg_kind;
31using namespace dnnl::impl::types;
32
33namespace {
34status_t deconv_desc_init(deconvolution_desc_t *deconv_desc,
35 prop_kind_t prop_kind, alg_kind_t alg_kind,
36 const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
37 const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
38 const dims_t strides, const dims_t dilates, const dims_t padding_l,
39 const dims_t padding_r) {
40 bool args_ok = true
41 && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides,
42 padding_l)
43 && one_of(alg_kind, deconvolution_direct, deconvolution_winograd);
44 if (!args_ok) return invalid_arguments;
45
46 if (padding_r == nullptr) padding_r = padding_l;
47
48 auto dd = deconvolution_desc_t();
49 dd.primitive_kind = primitive_kind::deconvolution;
50 dd.prop_kind = prop_kind;
51 dd.alg_kind = alg_kind;
52
53 dd.diff_src_desc = dd.src_desc = zero_md();
54 dd.diff_dst_desc = dd.dst_desc = zero_md();
55 dd.diff_weights_desc = dd.weights_desc = zero_md();
56 dd.diff_bias_desc = dd.bias_desc = zero_md();
57
58 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
59 const bool with_bias
60 = bias_desc && bias_desc->format_kind != format_kind::undef;
61 const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
62
63 bool runtime_dims_or_strides
64 = memory_desc_wrapper(src_desc).has_runtime_dims_or_strides()
65 || memory_desc_wrapper(weights_desc).has_runtime_dims_or_strides()
66 || memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides();
67 if (with_bias)
68 runtime_dims_or_strides = runtime_dims_or_strides
69 || memory_desc_wrapper(bias_desc).has_runtime_dims_or_strides();
70 if (runtime_dims_or_strides) return unimplemented;
71
72 (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc;
73 (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc;
74 (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc)
75 = *weights_desc;
76 if (with_bias)
77 (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc)
78 = *bias_desc;
79
80 int sp_dims = src_desc->ndims - 2;
81 utils::array_copy(dd.strides, strides, sp_dims);
82 utils::array_copy(dd.padding[0], padding_l, sp_dims);
83 utils::array_copy(dd.padding[1], padding_r, sp_dims);
84 if (dilates)
85 utils::array_copy(dd.dilates, dilates, sp_dims);
86 else
87 utils::array_set(dd.dilates, 0, sp_dims);
88
89 dd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
90 weights_desc->data_type, dst_desc->data_type, prop_kind);
91 if (dd.accum_data_type == data_type::undef) return invalid_arguments;
92
93 const int g = with_groups ? weights_desc->dims[0] : 1;
94 bool consistency = true && src_desc->ndims == dst_desc->ndims
95 && utils::one_of(src_desc->ndims, 3, 4, 5)
96 && utils::one_of(
97 weights_desc->ndims, src_desc->ndims, src_desc->ndims + 1)
98 && (with_bias ? bias_desc->ndims == 1 : true)
99 && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
100 && src_desc->dims[0] == dst_desc->dims[0]
101 && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
102 && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
103 for (int i = 2; i < src_desc->ndims; ++i) {
104 int src = src_desc->dims[i];
105 int ker = weights_desc->dims[with_groups + i];
106 int dil = dd.dilates[i - 2];
107 int pad_l = padding_l[i - 2];
108 int pad_r = padding_r[i - 2];
109 int str = strides[i - 2];
110 int dst = dst_desc->dims[i];
111 int ker_range = 1 + (ker - 1) * (dil + 1);
112
113 if (str < 1) return invalid_arguments;
114 consistency = consistency && dil >= 0 && pad_l >= 0 && pad_r + str > 0
115 && (dst - ker_range + pad_l + pad_r) / str + 1 == src;
116 }
117 if (!consistency) return invalid_arguments;
118
119 *deconv_desc = dd;
120 return success;
121}
122} // namespace
123
124status_t dnnl_deconvolution_forward_primitive_desc_create(
125 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
126 prop_kind_t prop_kind, alg_kind_t alg_kind,
127 const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
128 const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
129 const dims_t strides, const dims_t dilates, const dims_t padding_l,
130 const dims_t padding_r, const primitive_attr_t *attr) {
131 if (!one_of(prop_kind, forward_training, forward_inference))
132 return invalid_arguments;
133
134 auto deconv_desc = deconvolution_desc_t();
135 CHECK(deconv_desc_init(&deconv_desc, prop_kind, alg_kind, src_desc,
136 weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
137 padding_r));
138 return primitive_desc_create(primitive_desc_iface, engine,
139 (const op_desc_t *)&deconv_desc, nullptr, attr);
140}
141
142status_t dnnl_deconvolution_backward_data_primitive_desc_create(
143 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
144 alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
145 const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc,
146 const dims_t strides, const dims_t dilates, const dims_t padding_l,
147 const dims_t padding_r, const primitive_desc_iface_t *hint_fwd_pd,
148 const primitive_attr_t *attr) {
149
150 auto deconv_desc = deconvolution_desc_t();
151 CHECK(deconv_desc_init(&deconv_desc, backward_data, alg_kind, diff_src_desc,
152 weights_desc, nullptr, diff_dst_desc, strides, dilates, padding_l,
153 padding_r));
154 return primitive_desc_create(primitive_desc_iface, engine,
155 (const op_desc_t *)&deconv_desc, hint_fwd_pd, attr);
156}
157
158status_t dnnl_deconvolution_backward_weights_primitive_desc_create(
159 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
160 alg_kind_t alg_kind, const memory_desc_t *src_desc,
161 const memory_desc_t *diff_weights_desc,
162 const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
163 const dims_t strides, const dims_t dilates, const dims_t padding_l,
164 const dims_t padding_r, const primitive_desc_iface_t *hint_fwd_pd,
165 const primitive_attr_t *attr) {
166
167 auto deconv_desc = deconvolution_desc_t();
168 CHECK(deconv_desc_init(&deconv_desc, backward_weights, alg_kind, src_desc,
169 diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates,
170 padding_l, padding_r));
171 return primitive_desc_create(primitive_desc_iface, engine,
172 (const op_desc_t *)&deconv_desc, hint_fwd_pd, attr);
173}
174
175// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
176