1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "math_utils.hpp"
19#include "oneapi/dnnl/dnnl.hpp"
20#include "gtest/gtest.h"
21
22using namespace dnnl::impl::math;
23
24namespace dnnl {
25
26template <typename data_t_src, typename data_t_wei, typename data_t_acc,
27 typename data_t_dst>
28void compute_ref_conv_eltwise_fwd(const test_convolution_sizes_t &c,
29 const memory &src, const memory &weights, const memory &bias,
30 const memory &dst, bool w_bias, algorithm elt_alg, float elt_alpha,
31 float elt_beta) {
32 auto src_data = map_memory<data_t_src>(src);
33 auto weights_data = map_memory<data_t_wei>(weights);
34 auto bias_data = w_bias ? map_memory<data_t_dst>(bias) : nullptr;
35 auto dst_data = map_memory<data_t_dst>(dst);
36
37 const memory::desc src_d = src.get_desc();
38 const memory::desc weights_d = weights.get_desc();
39 const memory::desc dst_d = dst.get_desc();
40
41 auto padded_ic = src_d.get_padded_dims()[1];
42 auto padded_oc = dst_d.get_padded_dims()[1];
43
44 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get());
45 const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get());
46 const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get());
47
48 dnnl::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
49 [&](memory::dim n, memory::dim g, memory::dim oc, memory::dim oh,
50 memory::dim ow) {
51 memory::dim oidx = n * padded_oc * c.oh * c.ow
52 + g * padded_oc / c.ng * c.oh * c.ow + oc * c.oh * c.ow
53 + oh * c.ow + ow;
54
55 memory::dim didx = dst_mdw.off_l(oidx, true);
56 dst_data[didx] = bias_data ? bias_data[g * c.oc / c.ng + oc]
57 : data_t_dst {0};
58
59 for_(memory::dim ic = 0; ic < c.ic / c.ng; ic++)
60 for_(memory::dim kh = 0; kh < c.kh; kh++)
61 for (memory::dim kw = 0; kw < c.kw; kw++) {
62 memory::dim ih = oh * c.strh - c.padh + kh * (1 + c.dilh);
63 if (ih < 0 || ih >= c.ih) continue;
64 memory::dim iw = ow * c.strw - c.padw + kw * (1 + c.dilw);
65 if (iw < 0 || iw >= c.iw) continue;
66
67 memory::dim iidx = n * padded_ic * c.ih * c.iw
68 + g * padded_ic / c.ng * c.ih * c.iw
69 + ic * c.ih * c.iw + ih * c.iw + iw;
70 memory::dim widx = 0
71 + g * padded_oc / c.ng * padded_ic / c.ng * c.kh
72 * c.kw
73 + oc * padded_ic / c.ng * c.kh * c.kw
74 + ic * c.kh * c.kw + kh * c.kw + kw;
75
76 dst_data[didx] += src_data[src_mdw.off_l(iidx, true)]
77 * weights_data[weights_mdw.off_l(widx, true)];
78 }
79
80 auto &d = dst_data[didx];
81 switch (elt_alg) {
82 case algorithm::eltwise_relu:
83 d = relu_fwd(d, elt_alpha);
84 break;
85 case algorithm::eltwise_tanh: d = tanh_fwd(d); break;
86 case algorithm::eltwise_elu:
87 d = elu_fwd(d, elt_alpha);
88 break;
89 case algorithm::eltwise_square: d = square_fwd(d); break;
90 case algorithm::eltwise_abs: d = abs_fwd(d); break;
91 case algorithm::eltwise_linear:
92 d = linear_fwd(d, elt_alpha, elt_beta);
93 break;
94 case algorithm::eltwise_clip:
95 d = clip_fwd(d, elt_alpha, elt_beta);
96 break;
97 case algorithm::eltwise_soft_relu:
98 d = soft_relu_fwd(d, elt_alpha);
99 break;
100 case algorithm::eltwise_logistic:
101 d = logistic_fwd(d);
102 break;
103 case algorithm::eltwise_exp: d = exp_fwd(d); break;
104 case algorithm::eltwise_swish:
105 d = swish_fwd(d, elt_alpha);
106 break;
107 default: assert(!"unknown alg_kind");
108 }
109 });
110}
111
112template <typename data_t_src, typename data_t_wei, typename data_t_acc,
113 typename data_t_dst>
114class convolution_eltwise_test
115 : public ::testing::TestWithParam<test_convolution_eltwise_params_t> {
116protected:
117 virtual void SetUp() {
118 memory::data_type data_type_src = data_traits<data_t_src>::data_type;
119 memory::data_type data_type_dst = data_traits<data_t_dst>::data_type;
120 memory::data_type data_type_wei = data_traits<data_t_wei>::data_type;
121
122 SKIP_IF(unsupported_data_type(data_type_src),
123 "Engine does not support this data type.");
124 SKIP_IF(unsupported_data_type(data_type_dst),
125 "Engine does not support this data type.");
126 SKIP_IF(unsupported_data_type(data_type_wei),
127 "Engine does not support this data type.");
128
129 test_convolution_eltwise_params_t p = ::testing::TestWithParam<
130 test_convolution_eltwise_params_t>::GetParam();
131
132 SKIP_IF_CUDA(
133 !(cuda_check_format_tags(p.formats.src_format, data_type_src)
134 && cuda_check_format_tags(
135 p.formats.dst_format, data_type_dst)
136 && (cuda_check_format_tags(
137 p.formats.weights_format, data_type_wei)
138 || impl::utils::one_of(p.formats.weights_format,
139 /* weights formats */
140 memory::format_tag::gowi,
141 memory::format_tag::gohwi,
142 memory::format_tag::godhwi,
143 memory::format_tag::owi,
144 memory::format_tag::ohwi,
145 memory::format_tag::odhwi))),
146 "Format is not supported.");
147 SKIP_IF_CUDA(p.alg != algorithm::eltwise_relu
148 && p.alg != algorithm::eltwise_tanh
149 && p.alg != algorithm::eltwise_elu
150 && p.alg != algorithm::eltwise_logistic,
151 "Unsupported algorithm type for CUDA");
152 SKIP_IF_CUDA(p.alg == algorithm::eltwise_relu && p.eltwise_alpha != 0.0,
153 "DNNL only supports relu w/ slope=0 for integers");
154
155 catch_expected_failures(
156 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
157 }
158
159 bool cuda_check_format_tags(memory::format_tag tag, memory::data_type dt) {
160 return ((impl::utils::one_of(tag, memory::format_tag::ab,
161 memory::format_tag::abc, memory::format_tag::abcd,
162 memory::format_tag::abcde, memory::format_tag::abcdef,
163 memory::format_tag::acb, memory::format_tag::acdb,
164 memory::format_tag::acdeb))
165 || (dt == memory::data_type::s8
166 && impl::utils::one_of(tag, memory::format_tag::aBcd4b,
167 memory::format_tag::aBcde4b)));
168 }
169
170 virtual void Test() {
171 test_convolution_eltwise_params_t p = ::testing::TestWithParam<
172 test_convolution_eltwise_params_t>::GetParam();
173 ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
174 auto eng = get_test_engine();
175 auto strm = stream(eng);
176 float eltwise_alpha = p.eltwise_alpha;
177 float eltwise_beta = p.eltwise_beta;
178
179 memory::data_type data_type_src = data_traits<data_t_src>::data_type;
180 memory::data_type data_type_dst = data_traits<data_t_dst>::data_type;
181 memory::data_type data_type_wei = data_traits<data_t_wei>::data_type;
182
183 test_convolution_sizes_t cd = p.sizes;
184
185 auto c_src_desc = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, data_type_src,
186 p.formats.src_format);
187 auto c_weights_desc = cd.ng > 1
188 ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw},
189 data_type_wei, p.formats.weights_format)
190 : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type_wei,
191 p.formats.weights_format);
192 auto c_dst_desc = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, data_type_dst,
193 p.formats.dst_format);
194
195 auto c_src = test::make_memory(c_src_desc, eng);
196 auto c_weights = test::make_memory(c_weights_desc, eng);
197 auto c_dst = test::make_memory(c_dst_desc, eng);
198
199 auto dst_ref = test::make_memory(c_dst_desc, eng);
200
201 fill_data<data_t_src>(c_src.get_desc().get_size() / sizeof(data_t_src),
202 c_src, data_t_src(0), data_t_src(1));
203 check_zero_tail<data_t_src>(1, c_src);
204
205 fill_data<data_t_wei>(
206 c_weights.get_desc().get_size() / sizeof(data_t_wei), c_weights,
207 data_t_wei(0), data_t_wei(1));
208 check_zero_tail<data_t_wei>(1, c_weights);
209
210 bool with_bias = p.formats.bias_format != memory::format_tag::undef;
211 auto c_bias_desc = with_bias
212 ? create_md({cd.oc}, data_type_dst, p.formats.bias_format)
213 : create_md({0}, data_type_dst, p.formats.bias_format);
214 auto c_bias = test::make_memory(c_bias_desc, eng);
215 if (with_bias) {
216 fill_data<data_t_dst>(
217 c_bias.get_desc().get_size() / sizeof(data_t_dst), c_bias,
218 1., true);
219 }
220
221 memory::dims strides = {cd.strh, cd.strw};
222 memory::dims dilations = {cd.dilh, cd.dilw};
223 memory::dims padL = {cd.padh, cd.padw};
224 memory::dims padR = {cd.padh, cd.padw};
225 for (int i = 0; i < 2; ++i) {
226 if ((cd.ih - ((cd.kh - 1) * (cd.dilh + 1) + 1) + cd.padh + padR[0])
227 / cd.strh
228 + 1
229 != cd.oh)
230 ++padR[0];
231 if ((cd.iw - ((cd.kw - 1) * (cd.dilw + 1) + 1) + cd.padw + padR[1])
232 / cd.strw
233 + 1
234 != cd.ow)
235 ++padR[1];
236 }
237
238 SKIP_IF_CUDA(cd.padh < padR[0] || cd.padw < padR[1],
239 "Unsupported padding for CUDA.");
240
241 dnnl::post_ops ops;
242 ops.append_eltwise(p.alg, p.eltwise_alpha, p.eltwise_beta);
243
244 dnnl::primitive_attr attr;
245 attr.set_post_ops(ops);
246
247 auto conv_primitive_desc = with_bias
248 ? convolution_forward::primitive_desc(eng,
249 prop_kind::forward_inference, p.aalgorithm, c_src_desc,
250 c_weights_desc, c_bias_desc, c_dst_desc, strides,
251 dilations, padL, padR, attr)
252 : convolution_forward::primitive_desc(eng,
253 prop_kind::forward_inference, p.aalgorithm, c_src_desc,
254 c_weights_desc, c_dst_desc, strides, dilations, padL,
255 padR, attr);
256
257 ASSERT_EQ(conv_primitive_desc.get_algorithm(), p.aalgorithm);
258 ASSERT_EQ(conv_primitive_desc.get_prop_kind(),
259 prop_kind::forward_inference);
260 ASSERT_EQ(conv_primitive_desc.get_strides(), strides);
261 ASSERT_EQ(conv_primitive_desc.get_dilations(), dilations);
262 ASSERT_EQ(conv_primitive_desc.get_padding_l(), padL);
263 ASSERT_EQ(conv_primitive_desc.get_padding_r(), padR);
264
265 EXPECT_ANY_THROW(convolution_forward(conv_primitive_desc, {}));
266 convolution_forward(conv_primitive_desc)
267 .execute(strm,
268 {{DNNL_ARG_SRC, c_src}, {DNNL_ARG_WEIGHTS, c_weights},
269 {DNNL_ARG_BIAS, c_bias},
270 {DNNL_ARG_DST, c_dst}});
271 strm.wait();
272
273 compute_ref_conv_eltwise_fwd<data_t_src, data_t_wei, data_t_wei,
274 data_t_dst>(cd, c_src, c_weights, c_bias, dst_ref, with_bias,
275 p.alg, eltwise_alpha, eltwise_beta);
276 check_zero_tail<data_t_dst>(1, dst_ref);
277
278 compare_data<data_t_dst>(dst_ref, c_dst, 1e-2);
279 check_zero_tail<data_t_dst>(0, c_dst);
280 }
281};
282
283} // namespace dnnl
284