1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef TEST_CONVOLUTION_BACKWARD_WEIGHTS_COMMON_H
18#define TEST_CONVOLUTION_BACKWARD_WEIGHTS_COMMON_H
19
20#include "dnnl_test_common.hpp"
21#include "gtest/gtest.h"
22
23#include "oneapi/dnnl/dnnl.hpp"
24
25namespace dnnl {
26
27template <typename data_t_src, typename data_t_diff_dst,
28 typename data_t_diff_bias>
29void compute_ref_conv_bwd_bias(const test_convolution_sizes_t &c,
30 const memory &diff_dst, const memory &diff_bias) {
31 auto diff_bias_data = map_memory<data_t_diff_bias>(diff_bias);
32 auto diff_dst_data = map_memory<data_t_diff_dst>(diff_dst);
33
34 const memory::desc bias_d = diff_bias.get_desc();
35 const memory::desc dst_d = diff_dst.get_desc();
36 const dnnl::impl::memory_desc_wrapper diff_bias_mdw(bias_d.get());
37 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(dst_d.get());
38
39 auto padded_oc = dst_d.get_padded_dims()[1];
40
41 dnnl::impl::parallel_nd(
42 c.ng, c.oc / c.ng, [&](memory::dim g, memory::dim oc) {
43 memory::dim bidx = g * padded_oc / c.ng + oc;
44 diff_bias_data[diff_bias_mdw.off_l(bidx, true)] = 0.0;
45 for (memory::dim mb = 0; mb < c.mb; ++mb) {
46 for (memory::dim oh = 0; oh < c.oh; ++oh) {
47 for (memory::dim ow = 0; ow < c.ow; ++ow) {
48 memory::dim oidx = mb * padded_oc * c.oh * c.ow
49 + g * padded_oc / c.ng * c.oh * c.ow
50 + oc * c.oh * c.ow + oh * c.ow + ow;
51 diff_bias_data[diff_bias_mdw.off_l(bidx, true)]
52 += diff_dst_data[diff_dst_mdw.off_l(
53 oidx, true)];
54 }
55 }
56 }
57 });
58}
59
60template <typename data_t_src, typename data_t_diff_dst,
61 typename data_t_diff_weights>
62void compute_ref_conv_bwd_weights(const test_convolution_sizes_t &c,
63 const memory &src, const memory &diff_dst, const memory &diff_weights) {
64 auto src_data = map_memory<data_t_src>(src);
65 auto diff_weights_data = map_memory<data_t_diff_weights>(diff_weights);
66 auto diff_dst_data = map_memory<data_t_diff_dst>(diff_dst);
67
68 const memory::desc src_d = src.get_desc();
69 const memory::desc weights_d = diff_weights.get_desc();
70 const memory::desc dst_d = diff_dst.get_desc();
71 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get());
72 const dnnl::impl::memory_desc_wrapper diff_weights_mdw(weights_d.get());
73 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(dst_d.get());
74
75 auto padded_ic = src_d.get_padded_dims()[1];
76 auto padded_oc = dst_d.get_padded_dims()[1];
77
78 dnnl::impl::parallel_nd(c.ng, c.oc / c.ng, c.ic / c.ng, c.kh, c.kw,
79 [&](memory::dim g, memory::dim oc, memory::dim ic, memory::dim kh,
80 memory::dim kw) {
81 memory::dim widx
82 = g * padded_oc / c.ng * padded_ic / c.ng * c.kh * c.kw
83 + oc * padded_ic / c.ng * c.kh * c.kw + ic * c.kh * c.kw
84 + kh * c.kw + kw;
85 diff_weights_data[diff_weights_mdw.off_l(widx, true)] = 0.0;
86 for (memory::dim mb = 0; mb < c.mb; ++mb) {
87 for (memory::dim oh = 0; oh < c.oh; ++oh) {
88 for (memory::dim ow = 0; ow < c.ow; ++ow) {
89 if (ow * c.strw + kw * (1 + c.dilw) < c.padw
90 || oh * c.strh + kh * (1 + c.dilh) < c.padh
91 || ow * c.strw + kw * (1 + c.dilw)
92 >= c.iw + c.padw
93 || oh * c.strh + kh * (1 + c.dilh)
94 >= c.ih + c.padh)
95 continue;
96
97 memory::dim ih
98 = oh * c.strh - c.padh + kh * (1 + c.dilh);
99 memory::dim iw
100 = ow * c.strw - c.padw + kw * (1 + c.dilw);
101 memory::dim sidx = mb * padded_ic * c.ih * c.iw
102 + g * padded_ic / c.ng * c.ih * c.iw
103 + ic * c.ih * c.iw + ih * c.iw + iw;
104 memory::dim didx = mb * padded_oc * c.oh * c.ow
105 + g * padded_oc / c.ng * c.oh * c.ow
106 + oc * c.oh * c.ow + oh * c.ow + ow;
107
108 diff_weights_data[diff_weights_mdw.off_l(
109 widx, true)]
110 += src_data[src_mdw.off_l(sidx, true)]
111 * diff_dst_data[diff_dst_mdw.off_l(
112 didx, true)];
113 }
114 }
115 }
116 });
117}
118
119template <typename data_t_src, typename data_t_diff_dst,
120 typename data_t_diff_weights, typename data_t_diff_bias>
121class convolution_backward_weights_test
122 : public ::testing::TestWithParam<test_convolution_params_t> {
123protected:
124 virtual void SetUp() {
125 auto p = ::testing::TestWithParam<
126 test_convolution_params_t>::GetParam();
127
128 SKIP_IF_CUDA(
129 !(cuda_check_format_tags(p.formats.src_format)
130 && cuda_check_format_tags(p.formats.dst_format)
131 && (cuda_check_format_tags(p.formats.weights_format)
132 || (impl::utils::one_of(
133 p.formats.weights_format,
134 /* weights formats */
135 memory::format_tag::gowi,
136 memory::format_tag::gohwi,
137 memory::format_tag::godhwi,
138 memory::format_tag::owi,
139 memory::format_tag::ohwi,
140 memory::format_tag::odhwi)))
141 && data_traits<data_t_src>::data_type
142 == memory::data_type::f32
143 && data_traits<data_t_diff_dst>::data_type
144 == memory::data_type::f32
145 && data_traits<data_t_diff_weights>::data_type
146 == memory::data_type::f32
147 && check_cuda_alg_format(p.formats.dst_format,
148 p.formats.weights_format, p.aalgorithm)),
149 "format is not supported.");
150
151 catch_expected_failures(
152 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
153 }
154
155 bool cuda_check_format_tags(memory::format_tag tag) {
156 return impl::utils::one_of(tag, memory::format_tag::ab,
157 memory::format_tag::abc, memory::format_tag::abcd,
158 memory::format_tag::abcde, memory::format_tag::abcdef,
159 memory::format_tag::acb, memory::format_tag::acdb,
160 memory::format_tag::acdeb);
161 }
162
163 bool check_cuda_alg_format(memory::format_tag dst_fmt,
164 memory::format_tag wei_fmt, algorithm alg) {
165 bool res = dst_fmt == wei_fmt;
166 if (alg == dnnl::algorithm::convolution_winograd) {
167 res = res
168 && impl::utils::one_of(wei_fmt, memory::format_tag::ab,
169 memory::format_tag::abc, memory::format_tag::abcd,
170 memory::format_tag::abcde,
171 memory::format_tag::abcdef);
172 }
173 return res;
174 }
175
176 void Test() {
177 auto p = ::testing::TestWithParam<
178 test_convolution_params_t>::GetParam();
179
180 ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
181 auto eng = get_test_engine();
182 auto strm = stream(eng);
183 memory::data_type data_type_src = data_traits<data_t_src>::data_type;
184 memory::data_type data_type_diff_dst
185 = data_traits<data_t_diff_dst>::data_type;
186 memory::data_type data_type_diff_weights
187 = data_traits<data_t_diff_weights>::data_type;
188 memory::data_type data_type_diff_bias
189 = data_traits<data_t_diff_bias>::data_type;
190
191 test_convolution_sizes_t cd = p.sizes;
192
193 bool with_bias = p.formats.bias_format != memory::format_tag::undef;
194
195 auto c_src_desc = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, data_type_src,
196 p.formats.src_format);
197 auto c_diff_weights_desc = cd.ng > 1
198 ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw},
199 data_type_diff_weights, p.formats.weights_format)
200 : create_md({cd.oc, cd.ic, cd.kh, cd.kw},
201 data_type_diff_weights, p.formats.weights_format);
202 auto c_diff_bias_desc = create_md(
203 {cd.oc}, data_type_diff_bias, p.formats.bias_format);
204 auto c_diff_dst_desc = create_md({cd.mb, cd.oc, cd.oh, cd.ow},
205 data_type_diff_dst, p.formats.dst_format);
206 auto c_weights_desc_f = cd.ng > 1
207 ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw},
208 data_type_diff_dst, p.formats.weights_format)
209 : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type_diff_dst,
210 p.formats.weights_format);
211 auto c_dst_desc_f = create_md({cd.mb, cd.oc, cd.oh, cd.ow},
212 data_type_diff_weights, p.formats.dst_format);
213 auto c_src = test_memory(c_src_desc, eng);
214 auto c_diff_weights = test_memory(c_diff_weights_desc, eng);
215 auto c_diff_bias = test_memory(c_diff_bias_desc, eng);
216 auto c_diff_dst = test_memory(c_diff_dst_desc, eng);
217 auto weights_primitive_desc_f = test_memory(c_weights_desc_f, eng);
218 auto dst_primitive_desc_f = test_memory(c_dst_desc_f, eng);
219 fill_data<data_t_diff_dst>(
220 c_diff_dst.get_size() / sizeof(data_t_diff_dst),
221 c_diff_dst.get());
222 fill_data<data_t_src>(
223 c_src.get_size() / sizeof(data_t_src), c_src.get());
224 fill_data<data_t_diff_weights>(
225 c_diff_weights.get_size() / sizeof(data_t_diff_weights),
226 c_diff_weights.get());
227
228 check_zero_tail<data_t_diff_dst>(1, c_diff_dst.get());
229 check_zero_tail<data_t_src>(1, c_src.get());
230 check_zero_tail<data_t_diff_weights>(1, c_diff_weights.get());
231
232 memory::dims padR = {
233 right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
234 right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)};
235
236 auto conv_primitive_desc = convolution_forward::primitive_desc(eng,
237 prop_kind::forward_training, p.aalgorithm, c_src_desc,
238 c_weights_desc_f, c_diff_bias_desc, c_dst_desc_f,
239 {cd.strh, cd.strw}, {cd.dilh, cd.dilw}, {cd.padh, cd.padw},
240 padR);
241
242 auto conv_bwd_weights_primitive_desc
243 = convolution_backward_weights::primitive_desc(eng,
244 p.aalgorithm, c_src_desc, c_diff_weights_desc,
245 c_diff_bias_desc, c_diff_dst_desc, {cd.strh, cd.strw},
246 {cd.dilh, cd.dilw}, {cd.padh, cd.padw}, padR,
247 conv_primitive_desc);
248
249 conv_bwd_weights_primitive_desc
250 = convolution_backward_weights::primitive_desc(
251 conv_bwd_weights_primitive_desc
252 .get()); // test construction from a C pd
253
254 ASSERT_TRUE(conv_bwd_weights_primitive_desc.query_md(
255 query::exec_arg_md, DNNL_ARG_SRC)
256 == conv_bwd_weights_primitive_desc.src_desc());
257 ASSERT_TRUE(conv_bwd_weights_primitive_desc.query_md(
258 query::exec_arg_md, DNNL_ARG_DIFF_DST)
259 == conv_bwd_weights_primitive_desc.diff_dst_desc());
260 ASSERT_TRUE(conv_bwd_weights_primitive_desc.query_md(
261 query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS)
262 == conv_bwd_weights_primitive_desc.diff_weights_desc());
263 ASSERT_TRUE(conv_bwd_weights_primitive_desc.query_md(
264 query::exec_arg_md, DNNL_ARG_DIFF_BIAS)
265 == conv_bwd_weights_primitive_desc.diff_bias_desc());
266
267 EXPECT_ANY_THROW(convolution_backward_weights(
268 conv_bwd_weights_primitive_desc, {}));
269 convolution_backward_weights(conv_bwd_weights_primitive_desc)
270 .execute(strm,
271 {{DNNL_ARG_DIFF_DST, c_diff_dst.get()},
272 {DNNL_ARG_SRC, c_src.get()},
273 {DNNL_ARG_DIFF_WEIGHTS, c_diff_weights.get()},
274 {DNNL_ARG_DIFF_BIAS, c_diff_bias.get()}});
275 strm.wait();
276
277 auto ref_diff_weights = test::make_memory(c_diff_weights_desc, eng);
278 auto ref_diff_bias = test::make_memory(c_diff_bias_desc, eng);
279
280 compute_ref_conv_bwd_weights<data_t_src, data_t_diff_dst,
281 data_t_diff_weights>(
282 cd, c_src.get(), c_diff_dst.get(), ref_diff_weights);
283 check_zero_tail<data_t_diff_weights>(1, ref_diff_weights);
284 compare_data<data_t_diff_weights>(
285 ref_diff_weights, c_diff_weights.get());
286 check_zero_tail<data_t_diff_weights>(1, c_diff_weights.get());
287
288 if (with_bias) {
289 compute_ref_conv_bwd_bias<data_t_src, data_t_diff_dst,
290 data_t_diff_bias>(cd, c_diff_dst.get(), ref_diff_bias);
291
292 compare_data<data_t_diff_bias>(ref_diff_bias, c_diff_bias.get());
293 }
294 }
295};
296
297} // namespace dnnl
298#endif
299