1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <memory> |
18 | |
19 | #include "dnnl_test_common.hpp" |
20 | #include "gtest/gtest.h" |
21 | |
22 | #include "oneapi/dnnl/dnnl.hpp" |
23 | #include "oneapi/dnnl/dnnl_debug.h" |
24 | namespace dnnl { |
25 | using fmt = memory::format_tag; |
26 | struct deconvolution_test_params_t { |
27 | dnnl::algorithm aalgorithm; |
28 | test_convolution_formats_t formats; |
29 | test_convolution_attr_t attr; |
30 | test_convolution_sizes_t sizes; |
31 | bool expect_to_fail; |
32 | dnnl_status_t expected_status; |
33 | }; |
34 | template <typename data_t> |
35 | void compute_bias_fwd(const test_convolution_sizes_t &c, |
36 | const dnnl::memory &dst, const dnnl::memory &bias) { |
37 | auto bias_data = map_memory<data_t>(bias); |
38 | auto dst_data = map_memory<data_t>(dst); |
39 | |
40 | const memory::desc bias_d = bias.get_desc(); |
41 | const memory::desc dst_d = dst.get_desc(); |
42 | const dnnl::impl::memory_desc_wrapper bias_mdw(bias_d.get()); |
43 | const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get()); |
44 | |
45 | dnnl::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow, |
46 | [&](memory::dim n, memory::dim g, memory::dim oc, memory::dim oh, |
47 | memory::dim ow) { |
48 | data_t b |
49 | = bias_data[bias_mdw.off_l(g * c.oc / c.ng + oc, true)]; |
50 | memory::dim oidx = n * c.oc * c.oh * c.ow |
51 | + g * c.oc / c.ng * c.oh * c.ow + oc * c.oh * c.ow |
52 | + oh * c.ow + ow; |
53 | dst_data[dst_mdw.off_l(oidx, true)] += b; |
54 | }); |
55 | } |
56 | |
57 | template <typename data_t> |
58 | void compute_bias_bwd(const test_convolution_sizes_t &c, |
59 | const dnnl::memory &dst, const dnnl::memory &bias) { |
60 | auto bias_data = map_memory<data_t>(bias); |
61 | auto dst_data = map_memory<data_t>(dst); |
62 | |
63 | const memory::desc bias_d = bias.get_desc(); |
64 | const memory::desc dst_d = dst.get_desc(); |
65 | const dnnl::impl::memory_desc_wrapper bias_mdw(bias_d.get()); |
66 | const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get()); |
67 | |
68 | dnnl::impl::parallel_nd( |
69 | c.ng, c.oc / c.ng, [&](memory::dim g, memory::dim oc) { |
70 | memory::dim bidx = g * c.oc / c.ng + oc; |
71 | bias_data[bias_mdw.off_l(bidx, true)] = 0.0; |
72 | for_(memory::dim mb = 0; mb < c.mb; ++mb) |
73 | for_(memory::dim oh = 0; oh < c.oh; ++oh) |
74 | for (memory::dim ow = 0; ow < c.ow; ++ow) { |
75 | memory::dim oidx = mb * c.oc * c.oh * c.ow |
76 | + g * c.oc / c.ng * c.oh * c.ow + oc * c.oh * c.ow |
77 | + oh * c.ow + ow; |
78 | bias_data[bias_mdw.off_l(bidx, true)] |
79 | += dst_data[dst_mdw.off_l(oidx, true)]; |
80 | } |
81 | }); |
82 | } |
83 | |
84 | template <typename data_t> |
85 | void transpose_wei(const test_convolution_sizes_t &c, |
86 | const dnnl::memory &weights, const dnnl::memory &weights_tr) { |
87 | |
88 | auto weights_data = map_memory<data_t>(weights); |
89 | const memory::desc weights_d = weights.get_desc(); |
90 | const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get()); |
91 | auto weights_tr_data = map_memory<data_t>(weights_tr); |
92 | const memory::desc weights_tr_d = weights_tr.get_desc(); |
93 | const dnnl::impl::memory_desc_wrapper weights_tr_mdw(weights_tr_d.get()); |
94 | |
95 | dnnl::impl::parallel_nd(c.ng, c.oc / c.ng, c.ic / c.ng, c.kh, c.kw, |
96 | [&](memory::dim g, memory::dim oc, memory::dim ic, memory::dim kh, |
97 | memory::dim kw) { |
98 | memory::dim widx = g * c.oc / c.ng * c.ic / c.ng * c.kh * c.kw |
99 | + oc * c.ic / c.ng * c.kh * c.kw + ic * c.kh * c.kw |
100 | + kh * c.kw + kw; |
101 | memory::dim widx_tr |
102 | = g * c.oc / c.ng * c.ic / c.ng * c.kh * c.kw |
103 | + ic * c.oc / c.ng * c.kh * c.kw + oc * c.kh * c.kw |
104 | + kh * c.kw + kw; |
105 | weights_tr_data[weights_tr_mdw.off_l(widx_tr, true)] |
106 | = weights_data[weights_mdw.off_l(widx, true)]; |
107 | }); |
108 | } |
109 | |
110 | template <typename data_t> |
111 | class deconvolution_test_t |
112 | : public ::testing::TestWithParam<deconvolution_test_params_t> { |
113 | private: |
114 | std::shared_ptr<test_memory> src; |
115 | std::shared_ptr<test_memory> weights; |
116 | std::shared_ptr<test_memory> dst; |
117 | std::shared_ptr<test_memory> bias; |
118 | |
119 | std::shared_ptr<memory::desc> dec_src_desc; |
120 | std::shared_ptr<memory::desc> dec_weights_desc; |
121 | std::shared_ptr<memory::desc> dec_bias_desc; |
122 | std::shared_ptr<memory::desc> dec_dst_desc; |
123 | |
124 | std::shared_ptr<memory::desc> con_src_desc; |
125 | std::shared_ptr<memory::desc> con_bias_desc; |
126 | std::shared_ptr<memory::desc> con_dst_desc; |
127 | std::shared_ptr<memory::desc> con_weights_desc; |
128 | |
129 | engine eng; |
130 | stream strm; |
131 | bool with_bias; |
132 | memory::dims padL; |
133 | memory::dims padR; |
134 | memory::dims strides; |
135 | |
136 | protected: |
137 | void SetUp() override { |
138 | memory::data_type data_type = data_traits<data_t>::data_type; |
139 | SKIP_IF(unsupported_data_type(data_type), |
140 | "Engine does not support this data type." ); |
141 | |
142 | auto p = ::testing::TestWithParam< |
143 | deconvolution_test_params_t>::GetParam(); |
144 | |
145 | SKIP_IF_CUDA( |
146 | !(cuda_check_format_tags(p.formats.src_format, data_type) |
147 | && cuda_check_format_tags( |
148 | p.formats.dst_format, data_type) |
149 | && cuda_check_src_wei_format_tags(p.formats.src_format, |
150 | p.formats.weights_format, p.sizes.ng > 1)), |
151 | "Format is not supported." ); |
152 | |
153 | catch_expected_failures( |
154 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
155 | } |
156 | |
157 | bool cuda_check_format_tags(memory::format_tag tag, memory::data_type dt) { |
158 | return ((impl::utils::one_of(tag, memory::format_tag::ab, |
159 | memory::format_tag::abc, memory::format_tag::abcd, |
160 | memory::format_tag::abcde, memory::format_tag::abcdef, |
161 | memory::format_tag::acb, memory::format_tag::acdb, |
162 | memory::format_tag::acdeb)) |
163 | || (dt == memory::data_type::s8 |
164 | && impl::utils::one_of(tag, memory::format_tag::aBcd4b, |
165 | memory::format_tag::aBcde4b))); |
166 | } |
167 | |
168 | bool cuda_check_src_wei_format_tags( |
169 | memory::format_tag src, memory::format_tag wei, bool is_grouped) { |
170 | if (src == memory::format_tag::abcd) return true; |
171 | if (src == memory::format_tag::acdb) |
172 | return wei |
173 | != (is_grouped ? memory::format_tag::abcde |
174 | : memory::format_tag::abcd); |
175 | return false; |
176 | } |
177 | |
178 | void Test() { |
179 | auto p = ::testing::TestWithParam< |
180 | deconvolution_test_params_t>::GetParam(); |
181 | |
182 | eng = get_test_engine(); |
183 | strm = make_stream(eng); |
184 | |
185 | ASSERT_EQ(p.aalgorithm, algorithm::deconvolution_direct); |
186 | memory::data_type data_type = data_traits<data_t>::data_type; |
187 | |
188 | test_convolution_sizes_t dd = p.sizes; |
189 | with_bias = p.formats.bias_format != memory::format_tag::undef; |
190 | |
191 | memory::dims src_dims = {dd.mb, dd.ic, dd.ih, dd.iw}; |
192 | memory::dims dst_dims = {dd.mb, dd.oc, dd.oh, dd.ow}; |
193 | memory::dims weights_dims, c_weights_dims; |
194 | if (dd.ng > 1) { |
195 | weights_dims = {dd.ng, dd.oc / dd.ng, dd.ic / dd.ng, dd.kh, dd.kw}; |
196 | c_weights_dims |
197 | = {dd.ng, dd.ic / dd.ng, dd.oc / dd.ng, dd.kh, dd.kw}; |
198 | } else { |
199 | weights_dims = {dd.oc, dd.ic, dd.kh, dd.kw}; |
200 | c_weights_dims = {dd.ic, dd.oc, dd.kh, dd.kw}; |
201 | } |
202 | memory::dims bias_dims; |
203 | if (with_bias) |
204 | bias_dims = {dd.oc}; |
205 | else |
206 | bias_dims = {}; |
207 | |
208 | dec_src_desc = std::make_shared<memory::desc>( |
209 | src_dims, data_type, p.formats.src_format); |
210 | dec_dst_desc = std::make_shared<memory::desc>( |
211 | dst_dims, data_type, p.formats.src_format); |
212 | dec_weights_desc = std::make_shared<memory::desc>( |
213 | weights_dims, data_type, p.formats.weights_format); |
214 | dec_bias_desc = std::make_shared<memory::desc>( |
215 | bias_dims, data_type, p.formats.bias_format); |
216 | |
217 | con_src_desc = std::make_shared<memory::desc>( |
218 | dst_dims, data_type, p.formats.src_format); |
219 | con_dst_desc = std::make_shared<memory::desc>( |
220 | src_dims, data_type, p.formats.src_format); |
221 | con_weights_desc = std::make_shared<memory::desc>( |
222 | c_weights_dims, data_type, p.formats.weights_format); |
223 | |
224 | src = std::make_shared<test_memory>(*dec_src_desc, eng); |
225 | weights = std::make_shared<test_memory>(*dec_weights_desc, eng); |
226 | bias = std::make_shared<test_memory>(*dec_bias_desc, eng); |
227 | dst = std::make_shared<test_memory>(*dec_dst_desc, eng); |
228 | |
229 | strides = {dd.strh, dd.strw}; |
230 | padL = {dd.padh, dd.padw}; |
231 | padR = {right_padding(dd.oh, dd.ih, dd.kh, dd.padh, dd.strh, dd.dilh), |
232 | right_padding(dd.ow, dd.iw, dd.kw, dd.padw, dd.strw, dd.dilw)}; |
233 | SKIP_IF_CUDA(p.sizes.padh < padR[0] || p.sizes.padw < padR[1], |
234 | "Padding not supported" ); |
235 | Forward(); |
236 | BackwardData(); |
237 | BackwardWeights(); |
238 | } |
239 | |
240 | void Forward() { |
241 | auto aprop_kind = prop_kind::forward; |
242 | deconvolution_test_params_t p = ::testing::TestWithParam< |
243 | deconvolution_test_params_t>::GetParam(); |
244 | auto conv_src = test_memory(*con_src_desc, eng); |
245 | auto conv_dst = src; |
246 | test_convolution_sizes_t dd = p.sizes; |
247 | |
248 | fill_data<data_t>(src->get_size() / sizeof(data_t), src->get()); |
249 | |
250 | fill_data<data_t>(weights->get_size() / sizeof(data_t), weights->get()); |
251 | if (with_bias) { |
252 | fill_data<data_t>(bias->get_size() / sizeof(data_t), bias->get()); |
253 | } |
254 | |
255 | auto weights_tr = test::make_memory(*con_weights_desc, eng); |
256 | transpose_wei<data_t>(dd, weights->get(), weights_tr); |
257 | auto deconv_primitive_desc = with_bias |
258 | ? deconvolution_forward::primitive_desc(eng, aprop_kind, |
259 | algorithm::deconvolution_direct, *dec_src_desc, |
260 | *dec_weights_desc, *dec_bias_desc, *dec_dst_desc, |
261 | strides, padL, padR) |
262 | : deconvolution_forward::primitive_desc(eng, aprop_kind, |
263 | algorithm::deconvolution_direct, *dec_src_desc, |
264 | *dec_weights_desc, *dec_dst_desc, strides, padL, padR); |
265 | |
266 | deconv_primitive_desc = deconvolution_forward::primitive_desc( |
267 | deconv_primitive_desc.get()); // test construction from a C pd |
268 | |
269 | ASSERT_TRUE( |
270 | deconv_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC) |
271 | == deconv_primitive_desc.src_desc()); |
272 | ASSERT_TRUE( |
273 | deconv_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_DST) |
274 | == deconv_primitive_desc.dst_desc()); |
275 | ASSERT_TRUE(deconv_primitive_desc.query_md( |
276 | query::exec_arg_md, DNNL_ARG_WEIGHTS) |
277 | == deconv_primitive_desc.weights_desc()); |
278 | ASSERT_TRUE(deconv_primitive_desc.query_md( |
279 | query::exec_arg_md, DNNL_ARG_BIAS) |
280 | == deconv_primitive_desc.bias_desc()); |
281 | |
282 | ASSERT_EQ(deconv_primitive_desc.get_algorithm(), |
283 | algorithm::deconvolution_direct); |
284 | ASSERT_EQ(deconv_primitive_desc.get_prop_kind(), aprop_kind); |
285 | ASSERT_EQ(deconv_primitive_desc.get_strides(), strides); |
286 | ASSERT_EQ(deconv_primitive_desc.get_padding_l(), padL); |
287 | ASSERT_EQ(deconv_primitive_desc.get_padding_r(), padR); |
288 | |
289 | EXPECT_ANY_THROW(deconvolution_forward(deconv_primitive_desc, {})); |
290 | deconvolution_forward(deconv_primitive_desc) |
291 | .execute(strm, |
292 | {{DNNL_ARG_SRC, src->get()}, |
293 | {DNNL_ARG_WEIGHTS, weights->get()}, |
294 | {DNNL_ARG_BIAS, bias->get()}, |
295 | {DNNL_ARG_DST, dst->get()}}); |
296 | strm.wait(); |
297 | |
298 | auto conv_primitive_desc = convolution_forward::primitive_desc(eng, |
299 | prop_kind::forward_training, algorithm::convolution_direct, |
300 | *con_src_desc, *con_weights_desc, *con_dst_desc, strides, padL, |
301 | padR); |
302 | |
303 | auto conv_bwd_data_primitive_desc |
304 | = convolution_backward_data::primitive_desc(eng, |
305 | algorithm::convolution_direct, *con_src_desc, |
306 | *con_weights_desc, *con_dst_desc, strides, padL, padR, |
307 | conv_primitive_desc); |
308 | |
309 | convolution_backward_data(conv_bwd_data_primitive_desc) |
310 | .execute(strm, |
311 | {{DNNL_ARG_DIFF_DST, conv_dst->get()}, |
312 | {DNNL_ARG_WEIGHTS, weights_tr}, |
313 | {DNNL_ARG_DIFF_SRC, conv_src.get()}}); |
314 | strm.wait(); |
315 | |
316 | if (with_bias) |
317 | compute_bias_fwd<data_t>(dd, conv_src.get(), bias->get()); |
318 | compare_data<data_t>(conv_src.get(), dst->get()); |
319 | } |
320 | |
321 | void BackwardData() { |
322 | auto p = ::testing::TestWithParam< |
323 | deconvolution_test_params_t>::GetParam(); |
324 | auto conv_src = dst; |
325 | auto conv_dst = test_memory(*con_dst_desc, eng); |
326 | test_convolution_sizes_t dd = p.sizes; |
327 | |
328 | fill_data<data_t>(weights->get_size() / sizeof(data_t), weights->get()); |
329 | |
330 | fill_data<data_t>(dst->get_size() / sizeof(data_t), dst->get()); |
331 | |
332 | auto weights_tr = test::make_memory(*con_weights_desc, eng); |
333 | transpose_wei<data_t>(dd, weights->get(), weights_tr); |
334 | |
335 | auto deconv_primitive_desc = deconvolution_forward::primitive_desc(eng, |
336 | prop_kind::forward_training, algorithm::deconvolution_direct, |
337 | *dec_src_desc, *dec_weights_desc, *dec_dst_desc, strides, padL, |
338 | padR); |
339 | |
340 | auto deconv_bwd_data_primitive_desc |
341 | = deconvolution_backward_data::primitive_desc(eng, |
342 | algorithm::deconvolution_direct, *dec_src_desc, |
343 | *dec_weights_desc, *dec_dst_desc, strides, padL, padR, |
344 | deconv_primitive_desc); |
345 | deconv_bwd_data_primitive_desc |
346 | = deconvolution_backward_data::primitive_desc( |
347 | deconv_bwd_data_primitive_desc |
348 | .get()); // test construction from a C pd |
349 | |
350 | ASSERT_TRUE(deconv_bwd_data_primitive_desc.query_md( |
351 | query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
352 | == deconv_bwd_data_primitive_desc.diff_src_desc()); |
353 | ASSERT_TRUE(deconv_bwd_data_primitive_desc.query_md( |
354 | query::exec_arg_md, DNNL_ARG_DIFF_DST) |
355 | == deconv_bwd_data_primitive_desc.diff_dst_desc()); |
356 | ASSERT_TRUE(deconv_bwd_data_primitive_desc.query_md( |
357 | query::exec_arg_md, DNNL_ARG_WEIGHTS) |
358 | == deconv_bwd_data_primitive_desc.weights_desc()); |
359 | |
360 | ASSERT_EQ(deconv_bwd_data_primitive_desc.get_algorithm(), |
361 | algorithm::deconvolution_direct); |
362 | ASSERT_EQ(deconv_bwd_data_primitive_desc.get_prop_kind(), |
363 | prop_kind::backward_data); |
364 | ASSERT_EQ(deconv_bwd_data_primitive_desc.get_strides(), strides); |
365 | ASSERT_EQ(deconv_bwd_data_primitive_desc.get_padding_l(), padL); |
366 | ASSERT_EQ(deconv_bwd_data_primitive_desc.get_padding_r(), padR); |
367 | |
368 | deconvolution_backward_data(deconv_bwd_data_primitive_desc) |
369 | .execute(strm, |
370 | {{DNNL_ARG_DIFF_DST, dst->get()}, |
371 | {DNNL_ARG_WEIGHTS, weights->get()}, |
372 | {DNNL_ARG_DIFF_SRC, src->get()}}); |
373 | strm.wait(); |
374 | |
375 | auto conv_primitive_desc = convolution_forward::primitive_desc(eng, |
376 | prop_kind::forward_training, algorithm::convolution_direct, |
377 | *con_src_desc, *con_weights_desc, *con_dst_desc, strides, padL, |
378 | padR); |
379 | |
380 | convolution_forward(conv_primitive_desc) |
381 | .execute(strm, |
382 | {{DNNL_ARG_SRC, conv_src->get()}, |
383 | {DNNL_ARG_WEIGHTS, weights_tr}, |
384 | {DNNL_ARG_DST, conv_dst.get()}}); |
385 | strm.wait(); |
386 | |
387 | compare_data<data_t>(conv_dst.get(), src->get()); |
388 | } |
389 | |
390 | void BackwardWeights() { |
391 | auto p = ::testing::TestWithParam< |
392 | deconvolution_test_params_t>::GetParam(); |
393 | auto conv_src = dst; |
394 | auto conv_dst = src; |
395 | auto conv_weights = test::make_memory(*con_weights_desc, eng); |
396 | test_convolution_sizes_t dd = p.sizes; |
397 | |
398 | fill_data<data_t>(src->get_size() / sizeof(data_t), src->get()); |
399 | |
400 | fill_data<data_t>(dst->get_size() / sizeof(data_t), dst->get()); |
401 | |
402 | auto deconv_primitive_desc = deconvolution_forward::primitive_desc(eng, |
403 | prop_kind::forward_training, algorithm::deconvolution_direct, |
404 | *dec_src_desc, *dec_weights_desc, *dec_bias_desc, *dec_dst_desc, |
405 | {dd.strh, dd.strw}, {dd.padh, dd.padw}, padR); |
406 | |
407 | auto deconv_bwd_weights_primitive_desc |
408 | = deconvolution_backward_weights::primitive_desc(eng, |
409 | algorithm::deconvolution_direct, *dec_src_desc, |
410 | *dec_weights_desc, *dec_bias_desc, *dec_dst_desc, |
411 | strides, padL, padR, deconv_primitive_desc); |
412 | |
413 | ASSERT_TRUE(deconv_bwd_weights_primitive_desc.query_md( |
414 | query::exec_arg_md, DNNL_ARG_SRC) |
415 | == deconv_bwd_weights_primitive_desc.src_desc()); |
416 | ASSERT_TRUE(deconv_bwd_weights_primitive_desc.query_md( |
417 | query::exec_arg_md, DNNL_ARG_DIFF_DST) |
418 | == deconv_bwd_weights_primitive_desc.diff_dst_desc()); |
419 | ASSERT_TRUE(deconv_bwd_weights_primitive_desc.query_md( |
420 | query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS) |
421 | == deconv_bwd_weights_primitive_desc.diff_weights_desc()); |
422 | ASSERT_TRUE(deconv_bwd_weights_primitive_desc.query_md( |
423 | query::exec_arg_md, DNNL_ARG_DIFF_BIAS) |
424 | == deconv_bwd_weights_primitive_desc.diff_bias_desc()); |
425 | |
426 | ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_algorithm(), |
427 | algorithm::deconvolution_direct); |
428 | ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_prop_kind(), |
429 | prop_kind::backward_weights); |
430 | ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_strides(), strides); |
431 | ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_padding_l(), padL); |
432 | ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_padding_r(), padR); |
433 | |
434 | deconvolution_backward_weights(deconv_bwd_weights_primitive_desc) |
435 | .execute(strm, |
436 | {{DNNL_ARG_DIFF_DST, dst->get()}, |
437 | {DNNL_ARG_SRC, src->get()}, |
438 | {DNNL_ARG_DIFF_WEIGHTS, weights->get()}, |
439 | {DNNL_ARG_DIFF_BIAS, bias->get()}}); |
440 | strm.wait(); |
441 | |
442 | auto conv_primitive_desc = convolution_forward::primitive_desc(eng, |
443 | prop_kind::forward_training, algorithm::convolution_direct, |
444 | *con_src_desc, *con_weights_desc, *con_dst_desc, strides, padL, |
445 | padR); |
446 | |
447 | deconv_bwd_weights_primitive_desc |
448 | = deconvolution_backward_weights::primitive_desc( |
449 | deconv_bwd_weights_primitive_desc |
450 | .get()); // test construction from a C pd |
451 | |
452 | auto conv_bwd_weights_primitive_desc |
453 | = convolution_backward_weights::primitive_desc(eng, |
454 | algorithm::convolution_direct, *con_src_desc, |
455 | *con_weights_desc, *con_dst_desc, strides, padL, padR, |
456 | conv_primitive_desc); |
457 | |
458 | convolution_backward_weights(conv_bwd_weights_primitive_desc) |
459 | .execute(strm, |
460 | {{DNNL_ARG_DIFF_DST, conv_dst->get()}, |
461 | {DNNL_ARG_SRC, conv_src->get()}, |
462 | {DNNL_ARG_DIFF_WEIGHTS, conv_weights}}); |
463 | strm.wait(); |
464 | |
465 | auto weights_tr = test::make_memory(*con_weights_desc, eng); |
466 | transpose_wei<data_t>(dd, weights->get(), weights_tr); |
467 | |
468 | compare_data<data_t>(weights_tr, conv_weights); |
469 | |
470 | if (with_bias) { |
471 | auto ref_bias = test::make_memory(*dec_bias_desc, eng); |
472 | compute_bias_bwd<data_t>(dd, dst->get(), ref_bias); |
473 | compare_data<data_t>(ref_bias, bias->get()); |
474 | } |
475 | } |
476 | }; |
477 | |
478 | using deconvolution_test_float = deconvolution_test_t<float>; |
479 | |
480 | TEST_P(deconvolution_test_float, TestDeconvolution) {} |
481 | |
482 | #define EXPAND_FORMATS(src, weights, bias, dst) \ |
483 | { \ |
484 | dnnl::memory::format_tag::src, dnnl::memory::format_tag::weights, \ |
485 | dnnl::memory::format_tag::bias, dnnl::memory::format_tag::dst \ |
486 | } |
487 | |
488 | #define ALGORITHM dnnl::algorithm::deconvolution_direct |
489 | |
490 | #define PARAMS(src, weights, bias, dst, ...) \ |
491 | deconvolution_test_params_t { \ |
492 | ALGORITHM, EXPAND_FORMATS(src, weights, bias, dst), {}, { \ |
493 | __VA_ARGS__ \ |
494 | } \ |
495 | } |
496 | |
497 | #define CPU_INST_TEST_CASE(str, ...) \ |
498 | CPU_INSTANTIATE_TEST_SUITE_P( \ |
499 | str, deconvolution_test_float, ::testing::Values(__VA_ARGS__)) |
500 | #define GPU_INST_TEST_CASE(str, ...) \ |
501 | GPU_INSTANTIATE_TEST_SUITE_P( \ |
502 | str, deconvolution_test_float, ::testing::Values(__VA_ARGS__)) |
503 | |
504 | #define FMT_BIAS x |
505 | #define FMT_DATA_BLOCKED nChw8c |
506 | #define FMT_WEIGHTS_BLOCKED Ohwi8o |
507 | #define FMT_DATA_BLOCKED_GPU NChw16n16c |
508 | #define FMT_WEIGHTS_BLOCKED_GPU IOhw16i16o |
509 | |
510 | CPU_INST_TEST_CASE(SimpleSmall_NCHW, |
511 | PARAMS(nchw, oihw, x, nchw, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1), |
512 | PARAMS(nchw, oihw, x, nchw, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
513 | PARAMS(nhwc, oihw, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
514 | PARAMS(nhwc, hwio, x, nhwc, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1), |
515 | PARAMS(nhwc, hwio, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
516 | PARAMS(nhwc, goihw, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
517 | PARAMS(nhwc, hwigo, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1) |
518 | |
519 | ); |
520 | |
521 | CPU_INST_TEST_CASE(SimpleSmall_Blocked, |
522 | PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, |
523 | FMT_DATA_BLOCKED, 2, 1, 32, 12, 12, 32, 13, 13, 3, 3, 0, 0, 1, |
524 | 1), |
525 | PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, |
526 | FMT_DATA_BLOCKED, 2, 1, 32, 4, 4, 32, 3, 3, 3, 3, 1, 1, 1, 1), |
527 | PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, |
528 | FMT_DATA_BLOCKED, 2, 1, 32, 4, 4, 32, 4, 4, 3, 3, 0, 0, 1, 1), |
529 | PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, |
530 | FMT_DATA_BLOCKED, 2, 1, 32, 2, 2, 32, 3, 3, 3, 3, 0, 0, 1, 1), |
531 | PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, |
532 | FMT_DATA_BLOCKED, 2, 1, 32, 2, 2, 32, 2, 2, 3, 3, 1, 1, 1, 1), |
533 | PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, |
534 | FMT_DATA_BLOCKED, 2, 1, 48, 13, 13, 32, 13, 13, 3, 3, 1, 1, 1, |
535 | 1), |
536 | PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, |
537 | FMT_DATA_BLOCKED, 2, 1, 48, 11, 11, 32, 13, 13, 3, 3, 0, 0, 1, |
538 | 1)); |
539 | |
540 | GPU_INST_TEST_CASE(SimpleSmall_Blocked, |
541 | PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS, |
542 | FMT_DATA_BLOCKED_GPU, 32, 1, 32, 12, 12, 32, 10, 10, 3, 3, 0, 0, |
543 | 1, 1), |
544 | PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS, |
545 | FMT_DATA_BLOCKED_GPU, 32, 1, 32, 4, 4, 32, 3, 3, 3, 3, 1, 1, 1, |
546 | 1), |
547 | PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS, |
548 | FMT_DATA_BLOCKED_GPU, 32, 1, 32, 4, 4, 32, 4, 4, 3, 3, 0, 0, 1, |
549 | 1), |
550 | PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS, |
551 | FMT_DATA_BLOCKED_GPU, 32, 1, 32, 2, 2, 32, 3, 3, 3, 3, 0, 0, 1, |
552 | 1), |
553 | PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS, |
554 | FMT_DATA_BLOCKED_GPU, 32, 1, 32, 2, 2, 32, 2, 2, 3, 3, 1, 1, 1, |
555 | 1), |
556 | PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS, |
557 | FMT_DATA_BLOCKED_GPU, 32, 1, 48, 13, 13, 32, 13, 13, 3, 3, 1, 1, |
558 | 1, 1), |
559 | PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS, |
560 | FMT_DATA_BLOCKED_GPU, 32, 1, 48, 11, 11, 32, 13, 13, 3, 3, 0, 0, |
561 | 1, 1)); |
562 | |
563 | GPU_INST_TEST_CASE(SimpleSmall_NCHW, |
564 | PARAMS(nchw, oihw, x, nchw, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1), |
565 | PARAMS(nchw, oihw, x, nchw, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
566 | PARAMS(nhwc, oihw, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
567 | PARAMS(nhwc, hwio, x, nhwc, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1), |
568 | PARAMS(nhwc, hwio, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
569 | PARAMS(nhwc, goihw, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
570 | PARAMS(nhwc, hwigo, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1)); |
571 | |
572 | GPU_INST_TEST_CASE(SimpleSmall_NHWC, |
573 | PARAMS(nchw, oihw, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
574 | PARAMS(nhwc, ohwi, x, nhwc, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1), |
575 | PARAMS(nhwc, ohwi, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1), |
576 | PARAMS(nchw, goihw, x, nchw, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1), |
577 | PARAMS(nchw, goihw, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1), |
578 | PARAMS(nhwc, gohwi, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1)); |
579 | } // namespace dnnl |
580 | |