1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <memory>
18
19#include "dnnl_test_common.hpp"
20#include "gtest/gtest.h"
21
22#include "oneapi/dnnl/dnnl.hpp"
23#include "oneapi/dnnl/dnnl_debug.h"
24namespace dnnl {
25using fmt = memory::format_tag;
26struct deconvolution_test_params_t {
27 dnnl::algorithm aalgorithm;
28 test_convolution_formats_t formats;
29 test_convolution_attr_t attr;
30 test_convolution_sizes_t sizes;
31 bool expect_to_fail;
32 dnnl_status_t expected_status;
33};
34template <typename data_t>
35void compute_bias_fwd(const test_convolution_sizes_t &c,
36 const dnnl::memory &dst, const dnnl::memory &bias) {
37 auto bias_data = map_memory<data_t>(bias);
38 auto dst_data = map_memory<data_t>(dst);
39
40 const memory::desc bias_d = bias.get_desc();
41 const memory::desc dst_d = dst.get_desc();
42 const dnnl::impl::memory_desc_wrapper bias_mdw(bias_d.get());
43 const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get());
44
45 dnnl::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
46 [&](memory::dim n, memory::dim g, memory::dim oc, memory::dim oh,
47 memory::dim ow) {
48 data_t b
49 = bias_data[bias_mdw.off_l(g * c.oc / c.ng + oc, true)];
50 memory::dim oidx = n * c.oc * c.oh * c.ow
51 + g * c.oc / c.ng * c.oh * c.ow + oc * c.oh * c.ow
52 + oh * c.ow + ow;
53 dst_data[dst_mdw.off_l(oidx, true)] += b;
54 });
55}
56
57template <typename data_t>
58void compute_bias_bwd(const test_convolution_sizes_t &c,
59 const dnnl::memory &dst, const dnnl::memory &bias) {
60 auto bias_data = map_memory<data_t>(bias);
61 auto dst_data = map_memory<data_t>(dst);
62
63 const memory::desc bias_d = bias.get_desc();
64 const memory::desc dst_d = dst.get_desc();
65 const dnnl::impl::memory_desc_wrapper bias_mdw(bias_d.get());
66 const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get());
67
68 dnnl::impl::parallel_nd(
69 c.ng, c.oc / c.ng, [&](memory::dim g, memory::dim oc) {
70 memory::dim bidx = g * c.oc / c.ng + oc;
71 bias_data[bias_mdw.off_l(bidx, true)] = 0.0;
72 for_(memory::dim mb = 0; mb < c.mb; ++mb)
73 for_(memory::dim oh = 0; oh < c.oh; ++oh)
74 for (memory::dim ow = 0; ow < c.ow; ++ow) {
75 memory::dim oidx = mb * c.oc * c.oh * c.ow
76 + g * c.oc / c.ng * c.oh * c.ow + oc * c.oh * c.ow
77 + oh * c.ow + ow;
78 bias_data[bias_mdw.off_l(bidx, true)]
79 += dst_data[dst_mdw.off_l(oidx, true)];
80 }
81 });
82}
83
84template <typename data_t>
85void transpose_wei(const test_convolution_sizes_t &c,
86 const dnnl::memory &weights, const dnnl::memory &weights_tr) {
87
88 auto weights_data = map_memory<data_t>(weights);
89 const memory::desc weights_d = weights.get_desc();
90 const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get());
91 auto weights_tr_data = map_memory<data_t>(weights_tr);
92 const memory::desc weights_tr_d = weights_tr.get_desc();
93 const dnnl::impl::memory_desc_wrapper weights_tr_mdw(weights_tr_d.get());
94
95 dnnl::impl::parallel_nd(c.ng, c.oc / c.ng, c.ic / c.ng, c.kh, c.kw,
96 [&](memory::dim g, memory::dim oc, memory::dim ic, memory::dim kh,
97 memory::dim kw) {
98 memory::dim widx = g * c.oc / c.ng * c.ic / c.ng * c.kh * c.kw
99 + oc * c.ic / c.ng * c.kh * c.kw + ic * c.kh * c.kw
100 + kh * c.kw + kw;
101 memory::dim widx_tr
102 = g * c.oc / c.ng * c.ic / c.ng * c.kh * c.kw
103 + ic * c.oc / c.ng * c.kh * c.kw + oc * c.kh * c.kw
104 + kh * c.kw + kw;
105 weights_tr_data[weights_tr_mdw.off_l(widx_tr, true)]
106 = weights_data[weights_mdw.off_l(widx, true)];
107 });
108}
109
110template <typename data_t>
111class deconvolution_test_t
112 : public ::testing::TestWithParam<deconvolution_test_params_t> {
113private:
114 std::shared_ptr<test_memory> src;
115 std::shared_ptr<test_memory> weights;
116 std::shared_ptr<test_memory> dst;
117 std::shared_ptr<test_memory> bias;
118
119 std::shared_ptr<memory::desc> dec_src_desc;
120 std::shared_ptr<memory::desc> dec_weights_desc;
121 std::shared_ptr<memory::desc> dec_bias_desc;
122 std::shared_ptr<memory::desc> dec_dst_desc;
123
124 std::shared_ptr<memory::desc> con_src_desc;
125 std::shared_ptr<memory::desc> con_bias_desc;
126 std::shared_ptr<memory::desc> con_dst_desc;
127 std::shared_ptr<memory::desc> con_weights_desc;
128
129 engine eng;
130 stream strm;
131 bool with_bias;
132 memory::dims padL;
133 memory::dims padR;
134 memory::dims strides;
135
136protected:
137 void SetUp() override {
138 memory::data_type data_type = data_traits<data_t>::data_type;
139 SKIP_IF(unsupported_data_type(data_type),
140 "Engine does not support this data type.");
141
142 auto p = ::testing::TestWithParam<
143 deconvolution_test_params_t>::GetParam();
144
145 SKIP_IF_CUDA(
146 !(cuda_check_format_tags(p.formats.src_format, data_type)
147 && cuda_check_format_tags(
148 p.formats.dst_format, data_type)
149 && cuda_check_src_wei_format_tags(p.formats.src_format,
150 p.formats.weights_format, p.sizes.ng > 1)),
151 "Format is not supported.");
152
153 catch_expected_failures(
154 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
155 }
156
157 bool cuda_check_format_tags(memory::format_tag tag, memory::data_type dt) {
158 return ((impl::utils::one_of(tag, memory::format_tag::ab,
159 memory::format_tag::abc, memory::format_tag::abcd,
160 memory::format_tag::abcde, memory::format_tag::abcdef,
161 memory::format_tag::acb, memory::format_tag::acdb,
162 memory::format_tag::acdeb))
163 || (dt == memory::data_type::s8
164 && impl::utils::one_of(tag, memory::format_tag::aBcd4b,
165 memory::format_tag::aBcde4b)));
166 }
167
168 bool cuda_check_src_wei_format_tags(
169 memory::format_tag src, memory::format_tag wei, bool is_grouped) {
170 if (src == memory::format_tag::abcd) return true;
171 if (src == memory::format_tag::acdb)
172 return wei
173 != (is_grouped ? memory::format_tag::abcde
174 : memory::format_tag::abcd);
175 return false;
176 }
177
178 void Test() {
179 auto p = ::testing::TestWithParam<
180 deconvolution_test_params_t>::GetParam();
181
182 eng = get_test_engine();
183 strm = make_stream(eng);
184
185 ASSERT_EQ(p.aalgorithm, algorithm::deconvolution_direct);
186 memory::data_type data_type = data_traits<data_t>::data_type;
187
188 test_convolution_sizes_t dd = p.sizes;
189 with_bias = p.formats.bias_format != memory::format_tag::undef;
190
191 memory::dims src_dims = {dd.mb, dd.ic, dd.ih, dd.iw};
192 memory::dims dst_dims = {dd.mb, dd.oc, dd.oh, dd.ow};
193 memory::dims weights_dims, c_weights_dims;
194 if (dd.ng > 1) {
195 weights_dims = {dd.ng, dd.oc / dd.ng, dd.ic / dd.ng, dd.kh, dd.kw};
196 c_weights_dims
197 = {dd.ng, dd.ic / dd.ng, dd.oc / dd.ng, dd.kh, dd.kw};
198 } else {
199 weights_dims = {dd.oc, dd.ic, dd.kh, dd.kw};
200 c_weights_dims = {dd.ic, dd.oc, dd.kh, dd.kw};
201 }
202 memory::dims bias_dims;
203 if (with_bias)
204 bias_dims = {dd.oc};
205 else
206 bias_dims = {};
207
208 dec_src_desc = std::make_shared<memory::desc>(
209 src_dims, data_type, p.formats.src_format);
210 dec_dst_desc = std::make_shared<memory::desc>(
211 dst_dims, data_type, p.formats.src_format);
212 dec_weights_desc = std::make_shared<memory::desc>(
213 weights_dims, data_type, p.formats.weights_format);
214 dec_bias_desc = std::make_shared<memory::desc>(
215 bias_dims, data_type, p.formats.bias_format);
216
217 con_src_desc = std::make_shared<memory::desc>(
218 dst_dims, data_type, p.formats.src_format);
219 con_dst_desc = std::make_shared<memory::desc>(
220 src_dims, data_type, p.formats.src_format);
221 con_weights_desc = std::make_shared<memory::desc>(
222 c_weights_dims, data_type, p.formats.weights_format);
223
224 src = std::make_shared<test_memory>(*dec_src_desc, eng);
225 weights = std::make_shared<test_memory>(*dec_weights_desc, eng);
226 bias = std::make_shared<test_memory>(*dec_bias_desc, eng);
227 dst = std::make_shared<test_memory>(*dec_dst_desc, eng);
228
229 strides = {dd.strh, dd.strw};
230 padL = {dd.padh, dd.padw};
231 padR = {right_padding(dd.oh, dd.ih, dd.kh, dd.padh, dd.strh, dd.dilh),
232 right_padding(dd.ow, dd.iw, dd.kw, dd.padw, dd.strw, dd.dilw)};
233 SKIP_IF_CUDA(p.sizes.padh < padR[0] || p.sizes.padw < padR[1],
234 "Padding not supported");
235 Forward();
236 BackwardData();
237 BackwardWeights();
238 }
239
240 void Forward() {
241 auto aprop_kind = prop_kind::forward;
242 deconvolution_test_params_t p = ::testing::TestWithParam<
243 deconvolution_test_params_t>::GetParam();
244 auto conv_src = test_memory(*con_src_desc, eng);
245 auto conv_dst = src;
246 test_convolution_sizes_t dd = p.sizes;
247
248 fill_data<data_t>(src->get_size() / sizeof(data_t), src->get());
249
250 fill_data<data_t>(weights->get_size() / sizeof(data_t), weights->get());
251 if (with_bias) {
252 fill_data<data_t>(bias->get_size() / sizeof(data_t), bias->get());
253 }
254
255 auto weights_tr = test::make_memory(*con_weights_desc, eng);
256 transpose_wei<data_t>(dd, weights->get(), weights_tr);
257 auto deconv_primitive_desc = with_bias
258 ? deconvolution_forward::primitive_desc(eng, aprop_kind,
259 algorithm::deconvolution_direct, *dec_src_desc,
260 *dec_weights_desc, *dec_bias_desc, *dec_dst_desc,
261 strides, padL, padR)
262 : deconvolution_forward::primitive_desc(eng, aprop_kind,
263 algorithm::deconvolution_direct, *dec_src_desc,
264 *dec_weights_desc, *dec_dst_desc, strides, padL, padR);
265
266 deconv_primitive_desc = deconvolution_forward::primitive_desc(
267 deconv_primitive_desc.get()); // test construction from a C pd
268
269 ASSERT_TRUE(
270 deconv_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
271 == deconv_primitive_desc.src_desc());
272 ASSERT_TRUE(
273 deconv_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_DST)
274 == deconv_primitive_desc.dst_desc());
275 ASSERT_TRUE(deconv_primitive_desc.query_md(
276 query::exec_arg_md, DNNL_ARG_WEIGHTS)
277 == deconv_primitive_desc.weights_desc());
278 ASSERT_TRUE(deconv_primitive_desc.query_md(
279 query::exec_arg_md, DNNL_ARG_BIAS)
280 == deconv_primitive_desc.bias_desc());
281
282 ASSERT_EQ(deconv_primitive_desc.get_algorithm(),
283 algorithm::deconvolution_direct);
284 ASSERT_EQ(deconv_primitive_desc.get_prop_kind(), aprop_kind);
285 ASSERT_EQ(deconv_primitive_desc.get_strides(), strides);
286 ASSERT_EQ(deconv_primitive_desc.get_padding_l(), padL);
287 ASSERT_EQ(deconv_primitive_desc.get_padding_r(), padR);
288
289 EXPECT_ANY_THROW(deconvolution_forward(deconv_primitive_desc, {}));
290 deconvolution_forward(deconv_primitive_desc)
291 .execute(strm,
292 {{DNNL_ARG_SRC, src->get()},
293 {DNNL_ARG_WEIGHTS, weights->get()},
294 {DNNL_ARG_BIAS, bias->get()},
295 {DNNL_ARG_DST, dst->get()}});
296 strm.wait();
297
298 auto conv_primitive_desc = convolution_forward::primitive_desc(eng,
299 prop_kind::forward_training, algorithm::convolution_direct,
300 *con_src_desc, *con_weights_desc, *con_dst_desc, strides, padL,
301 padR);
302
303 auto conv_bwd_data_primitive_desc
304 = convolution_backward_data::primitive_desc(eng,
305 algorithm::convolution_direct, *con_src_desc,
306 *con_weights_desc, *con_dst_desc, strides, padL, padR,
307 conv_primitive_desc);
308
309 convolution_backward_data(conv_bwd_data_primitive_desc)
310 .execute(strm,
311 {{DNNL_ARG_DIFF_DST, conv_dst->get()},
312 {DNNL_ARG_WEIGHTS, weights_tr},
313 {DNNL_ARG_DIFF_SRC, conv_src.get()}});
314 strm.wait();
315
316 if (with_bias)
317 compute_bias_fwd<data_t>(dd, conv_src.get(), bias->get());
318 compare_data<data_t>(conv_src.get(), dst->get());
319 }
320
321 void BackwardData() {
322 auto p = ::testing::TestWithParam<
323 deconvolution_test_params_t>::GetParam();
324 auto conv_src = dst;
325 auto conv_dst = test_memory(*con_dst_desc, eng);
326 test_convolution_sizes_t dd = p.sizes;
327
328 fill_data<data_t>(weights->get_size() / sizeof(data_t), weights->get());
329
330 fill_data<data_t>(dst->get_size() / sizeof(data_t), dst->get());
331
332 auto weights_tr = test::make_memory(*con_weights_desc, eng);
333 transpose_wei<data_t>(dd, weights->get(), weights_tr);
334
335 auto deconv_primitive_desc = deconvolution_forward::primitive_desc(eng,
336 prop_kind::forward_training, algorithm::deconvolution_direct,
337 *dec_src_desc, *dec_weights_desc, *dec_dst_desc, strides, padL,
338 padR);
339
340 auto deconv_bwd_data_primitive_desc
341 = deconvolution_backward_data::primitive_desc(eng,
342 algorithm::deconvolution_direct, *dec_src_desc,
343 *dec_weights_desc, *dec_dst_desc, strides, padL, padR,
344 deconv_primitive_desc);
345 deconv_bwd_data_primitive_desc
346 = deconvolution_backward_data::primitive_desc(
347 deconv_bwd_data_primitive_desc
348 .get()); // test construction from a C pd
349
350 ASSERT_TRUE(deconv_bwd_data_primitive_desc.query_md(
351 query::exec_arg_md, DNNL_ARG_DIFF_SRC)
352 == deconv_bwd_data_primitive_desc.diff_src_desc());
353 ASSERT_TRUE(deconv_bwd_data_primitive_desc.query_md(
354 query::exec_arg_md, DNNL_ARG_DIFF_DST)
355 == deconv_bwd_data_primitive_desc.diff_dst_desc());
356 ASSERT_TRUE(deconv_bwd_data_primitive_desc.query_md(
357 query::exec_arg_md, DNNL_ARG_WEIGHTS)
358 == deconv_bwd_data_primitive_desc.weights_desc());
359
360 ASSERT_EQ(deconv_bwd_data_primitive_desc.get_algorithm(),
361 algorithm::deconvolution_direct);
362 ASSERT_EQ(deconv_bwd_data_primitive_desc.get_prop_kind(),
363 prop_kind::backward_data);
364 ASSERT_EQ(deconv_bwd_data_primitive_desc.get_strides(), strides);
365 ASSERT_EQ(deconv_bwd_data_primitive_desc.get_padding_l(), padL);
366 ASSERT_EQ(deconv_bwd_data_primitive_desc.get_padding_r(), padR);
367
368 deconvolution_backward_data(deconv_bwd_data_primitive_desc)
369 .execute(strm,
370 {{DNNL_ARG_DIFF_DST, dst->get()},
371 {DNNL_ARG_WEIGHTS, weights->get()},
372 {DNNL_ARG_DIFF_SRC, src->get()}});
373 strm.wait();
374
375 auto conv_primitive_desc = convolution_forward::primitive_desc(eng,
376 prop_kind::forward_training, algorithm::convolution_direct,
377 *con_src_desc, *con_weights_desc, *con_dst_desc, strides, padL,
378 padR);
379
380 convolution_forward(conv_primitive_desc)
381 .execute(strm,
382 {{DNNL_ARG_SRC, conv_src->get()},
383 {DNNL_ARG_WEIGHTS, weights_tr},
384 {DNNL_ARG_DST, conv_dst.get()}});
385 strm.wait();
386
387 compare_data<data_t>(conv_dst.get(), src->get());
388 }
389
390 void BackwardWeights() {
391 auto p = ::testing::TestWithParam<
392 deconvolution_test_params_t>::GetParam();
393 auto conv_src = dst;
394 auto conv_dst = src;
395 auto conv_weights = test::make_memory(*con_weights_desc, eng);
396 test_convolution_sizes_t dd = p.sizes;
397
398 fill_data<data_t>(src->get_size() / sizeof(data_t), src->get());
399
400 fill_data<data_t>(dst->get_size() / sizeof(data_t), dst->get());
401
402 auto deconv_primitive_desc = deconvolution_forward::primitive_desc(eng,
403 prop_kind::forward_training, algorithm::deconvolution_direct,
404 *dec_src_desc, *dec_weights_desc, *dec_bias_desc, *dec_dst_desc,
405 {dd.strh, dd.strw}, {dd.padh, dd.padw}, padR);
406
407 auto deconv_bwd_weights_primitive_desc
408 = deconvolution_backward_weights::primitive_desc(eng,
409 algorithm::deconvolution_direct, *dec_src_desc,
410 *dec_weights_desc, *dec_bias_desc, *dec_dst_desc,
411 strides, padL, padR, deconv_primitive_desc);
412
413 ASSERT_TRUE(deconv_bwd_weights_primitive_desc.query_md(
414 query::exec_arg_md, DNNL_ARG_SRC)
415 == deconv_bwd_weights_primitive_desc.src_desc());
416 ASSERT_TRUE(deconv_bwd_weights_primitive_desc.query_md(
417 query::exec_arg_md, DNNL_ARG_DIFF_DST)
418 == deconv_bwd_weights_primitive_desc.diff_dst_desc());
419 ASSERT_TRUE(deconv_bwd_weights_primitive_desc.query_md(
420 query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS)
421 == deconv_bwd_weights_primitive_desc.diff_weights_desc());
422 ASSERT_TRUE(deconv_bwd_weights_primitive_desc.query_md(
423 query::exec_arg_md, DNNL_ARG_DIFF_BIAS)
424 == deconv_bwd_weights_primitive_desc.diff_bias_desc());
425
426 ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_algorithm(),
427 algorithm::deconvolution_direct);
428 ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_prop_kind(),
429 prop_kind::backward_weights);
430 ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_strides(), strides);
431 ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_padding_l(), padL);
432 ASSERT_EQ(deconv_bwd_weights_primitive_desc.get_padding_r(), padR);
433
434 deconvolution_backward_weights(deconv_bwd_weights_primitive_desc)
435 .execute(strm,
436 {{DNNL_ARG_DIFF_DST, dst->get()},
437 {DNNL_ARG_SRC, src->get()},
438 {DNNL_ARG_DIFF_WEIGHTS, weights->get()},
439 {DNNL_ARG_DIFF_BIAS, bias->get()}});
440 strm.wait();
441
442 auto conv_primitive_desc = convolution_forward::primitive_desc(eng,
443 prop_kind::forward_training, algorithm::convolution_direct,
444 *con_src_desc, *con_weights_desc, *con_dst_desc, strides, padL,
445 padR);
446
447 deconv_bwd_weights_primitive_desc
448 = deconvolution_backward_weights::primitive_desc(
449 deconv_bwd_weights_primitive_desc
450 .get()); // test construction from a C pd
451
452 auto conv_bwd_weights_primitive_desc
453 = convolution_backward_weights::primitive_desc(eng,
454 algorithm::convolution_direct, *con_src_desc,
455 *con_weights_desc, *con_dst_desc, strides, padL, padR,
456 conv_primitive_desc);
457
458 convolution_backward_weights(conv_bwd_weights_primitive_desc)
459 .execute(strm,
460 {{DNNL_ARG_DIFF_DST, conv_dst->get()},
461 {DNNL_ARG_SRC, conv_src->get()},
462 {DNNL_ARG_DIFF_WEIGHTS, conv_weights}});
463 strm.wait();
464
465 auto weights_tr = test::make_memory(*con_weights_desc, eng);
466 transpose_wei<data_t>(dd, weights->get(), weights_tr);
467
468 compare_data<data_t>(weights_tr, conv_weights);
469
470 if (with_bias) {
471 auto ref_bias = test::make_memory(*dec_bias_desc, eng);
472 compute_bias_bwd<data_t>(dd, dst->get(), ref_bias);
473 compare_data<data_t>(ref_bias, bias->get());
474 }
475 }
476};
477
478using deconvolution_test_float = deconvolution_test_t<float>;
479
480TEST_P(deconvolution_test_float, TestDeconvolution) {}
481
482#define EXPAND_FORMATS(src, weights, bias, dst) \
483 { \
484 dnnl::memory::format_tag::src, dnnl::memory::format_tag::weights, \
485 dnnl::memory::format_tag::bias, dnnl::memory::format_tag::dst \
486 }
487
488#define ALGORITHM dnnl::algorithm::deconvolution_direct
489
490#define PARAMS(src, weights, bias, dst, ...) \
491 deconvolution_test_params_t { \
492 ALGORITHM, EXPAND_FORMATS(src, weights, bias, dst), {}, { \
493 __VA_ARGS__ \
494 } \
495 }
496
497#define CPU_INST_TEST_CASE(str, ...) \
498 CPU_INSTANTIATE_TEST_SUITE_P( \
499 str, deconvolution_test_float, ::testing::Values(__VA_ARGS__))
500#define GPU_INST_TEST_CASE(str, ...) \
501 GPU_INSTANTIATE_TEST_SUITE_P( \
502 str, deconvolution_test_float, ::testing::Values(__VA_ARGS__))
503
504#define FMT_BIAS x
505#define FMT_DATA_BLOCKED nChw8c
506#define FMT_WEIGHTS_BLOCKED Ohwi8o
507#define FMT_DATA_BLOCKED_GPU NChw16n16c
508#define FMT_WEIGHTS_BLOCKED_GPU IOhw16i16o
509
510CPU_INST_TEST_CASE(SimpleSmall_NCHW,
511 PARAMS(nchw, oihw, x, nchw, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1),
512 PARAMS(nchw, oihw, x, nchw, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1),
513 PARAMS(nhwc, oihw, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1),
514 PARAMS(nhwc, hwio, x, nhwc, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1),
515 PARAMS(nhwc, hwio, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1),
516 PARAMS(nhwc, goihw, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 0, 0, 1, 1),
517 PARAMS(nhwc, hwigo, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1)
518
519);
520
521CPU_INST_TEST_CASE(SimpleSmall_Blocked,
522 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS,
523 FMT_DATA_BLOCKED, 2, 1, 32, 12, 12, 32, 13, 13, 3, 3, 0, 0, 1,
524 1),
525 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS,
526 FMT_DATA_BLOCKED, 2, 1, 32, 4, 4, 32, 3, 3, 3, 3, 1, 1, 1, 1),
527 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS,
528 FMT_DATA_BLOCKED, 2, 1, 32, 4, 4, 32, 4, 4, 3, 3, 0, 0, 1, 1),
529 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS,
530 FMT_DATA_BLOCKED, 2, 1, 32, 2, 2, 32, 3, 3, 3, 3, 0, 0, 1, 1),
531 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS,
532 FMT_DATA_BLOCKED, 2, 1, 32, 2, 2, 32, 2, 2, 3, 3, 1, 1, 1, 1),
533 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS,
534 FMT_DATA_BLOCKED, 2, 1, 48, 13, 13, 32, 13, 13, 3, 3, 1, 1, 1,
535 1),
536 PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS,
537 FMT_DATA_BLOCKED, 2, 1, 48, 11, 11, 32, 13, 13, 3, 3, 0, 0, 1,
538 1));
539
540GPU_INST_TEST_CASE(SimpleSmall_Blocked,
541 PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS,
542 FMT_DATA_BLOCKED_GPU, 32, 1, 32, 12, 12, 32, 10, 10, 3, 3, 0, 0,
543 1, 1),
544 PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS,
545 FMT_DATA_BLOCKED_GPU, 32, 1, 32, 4, 4, 32, 3, 3, 3, 3, 1, 1, 1,
546 1),
547 PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS,
548 FMT_DATA_BLOCKED_GPU, 32, 1, 32, 4, 4, 32, 4, 4, 3, 3, 0, 0, 1,
549 1),
550 PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS,
551 FMT_DATA_BLOCKED_GPU, 32, 1, 32, 2, 2, 32, 3, 3, 3, 3, 0, 0, 1,
552 1),
553 PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS,
554 FMT_DATA_BLOCKED_GPU, 32, 1, 32, 2, 2, 32, 2, 2, 3, 3, 1, 1, 1,
555 1),
556 PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS,
557 FMT_DATA_BLOCKED_GPU, 32, 1, 48, 13, 13, 32, 13, 13, 3, 3, 1, 1,
558 1, 1),
559 PARAMS(FMT_DATA_BLOCKED_GPU, FMT_WEIGHTS_BLOCKED_GPU, FMT_BIAS,
560 FMT_DATA_BLOCKED_GPU, 32, 1, 48, 11, 11, 32, 13, 13, 3, 3, 0, 0,
561 1, 1));
562
563GPU_INST_TEST_CASE(SimpleSmall_NCHW,
564 PARAMS(nchw, oihw, x, nchw, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1),
565 PARAMS(nchw, oihw, x, nchw, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1),
566 PARAMS(nhwc, oihw, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1),
567 PARAMS(nhwc, hwio, x, nhwc, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1),
568 PARAMS(nhwc, hwio, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1),
569 PARAMS(nhwc, goihw, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 0, 0, 1, 1),
570 PARAMS(nhwc, hwigo, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1));
571
572GPU_INST_TEST_CASE(SimpleSmall_NHWC,
573 PARAMS(nchw, oihw, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1),
574 PARAMS(nhwc, ohwi, x, nhwc, 2, 1, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1),
575 PARAMS(nhwc, ohwi, x, nhwc, 2, 1, 6, 2, 2, 4, 4, 4, 3, 3, 0, 0, 1, 1),
576 PARAMS(nchw, goihw, x, nchw, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1),
577 PARAMS(nchw, goihw, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1),
578 PARAMS(nhwc, gohwi, x, nhwc, 2, 2, 6, 4, 4, 4, 4, 4, 3, 3, 1, 1, 1, 1));
579} // namespace dnnl
580