1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24struct concat_test_params_t {
25 size_t concat_dimension;
26 std::vector<memory::format_tag> srcs_format;
27 memory::format_tag dst_format;
28 std::vector<memory::dims> srcs_cds;
29 memory::dims dst_cds;
30 bool expect_to_fail;
31 dnnl_status_t expected_status;
32};
33
34template <typename data_t>
35class concat_test_t : public ::testing::TestWithParam<concat_test_params_t> {
36 void check_data(const std::vector<memory> &srcs, const memory &dst,
37 int concat_dim) {
38 auto dst_data = map_memory<const data_t>(dst);
39 const auto &dst_d = dst.get_desc();
40 const auto dst_dims = dst_d.get_dims();
41 const auto dst_pdims = dst_d.get_padded_dims();
42 const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get());
43
44 memory::dim acc_concat_dim = 0;
45 const auto ndims = dst_d.get_ndims();
46
47 for (size_t num = 0; num < srcs.size(); num++) {
48 const auto src_data = map_memory<data_t>(srcs[num]);
49 const auto &src_d = srcs[num].get_desc();
50 const auto src_dims = src_d.get_dims();
51 const auto src_pdims = src_d.get_padded_dims();
52 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get());
53
54 auto N = src_dims[0];
55 auto C = src_dims[1];
56 auto C_PADDED = src_pdims[1];
57 auto D = (ndims == 5) ? src_dims[2] : 1;
58 auto H = src_dims[ndims - 2];
59 auto W = src_dims[ndims - 1];
60
61 auto DST_C_PADDED = dst_pdims[1];
62 auto DST_D = (ndims == 5) ? dst_dims[2] : 1;
63 auto DST_H = dst_dims[ndims - 2];
64 auto DST_W = dst_dims[ndims - 1];
65
66 for_(memory::dim n = 0; n < N; n++)
67 for_(memory::dim c = 0; c < C; c++)
68 for_(memory::dim d = 0; d < D; d++)
69 for_(memory::dim h = 0; h < H; h++)
70 for (memory::dim w = 0; w < W; w++) {
71 auto src_idx = w + W * h + H * W * d + D * H * W * c
72 + C_PADDED * D * H * W * n;
73
74 auto adj_dst_dim = [&](int dim, memory::dim dim_sz) {
75 if (concat_dim == dim) return dim_sz + acc_concat_dim;
76 return dim_sz;
77 };
78 auto dst_idx = adj_dst_dim(ndims - 1, w)
79 + DST_W * adj_dst_dim(ndims - 2, h)
80 + DST_D * DST_H * DST_W * adj_dst_dim(1, c)
81 + DST_C_PADDED * DST_D * DST_H * DST_W
82 * adj_dst_dim(0, n);
83 if (ndims == 5) dst_idx += DST_H * DST_W * adj_dst_dim(2, d);
84 ASSERT_NEAR(src_data[src_mdw.off_l(src_idx, true)],
85 dst_data[dst_mdw.off_l(dst_idx, true)], 1e-7);
86 }
87
88 acc_concat_dim += src_dims[concat_dim];
89 }
90 }
91
92protected:
93 bool cuda_supported_format_tag(memory::format_tag tag) {
94 return impl::utils::one_of(tag, dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd,
95 dnnl_abcde, dnnl_abcdef, dnnl_abdec, dnnl_acb, dnnl_acbde,
96 dnnl_acbdef, dnnl_acdb, dnnl_acdeb, dnnl_ba, dnnl_bac,
97 dnnl_bacd, dnnl_bca, dnnl_bcda, dnnl_bcdea, dnnl_cba, dnnl_cdba,
98 dnnl_cdeba, dnnl_decab, dnnl_defcab, dnnl_aBc4b, dnnl_aBcd4b,
99 dnnl_aBcde4b);
100 }
101
102 void SetUp() override {
103 auto data_type = data_traits<data_t>::data_type;
104 SKIP_IF(unsupported_data_type(data_type),
105 "Engine does not support this data type.");
106 concat_test_params_t p
107 = ::testing::TestWithParam<decltype(p)>::GetParam();
108 for (size_t i = 0; i < p.srcs_cds.size(); i++) {
109 SKIP_IF_CUDA(!cuda_supported_format_tag(p.srcs_format[i]),
110 "Unsupported format tag");
111 }
112
113 SKIP_IF_CUDA(!cuda_supported_format_tag(p.dst_format),
114 "Unsupported format tag");
115 catch_expected_failures(
116 [=]() { Test(); }, p.expect_to_fail, p.expected_status, false);
117 }
118
119 virtual void Test() {
120 concat_test_params_t p
121 = ::testing::TestWithParam<concat_test_params_t>::GetParam();
122
123 int src_dim_sum = 0;
124 for (size_t i = 0; i < p.srcs_cds.size(); i++) {
125 for (size_t dim = 0; dim < p.dst_cds.size(); dim++) {
126 if (dim == p.concat_dimension)
127 src_dim_sum += p.srcs_cds[i][dim];
128 else if (p.expect_to_fail == false) {
129 ASSERT_TRUE(p.srcs_cds[i][dim] == p.dst_cds[dim]);
130 }
131 }
132 }
133
134 if (p.expect_to_fail == false) {
135 ASSERT_TRUE(src_dim_sum == p.dst_cds[p.concat_dimension]);
136 }
137
138 auto eng = get_test_engine();
139 auto strm = make_stream(eng);
140 memory::data_type data_type = data_traits<data_t>::data_type;
141
142 std::vector<memory::desc> srcs_md;
143 std::vector<memory> srcs;
144 for (size_t i = 0; i < p.srcs_cds.size(); i++) {
145 auto md = memory::desc(p.srcs_cds[i], data_type, p.srcs_format[i]);
146 srcs_md.push_back(md);
147 }
148
149 auto dst_desc = memory::desc(p.dst_cds, data_type, p.dst_format);
150 auto concat_pd = concat::primitive_desc(
151 eng, dst_desc, static_cast<int>(p.concat_dimension), srcs_md);
152 // test construction from a C pd
153 concat_pd = concat::primitive_desc(concat_pd.get());
154
155 ASSERT_TRUE(concat_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
156 == concat_pd.dst_desc());
157
158 for (int i = 0; i < (int)srcs.size(); i++) {
159 if (p.srcs_format[i] != memory::format_tag::any) {
160 ASSERT_TRUE(srcs_md[i] == concat_pd.src_desc(i));
161 }
162 }
163
164 auto dst = test::make_memory(concat_pd.dst_desc(), eng);
165 fill_data<data_t>(dst.get_desc().get_size() / sizeof(data_t), dst);
166 check_zero_tail<data_t>(1, dst);
167
168 ASSERT_EQ(concat_pd.dst_desc().get_ndims(), dst_desc.get_ndims());
169
170 for (size_t i = 0; i < p.srcs_cds.size(); i++) {
171 auto md = concat_pd.src_desc((int)i);
172 auto src_memory = test::make_memory(md, eng);
173 const size_t sz = src_memory.get_desc().get_size() / sizeof(data_t);
174 fill_data<data_t>(sz, src_memory);
175 check_zero_tail<data_t>(1, src_memory);
176 srcs.push_back(src_memory);
177 }
178
179 for (int i = 0; i < (int)srcs.size(); i++)
180 ASSERT_TRUE(concat_pd.query_md(
181 query::exec_arg_md, DNNL_ARG_MULTIPLE_SRC + i)
182 == concat_pd.src_desc(i));
183
184 EXPECT_ANY_THROW(concat(concat_pd, {}));
185 concat c(concat_pd);
186 std::unordered_map<int, memory> args = {{DNNL_ARG_DST, dst}};
187 for (int i = 0; i < (int)srcs.size(); i++) {
188 args.insert({DNNL_ARG_MULTIPLE_SRC + i, srcs[i]});
189 }
190 c.execute(strm, args);
191 strm.wait();
192
193 check_data(srcs, dst, static_cast<int>(p.concat_dimension));
194 check_zero_tail<data_t>(0, dst);
195 }
196};
197
198using concat_test_float = concat_test_t<float>;
199using concat_test_float16 = concat_test_t<float16_t>;
200using concat_test_s8 = concat_test_t<int8_t>;
201using concat_test_bf16 = concat_test_t<bfloat16_t>;
202
203TEST_P(concat_test_float, TestsConcat) {}
204TEST_P(concat_test_s8, TestsConcat) {}
205TEST_P(concat_test_bf16, TestsConcat) {}
206TEST_P(concat_test_float16, TestConcat) {}
207
208using fmt = memory::format_tag;
209
210static auto case_ZeroDim = []() {
211 return ::testing::Values(
212 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
213 {{4, 0, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}},
214 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
215 {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}},
216 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c,
217 {{4, 0, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}},
218 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c,
219 {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}},
220 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
221 {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
222 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
223 {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}},
224 concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc,
225 {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
226 concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw,
227 {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
228 concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc,
229 {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}},
230 concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw,
231 {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}});
232};
233INSTANTIATE_TEST_SUITE_P(TestConcat_ZeroDim, concat_test_float, case_ZeroDim());
234CPU_INSTANTIATE_TEST_SUITE_P(
235 TestConcat_ZeroDim_bf16, concat_test_bf16, case_ZeroDim());
236CPU_INSTANTIATE_TEST_SUITE_P(
237 TestConcat_ZeroDim_f16, concat_test_float16, case_ZeroDim());
238
239static auto cases_EF = []() {
240 return ::testing::Values(
241 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
242 {{4, 2, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}, true,
243 dnnl_invalid_arguments},
244 concat_test_params_t {2, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
245 {{4, 2, 5, 5}, {4, 3, 5, 5}}, {4, 5, 5, 5}, true,
246 dnnl_invalid_arguments},
247 concat_test_params_t {5, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
248 {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}, true,
249 dnnl_invalid_arguments},
250 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c,
251 {{4, -1, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}, true,
252 dnnl_invalid_arguments},
253 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c,
254 {{4, 4, 5, 5}, {4, 4, 5, 5}}, {4, 4, 5, 5}, true,
255 dnnl_invalid_arguments},
256 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
257 {{0, 4, 5, 5}, {0, 4, 5, 5}}, {0, 6, 5, 5}, true,
258 dnnl_invalid_arguments},
259 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
260 {{2, 4, 2, 5}, {2, 2, 1, 5}}, {2, 6, 2, 5}, true,
261 dnnl_invalid_arguments},
262 concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc,
263 {{1, 4, 5, 5}, {1, 2, 5, 5}}, {1, 7, 5, 5}, true,
264 dnnl_invalid_arguments},
265 concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw,
266 {{1, 4, 5, 5}, {1, 2, 5, 5}}, {1, 6, 6, 5}, true,
267 dnnl_invalid_arguments},
268 concat_test_params_t {1, {fmt::any, fmt::nchw}, fmt::nchw,
269 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}, true,
270 dnnl_invalid_arguments},
271 concat_test_params_t {1, {fmt::nchw, fmt::any}, fmt::nchw,
272 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}, true,
273 dnnl_invalid_arguments});
274};
275INSTANTIATE_TEST_SUITE_P(TestConcat_EF, concat_test_float, cases_EF());
276CPU_INSTANTIATE_TEST_SUITE_P(TestConcat_EF_bf16, concat_test_bf16, cases_EF());
277CPU_INSTANTIATE_TEST_SUITE_P(
278 TestConcat_EF_f16, concat_test_float16, cases_EF());
279
280static auto cases_padded = []() {
281 return ::testing::Values(
282 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c,
283 {{1, 12, 28, 28}, {1, 12, 28, 28}}, {1, 24, 28, 28}},
284 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c,
285 {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
286 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw,
287 {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
288 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw,
289 {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
290 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw8c}, fmt::nchw,
291 {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
292 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
293 {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
294 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c,
295 {{4, 4, 5, 5}, {4, 6, 5, 5}}, {4, 10, 5, 5}},
296 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw,
297 {{4, 4, 5, 5}, {4, 6, 5, 5}}, {4, 10, 5, 5}},
298 concat_test_params_t {1, {fmt::nchw, fmt::nChw16c}, fmt::nChw16c,
299 {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
300 concat_test_params_t {1, {fmt::nchw, fmt::nChw16c}, fmt::nchw,
301 {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
302 // right border
303 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c,
304 {{4, 16, 5, 5}, {4, 3, 5, 5}}, {4, 19, 5, 5}},
305 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw8c,
306 {{4, 16, 5, 5}, {4, 3, 5, 5}}, {4, 19, 5, 5}},
307 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c,
308 {{4, 8, 5, 5}, {4, 3, 5, 5}}, {4, 11, 5, 5}},
309 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nChw16c,
310 {{4, 8, 5, 5}, {4, 3, 5, 5}}, {4, 11, 5, 5}},
311 // not over channels
312 concat_test_params_t {2, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw,
313 {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}},
314 concat_test_params_t {2, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw,
315 {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}},
316 concat_test_params_t {2, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,
317 {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}});
318};
319INSTANTIATE_TEST_SUITE_P(TestConcat_padded, concat_test_float, cases_padded());
320CPU_INSTANTIATE_TEST_SUITE_P(
321 TestConcat_padded_bf16, concat_test_bf16, cases_padded());
322CPU_INSTANTIATE_TEST_SUITE_P(
323 TestConcat_padded_f16, concat_test_float16, cases_padded());
324
325static auto cases_3D = []() {
326 return ::testing::Values(
327 concat_test_params_t {0, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw,
328 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {4, 8, 3, 4, 5}},
329 concat_test_params_t {1, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw,
330 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
331 concat_test_params_t {2, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw,
332 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 6, 4, 5}},
333 concat_test_params_t {3, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw,
334 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 8, 5}},
335 concat_test_params_t {4, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw,
336 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 4, 10}},
337 concat_test_params_t {0, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c,
338 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {4, 8, 3, 4, 5}},
339 concat_test_params_t {1, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c,
340 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
341 concat_test_params_t {1, {fmt::nCdhw8c, fmt::ncdhw}, fmt::nCdhw8c,
342 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
343 concat_test_params_t {1, {fmt::ncdhw, fmt::ncdhw}, fmt::nCdhw8c,
344 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
345 concat_test_params_t {2, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c,
346 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 6, 4, 5}},
347 concat_test_params_t {3, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c,
348 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 8, 5}},
349 concat_test_params_t {4, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c,
350 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 4, 10}});
351};
352INSTANTIATE_TEST_SUITE_P(TestConcat3D, concat_test_float, cases_3D());
353CPU_INSTANTIATE_TEST_SUITE_P(TestConcat3D_bf16, concat_test_bf16, cases_3D());
354CPU_INSTANTIATE_TEST_SUITE_P(TestConcat3D_f16, concat_test_float16, cases_3D());
355
356static auto cases_concat = []() {
357 return ::testing::Values(
358 concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw,
359 {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}},
360 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c,
361 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
362 concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nChw8c,
363 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
364 concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc,
365 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
366 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw,
367 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
368
369 concat_test_params_t {0, {fmt::nchw, fmt::nchw}, fmt::nchw,
370 {{2, 8, 3, 4}, {2, 8, 3, 4}}, {4, 8, 3, 4}},
371 concat_test_params_t {0, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c,
372 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
373 concat_test_params_t {0, {fmt::nchw, fmt::nchw}, fmt::nChw8c,
374 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
375 concat_test_params_t {0, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw,
376 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
377
378 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c,
379 {{2, 8, 1, 1}, {2, 8, 1, 1}}, {2, 16, 1, 1}},
380
381 concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nChw8c,
382 {{2, 8, 1, 1}, {2, 16, 1, 1}}, {2, 24, 1, 1}});
383};
384INSTANTIATE_TEST_SUITE_P(TestConcat, concat_test_float, cases_concat());
385CPU_INSTANTIATE_TEST_SUITE_P(TestConcat_bf16, concat_test_bf16, cases_concat());
386CPU_INSTANTIATE_TEST_SUITE_P(
387 TestConcat_f16, concat_test_float16, cases_concat());
388
389INSTANTIATE_TEST_SUITE_P(TestConcat, concat_test_s8,
390 ::testing::Values(
391 concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc,
392 {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}},
393 concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw,
394 {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}}));
395
396static auto cases_concat_gpu = []() {
397 return ::testing::Values(concat_test_params_t {1, {fmt::nchw}, fmt::nchw,
398 {{1, 1, 1, 1}}, {1, 1, 1, 1}},
399 concat_test_params_t {
400 1, {fmt::nchw}, fmt::nhwc, {{1, 1, 1, 1}}, {1, 1, 1, 1}},
401 concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw,
402 {{1, 1, 1, 1}, {1, 1, 1, 1}}, {1, 2, 1, 1}},
403 concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw,
404 {{4, 5, 5, 5}, {4, 5, 5, 5}}, {4, 10, 5, 5}},
405 concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c,
406 {{4, 16, 5, 5}, {4, 32, 5, 5}}, {4, 48, 5, 5}},
407 concat_test_params_t {0, {fmt::NChw16n16c}, fmt::NChw16n16c,
408 {{16, 16, 1, 1}}, {16, 16, 1, 1}},
409 concat_test_params_t {0, {fmt::NChw16n16c, fmt::NChw16n16c},
410 fmt::NChw16n16c, {{16, 16, 1, 1}, {16, 16, 1, 1}},
411 {32, 16, 1, 1}},
412 concat_test_params_t {1, {fmt::NChw16n16c, fmt::NChw16n16c},
413 fmt::NChw16n16c, {{16, 16, 1, 1}, {16, 16, 1, 1}},
414 {16, 32, 1, 1}},
415 concat_test_params_t {2, {fmt::NChw16n16c, fmt::NChw16n16c},
416 fmt::NChw16n16c, {{16, 16, 1, 1}, {16, 16, 1, 1}},
417 {16, 16, 2, 1}},
418 concat_test_params_t {3, {fmt::NChw16n16c, fmt::NChw16n16c},
419 fmt::NChw16n16c, {{16, 16, 1, 1}, {16, 16, 1, 1}},
420 {16, 16, 1, 2}},
421 concat_test_params_t {1, {fmt::NChw16n16c, fmt::NChw16n16c},
422 fmt::NChw16n16c, {{16, 16, 5, 5}, {16, 32, 5, 5}},
423 {16, 48, 5, 5}},
424 concat_test_params_t {1, {fmt::NCdhw16n16c, fmt::NCdhw16n16c},
425 fmt::NCdhw16n16c, {{16, 16, 5, 5, 5}, {16, 32, 5, 5, 5}},
426 {16, 48, 5, 5, 5}},
427 concat_test_params_t {2, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw,
428 {{4, 16, 5, 5}, {4, 16, 5, 5}}, {4, 16, 10, 5}},
429 concat_test_params_t {2, {fmt::NChw16n16c, fmt::NChw16n16c},
430 fmt::nchw, {{16, 16, 5, 5}, {16, 16, 5, 5}},
431 {16, 16, 10, 5}});
432};
433
434GPU_INSTANTIATE_TEST_SUITE_P(TestConcat, concat_test_float, cases_concat_gpu());
435GPU_INSTANTIATE_TEST_SUITE_P(
436 TestConcat, concat_test_float16, cases_concat_gpu());
437
438} // namespace dnnl
439