1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef TEST_CONVOLUTION_BACKWARD_WEIGHTS_COMMON_H |
18 | #define TEST_CONVOLUTION_BACKWARD_WEIGHTS_COMMON_H |
19 | |
20 | #include "dnnl_test_common.hpp" |
21 | #include "gtest/gtest.h" |
22 | |
23 | #include "oneapi/dnnl/dnnl.hpp" |
24 | |
25 | namespace dnnl { |
26 | |
27 | template <typename data_t_src, typename data_t_diff_dst, |
28 | typename data_t_diff_bias> |
29 | void compute_ref_conv_bwd_bias(const test_convolution_sizes_t &c, |
30 | const memory &diff_dst, const memory &diff_bias) { |
31 | auto diff_bias_data = map_memory<data_t_diff_bias>(diff_bias); |
32 | auto diff_dst_data = map_memory<data_t_diff_dst>(diff_dst); |
33 | |
34 | const memory::desc bias_d = diff_bias.get_desc(); |
35 | const memory::desc dst_d = diff_dst.get_desc(); |
36 | const dnnl::impl::memory_desc_wrapper diff_bias_mdw(bias_d.get()); |
37 | const dnnl::impl::memory_desc_wrapper diff_dst_mdw(dst_d.get()); |
38 | |
39 | auto padded_oc = dst_d.get_padded_dims()[1]; |
40 | |
41 | dnnl::impl::parallel_nd( |
42 | c.ng, c.oc / c.ng, [&](memory::dim g, memory::dim oc) { |
43 | memory::dim bidx = g * padded_oc / c.ng + oc; |
44 | diff_bias_data[diff_bias_mdw.off_l(bidx, true)] = 0.0; |
45 | for (memory::dim mb = 0; mb < c.mb; ++mb) { |
46 | for (memory::dim oh = 0; oh < c.oh; ++oh) { |
47 | for (memory::dim ow = 0; ow < c.ow; ++ow) { |
48 | memory::dim oidx = mb * padded_oc * c.oh * c.ow |
49 | + g * padded_oc / c.ng * c.oh * c.ow |
50 | + oc * c.oh * c.ow + oh * c.ow + ow; |
51 | diff_bias_data[diff_bias_mdw.off_l(bidx, true)] |
52 | += diff_dst_data[diff_dst_mdw.off_l( |
53 | oidx, true)]; |
54 | } |
55 | } |
56 | } |
57 | }); |
58 | } |
59 | |
60 | template <typename data_t_src, typename data_t_diff_dst, |
61 | typename data_t_diff_weights> |
62 | void compute_ref_conv_bwd_weights(const test_convolution_sizes_t &c, |
63 | const memory &src, const memory &diff_dst, const memory &diff_weights) { |
64 | auto src_data = map_memory<data_t_src>(src); |
65 | auto diff_weights_data = map_memory<data_t_diff_weights>(diff_weights); |
66 | auto diff_dst_data = map_memory<data_t_diff_dst>(diff_dst); |
67 | |
68 | const memory::desc src_d = src.get_desc(); |
69 | const memory::desc weights_d = diff_weights.get_desc(); |
70 | const memory::desc dst_d = diff_dst.get_desc(); |
71 | const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get()); |
72 | const dnnl::impl::memory_desc_wrapper diff_weights_mdw(weights_d.get()); |
73 | const dnnl::impl::memory_desc_wrapper diff_dst_mdw(dst_d.get()); |
74 | |
75 | auto padded_ic = src_d.get_padded_dims()[1]; |
76 | auto padded_oc = dst_d.get_padded_dims()[1]; |
77 | |
78 | dnnl::impl::parallel_nd(c.ng, c.oc / c.ng, c.ic / c.ng, c.kh, c.kw, |
79 | [&](memory::dim g, memory::dim oc, memory::dim ic, memory::dim kh, |
80 | memory::dim kw) { |
81 | memory::dim widx |
82 | = g * padded_oc / c.ng * padded_ic / c.ng * c.kh * c.kw |
83 | + oc * padded_ic / c.ng * c.kh * c.kw + ic * c.kh * c.kw |
84 | + kh * c.kw + kw; |
85 | diff_weights_data[diff_weights_mdw.off_l(widx, true)] = 0.0; |
86 | for (memory::dim mb = 0; mb < c.mb; ++mb) { |
87 | for (memory::dim oh = 0; oh < c.oh; ++oh) { |
88 | for (memory::dim ow = 0; ow < c.ow; ++ow) { |
89 | if (ow * c.strw + kw * (1 + c.dilw) < c.padw |
90 | || oh * c.strh + kh * (1 + c.dilh) < c.padh |
91 | || ow * c.strw + kw * (1 + c.dilw) |
92 | >= c.iw + c.padw |
93 | || oh * c.strh + kh * (1 + c.dilh) |
94 | >= c.ih + c.padh) |
95 | continue; |
96 | |
97 | memory::dim ih |
98 | = oh * c.strh - c.padh + kh * (1 + c.dilh); |
99 | memory::dim iw |
100 | = ow * c.strw - c.padw + kw * (1 + c.dilw); |
101 | memory::dim sidx = mb * padded_ic * c.ih * c.iw |
102 | + g * padded_ic / c.ng * c.ih * c.iw |
103 | + ic * c.ih * c.iw + ih * c.iw + iw; |
104 | memory::dim didx = mb * padded_oc * c.oh * c.ow |
105 | + g * padded_oc / c.ng * c.oh * c.ow |
106 | + oc * c.oh * c.ow + oh * c.ow + ow; |
107 | |
108 | diff_weights_data[diff_weights_mdw.off_l( |
109 | widx, true)] |
110 | += src_data[src_mdw.off_l(sidx, true)] |
111 | * diff_dst_data[diff_dst_mdw.off_l( |
112 | didx, true)]; |
113 | } |
114 | } |
115 | } |
116 | }); |
117 | } |
118 | |
119 | template <typename data_t_src, typename data_t_diff_dst, |
120 | typename data_t_diff_weights, typename data_t_diff_bias> |
121 | class convolution_backward_weights_test |
122 | : public ::testing::TestWithParam<test_convolution_params_t> { |
123 | protected: |
124 | virtual void SetUp() { |
125 | auto p = ::testing::TestWithParam< |
126 | test_convolution_params_t>::GetParam(); |
127 | |
128 | SKIP_IF_CUDA( |
129 | !(cuda_check_format_tags(p.formats.src_format) |
130 | && cuda_check_format_tags(p.formats.dst_format) |
131 | && (cuda_check_format_tags(p.formats.weights_format) |
132 | || (impl::utils::one_of( |
133 | p.formats.weights_format, |
134 | /* weights formats */ |
135 | memory::format_tag::gowi, |
136 | memory::format_tag::gohwi, |
137 | memory::format_tag::godhwi, |
138 | memory::format_tag::owi, |
139 | memory::format_tag::ohwi, |
140 | memory::format_tag::odhwi))) |
141 | && data_traits<data_t_src>::data_type |
142 | == memory::data_type::f32 |
143 | && data_traits<data_t_diff_dst>::data_type |
144 | == memory::data_type::f32 |
145 | && data_traits<data_t_diff_weights>::data_type |
146 | == memory::data_type::f32 |
147 | && check_cuda_alg_format(p.formats.dst_format, |
148 | p.formats.weights_format, p.aalgorithm)), |
149 | "format is not supported." ); |
150 | |
151 | catch_expected_failures( |
152 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
153 | } |
154 | |
155 | bool cuda_check_format_tags(memory::format_tag tag) { |
156 | return impl::utils::one_of(tag, memory::format_tag::ab, |
157 | memory::format_tag::abc, memory::format_tag::abcd, |
158 | memory::format_tag::abcde, memory::format_tag::abcdef, |
159 | memory::format_tag::acb, memory::format_tag::acdb, |
160 | memory::format_tag::acdeb); |
161 | } |
162 | |
163 | bool check_cuda_alg_format(memory::format_tag dst_fmt, |
164 | memory::format_tag wei_fmt, algorithm alg) { |
165 | bool res = dst_fmt == wei_fmt; |
166 | if (alg == dnnl::algorithm::convolution_winograd) { |
167 | res = res |
168 | && impl::utils::one_of(wei_fmt, memory::format_tag::ab, |
169 | memory::format_tag::abc, memory::format_tag::abcd, |
170 | memory::format_tag::abcde, |
171 | memory::format_tag::abcdef); |
172 | } |
173 | return res; |
174 | } |
175 | |
176 | void Test() { |
177 | auto p = ::testing::TestWithParam< |
178 | test_convolution_params_t>::GetParam(); |
179 | |
180 | ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct); |
181 | auto eng = get_test_engine(); |
182 | auto strm = stream(eng); |
183 | memory::data_type data_type_src = data_traits<data_t_src>::data_type; |
184 | memory::data_type data_type_diff_dst |
185 | = data_traits<data_t_diff_dst>::data_type; |
186 | memory::data_type data_type_diff_weights |
187 | = data_traits<data_t_diff_weights>::data_type; |
188 | memory::data_type data_type_diff_bias |
189 | = data_traits<data_t_diff_bias>::data_type; |
190 | |
191 | test_convolution_sizes_t cd = p.sizes; |
192 | |
193 | bool with_bias = p.formats.bias_format != memory::format_tag::undef; |
194 | |
195 | auto c_src_desc = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, data_type_src, |
196 | p.formats.src_format); |
197 | auto c_diff_weights_desc = cd.ng > 1 |
198 | ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw}, |
199 | data_type_diff_weights, p.formats.weights_format) |
200 | : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, |
201 | data_type_diff_weights, p.formats.weights_format); |
202 | auto c_diff_bias_desc = create_md( |
203 | {cd.oc}, data_type_diff_bias, p.formats.bias_format); |
204 | auto c_diff_dst_desc = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, |
205 | data_type_diff_dst, p.formats.dst_format); |
206 | auto c_weights_desc_f = cd.ng > 1 |
207 | ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw}, |
208 | data_type_diff_dst, p.formats.weights_format) |
209 | : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type_diff_dst, |
210 | p.formats.weights_format); |
211 | auto c_dst_desc_f = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, |
212 | data_type_diff_weights, p.formats.dst_format); |
213 | auto c_src = test_memory(c_src_desc, eng); |
214 | auto c_diff_weights = test_memory(c_diff_weights_desc, eng); |
215 | auto c_diff_bias = test_memory(c_diff_bias_desc, eng); |
216 | auto c_diff_dst = test_memory(c_diff_dst_desc, eng); |
217 | auto weights_primitive_desc_f = test_memory(c_weights_desc_f, eng); |
218 | auto dst_primitive_desc_f = test_memory(c_dst_desc_f, eng); |
219 | fill_data<data_t_diff_dst>( |
220 | c_diff_dst.get_size() / sizeof(data_t_diff_dst), |
221 | c_diff_dst.get()); |
222 | fill_data<data_t_src>( |
223 | c_src.get_size() / sizeof(data_t_src), c_src.get()); |
224 | fill_data<data_t_diff_weights>( |
225 | c_diff_weights.get_size() / sizeof(data_t_diff_weights), |
226 | c_diff_weights.get()); |
227 | |
228 | check_zero_tail<data_t_diff_dst>(1, c_diff_dst.get()); |
229 | check_zero_tail<data_t_src>(1, c_src.get()); |
230 | check_zero_tail<data_t_diff_weights>(1, c_diff_weights.get()); |
231 | |
232 | memory::dims padR = { |
233 | right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh), |
234 | right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)}; |
235 | |
236 | auto conv_primitive_desc = convolution_forward::primitive_desc(eng, |
237 | prop_kind::forward_training, p.aalgorithm, c_src_desc, |
238 | c_weights_desc_f, c_diff_bias_desc, c_dst_desc_f, |
239 | {cd.strh, cd.strw}, {cd.dilh, cd.dilw}, {cd.padh, cd.padw}, |
240 | padR); |
241 | |
242 | auto conv_bwd_weights_primitive_desc |
243 | = convolution_backward_weights::primitive_desc(eng, |
244 | p.aalgorithm, c_src_desc, c_diff_weights_desc, |
245 | c_diff_bias_desc, c_diff_dst_desc, {cd.strh, cd.strw}, |
246 | {cd.dilh, cd.dilw}, {cd.padh, cd.padw}, padR, |
247 | conv_primitive_desc); |
248 | |
249 | conv_bwd_weights_primitive_desc |
250 | = convolution_backward_weights::primitive_desc( |
251 | conv_bwd_weights_primitive_desc |
252 | .get()); // test construction from a C pd |
253 | |
254 | ASSERT_TRUE(conv_bwd_weights_primitive_desc.query_md( |
255 | query::exec_arg_md, DNNL_ARG_SRC) |
256 | == conv_bwd_weights_primitive_desc.src_desc()); |
257 | ASSERT_TRUE(conv_bwd_weights_primitive_desc.query_md( |
258 | query::exec_arg_md, DNNL_ARG_DIFF_DST) |
259 | == conv_bwd_weights_primitive_desc.diff_dst_desc()); |
260 | ASSERT_TRUE(conv_bwd_weights_primitive_desc.query_md( |
261 | query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS) |
262 | == conv_bwd_weights_primitive_desc.diff_weights_desc()); |
263 | ASSERT_TRUE(conv_bwd_weights_primitive_desc.query_md( |
264 | query::exec_arg_md, DNNL_ARG_DIFF_BIAS) |
265 | == conv_bwd_weights_primitive_desc.diff_bias_desc()); |
266 | |
267 | EXPECT_ANY_THROW(convolution_backward_weights( |
268 | conv_bwd_weights_primitive_desc, {})); |
269 | convolution_backward_weights(conv_bwd_weights_primitive_desc) |
270 | .execute(strm, |
271 | {{DNNL_ARG_DIFF_DST, c_diff_dst.get()}, |
272 | {DNNL_ARG_SRC, c_src.get()}, |
273 | {DNNL_ARG_DIFF_WEIGHTS, c_diff_weights.get()}, |
274 | {DNNL_ARG_DIFF_BIAS, c_diff_bias.get()}}); |
275 | strm.wait(); |
276 | |
277 | auto ref_diff_weights = test::make_memory(c_diff_weights_desc, eng); |
278 | auto ref_diff_bias = test::make_memory(c_diff_bias_desc, eng); |
279 | |
280 | compute_ref_conv_bwd_weights<data_t_src, data_t_diff_dst, |
281 | data_t_diff_weights>( |
282 | cd, c_src.get(), c_diff_dst.get(), ref_diff_weights); |
283 | check_zero_tail<data_t_diff_weights>(1, ref_diff_weights); |
284 | compare_data<data_t_diff_weights>( |
285 | ref_diff_weights, c_diff_weights.get()); |
286 | check_zero_tail<data_t_diff_weights>(1, c_diff_weights.get()); |
287 | |
288 | if (with_bias) { |
289 | compute_ref_conv_bwd_bias<data_t_src, data_t_diff_dst, |
290 | data_t_diff_bias>(cd, c_diff_dst.get(), ref_diff_bias); |
291 | |
292 | compare_data<data_t_diff_bias>(ref_diff_bias, c_diff_bias.get()); |
293 | } |
294 | } |
295 | }; |
296 | |
297 | } // namespace dnnl |
298 | #endif |
299 | |