1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19#include "opdesc.hpp"
20#include "primitive_desc_iface.hpp"
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::utils;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::prop_kind;
30using namespace dnnl::impl::alg_kind;
31using namespace dnnl::impl::types;
32
33namespace dnnl {
34namespace impl {
35status_t conv_desc_init(convolution_desc_t *conv_desc, prop_kind_t prop_kind,
36 alg_kind_t alg_kind, const memory_desc_t *src_desc,
37 const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
38 const memory_desc_t *dst_desc, const dims_t strides,
39 const dims_t dilates, const dims_t padding_l, const dims_t padding_r) {
40 bool args_ok = true
41 && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
42 padding_l)
43 && one_of(alg_kind, convolution_auto, convolution_direct,
44 convolution_winograd);
45 if (!args_ok) return invalid_arguments;
46
47 if (padding_r == nullptr) padding_r = padding_l;
48
49 auto cd = convolution_desc_t();
50 cd.primitive_kind = primitive_kind::convolution;
51 cd.prop_kind = prop_kind;
52 cd.alg_kind = alg_kind;
53
54 cd.diff_src_desc = cd.src_desc = zero_md();
55 cd.diff_dst_desc = cd.dst_desc = zero_md();
56 cd.diff_weights_desc = cd.weights_desc = zero_md();
57 cd.diff_bias_desc = cd.bias_desc = zero_md();
58
59 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
60 const bool with_bias
61 = bias_desc && bias_desc->format_kind != format_kind::undef;
62 const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
63
64 bool runtime_dims_or_strides
65 = memory_desc_wrapper(src_desc).has_runtime_dims_or_strides()
66 || memory_desc_wrapper(weights_desc).has_runtime_dims_or_strides()
67 || memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides();
68 if (with_bias)
69 runtime_dims_or_strides = runtime_dims_or_strides
70 || memory_desc_wrapper(bias_desc).has_runtime_dims_or_strides();
71 if (runtime_dims_or_strides) return unimplemented;
72
73 (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc;
74 (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc;
75 (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc)
76 = *weights_desc;
77 if (with_bias)
78 (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc)
79 = *bias_desc;
80
81 cd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
82 weights_desc->data_type, dst_desc->data_type, prop_kind);
83 if (cd.accum_data_type == data_type::undef) return invalid_arguments;
84
85 bool consistency = memory_desc_wrapper(weights_desc).nelems()
86 && src_desc->ndims == dst_desc->ndims
87 && utils::one_of(src_desc->ndims, 3, 4, 5)
88 && utils::one_of(
89 weights_desc->ndims, src_desc->ndims, src_desc->ndims + 1);
90 if (!consistency) return invalid_arguments;
91
92 const int g = with_groups ? weights_desc->dims[0] : 1;
93 const int bias_dim = prop_kind == backward_data ? src_desc->dims[1]
94 : dst_desc->dims[1];
95 consistency
96 = IMPLICATION(with_bias,
97 bias_desc->ndims == 1 && bias_desc->dims[0] == bias_dim)
98 && src_desc->dims[0] == dst_desc->dims[0]
99 && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
100 && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
101 if (!consistency) return invalid_arguments;
102
103 int sp_dims = src_desc->ndims - 2;
104 utils::array_copy(cd.strides, strides, sp_dims);
105 utils::array_copy(cd.padding[0], padding_l, sp_dims);
106 utils::array_copy(cd.padding[1], padding_r, sp_dims);
107 if (dilates)
108 utils::array_copy(cd.dilates, dilates, sp_dims);
109 else
110 utils::array_set(cd.dilates, 0, sp_dims);
111
112 for (int i = 2; i < src_desc->ndims; ++i) {
113 int src = src_desc->dims[i];
114 int ker = weights_desc->dims[with_groups + i];
115 int dil = cd.dilates[i - 2];
116 int pad_l = padding_l[i - 2];
117 int pad_r = padding_r[i - 2];
118 int str = strides[i - 2];
119 int dst = dst_desc->dims[i];
120 int ker_range = 1 + (ker - 1) * (dil + 1);
121
122 if (str < 1) return invalid_arguments;
123 consistency = consistency && dil >= 0 && pad_l >= 0 && pad_r + str > 0
124 && (src - ker_range + pad_l + pad_r) / str + 1 == dst;
125 }
126 if (!consistency) return invalid_arguments;
127
128 *conv_desc = cd;
129 return success;
130}
131} // namespace impl
132} // namespace dnnl
133
134status_t dnnl_convolution_forward_primitive_desc_create(
135 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
136 prop_kind_t prop_kind, alg_kind_t alg_kind,
137 const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
138 const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
139 const dims_t strides, const dims_t dilates, const dims_t padding_l,
140 const dims_t padding_r, const primitive_attr_t *attr) {
141 if (!one_of(prop_kind, forward_training, forward_inference))
142 return invalid_arguments;
143
144 auto conv_desc = convolution_desc_t();
145 CHECK(dnnl::impl::conv_desc_init(&conv_desc, prop_kind, alg_kind, src_desc,
146 weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
147 padding_r));
148 return primitive_desc_create(primitive_desc_iface, engine,
149 (const op_desc_t *)&conv_desc, nullptr, attr);
150}
151
152status_t dnnl_convolution_backward_data_primitive_desc_create(
153 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
154 alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
155 const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc,
156 const dims_t strides, const dims_t dilates, const dims_t padding_l,
157 const dims_t padding_r, const primitive_desc_iface_t *hint_fwd_pd,
158 const primitive_attr_t *attr) {
159
160 auto conv_desc = convolution_desc_t();
161 CHECK(dnnl::impl::conv_desc_init(&conv_desc, backward_data, alg_kind,
162 diff_src_desc, weights_desc, nullptr, diff_dst_desc, strides,
163 dilates, padding_l, padding_r));
164 return primitive_desc_create(primitive_desc_iface, engine,
165 (const op_desc_t *)&conv_desc, hint_fwd_pd, attr);
166}
167
168status_t dnnl_convolution_backward_weights_primitive_desc_create(
169 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
170 alg_kind_t alg_kind, const memory_desc_t *src_desc,
171 const memory_desc_t *diff_weights_desc,
172 const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
173 const dims_t strides, const dims_t dilates, const dims_t padding_l,
174 const dims_t padding_r, const primitive_desc_iface_t *hint_fwd_pd,
175 const primitive_attr_t *attr) {
176
177 auto conv_desc = convolution_desc_t();
178 CHECK(dnnl::impl::conv_desc_init(&conv_desc, backward_weights, alg_kind,
179 src_desc, diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
180 dilates, padding_l, padding_r));
181 return primitive_desc_create(primitive_desc_iface, engine,
182 (const op_desc_t *)&conv_desc, hint_fwd_pd, attr);
183}
184
185// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
186