1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | struct concat_test_params_t { |
25 | size_t concat_dimension; |
26 | std::vector<memory::format_tag> srcs_format; |
27 | memory::format_tag dst_format; |
28 | std::vector<memory::dims> srcs_cds; |
29 | memory::dims dst_cds; |
30 | bool expect_to_fail; |
31 | dnnl_status_t expected_status; |
32 | }; |
33 | |
34 | template <typename data_t> |
35 | class concat_test_t : public ::testing::TestWithParam<concat_test_params_t> { |
36 | void check_data(const std::vector<memory> &srcs, const memory &dst, |
37 | int concat_dim) { |
38 | auto dst_data = map_memory<const data_t>(dst); |
39 | const auto &dst_d = dst.get_desc(); |
40 | const auto dst_dims = dst_d.get_dims(); |
41 | const auto dst_pdims = dst_d.get_padded_dims(); |
42 | const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get()); |
43 | |
44 | memory::dim acc_concat_dim = 0; |
45 | const auto ndims = dst_d.get_ndims(); |
46 | |
47 | for (size_t num = 0; num < srcs.size(); num++) { |
48 | const auto src_data = map_memory<data_t>(srcs[num]); |
49 | const auto &src_d = srcs[num].get_desc(); |
50 | const auto src_dims = src_d.get_dims(); |
51 | const auto src_pdims = src_d.get_padded_dims(); |
52 | const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get()); |
53 | |
54 | auto N = src_dims[0]; |
55 | auto C = src_dims[1]; |
56 | auto C_PADDED = src_pdims[1]; |
57 | auto D = (ndims == 5) ? src_dims[2] : 1; |
58 | auto H = src_dims[ndims - 2]; |
59 | auto W = src_dims[ndims - 1]; |
60 | |
61 | auto DST_C_PADDED = dst_pdims[1]; |
62 | auto DST_D = (ndims == 5) ? dst_dims[2] : 1; |
63 | auto DST_H = dst_dims[ndims - 2]; |
64 | auto DST_W = dst_dims[ndims - 1]; |
65 | |
66 | for_(memory::dim n = 0; n < N; n++) |
67 | for_(memory::dim c = 0; c < C; c++) |
68 | for_(memory::dim d = 0; d < D; d++) |
69 | for_(memory::dim h = 0; h < H; h++) |
70 | for (memory::dim w = 0; w < W; w++) { |
71 | auto src_idx = w + W * h + H * W * d + D * H * W * c |
72 | + C_PADDED * D * H * W * n; |
73 | |
74 | auto adj_dst_dim = [&](int dim, memory::dim dim_sz) { |
75 | if (concat_dim == dim) return dim_sz + acc_concat_dim; |
76 | return dim_sz; |
77 | }; |
78 | auto dst_idx = adj_dst_dim(ndims - 1, w) |
79 | + DST_W * adj_dst_dim(ndims - 2, h) |
80 | + DST_D * DST_H * DST_W * adj_dst_dim(1, c) |
81 | + DST_C_PADDED * DST_D * DST_H * DST_W |
82 | * adj_dst_dim(0, n); |
83 | if (ndims == 5) dst_idx += DST_H * DST_W * adj_dst_dim(2, d); |
84 | ASSERT_NEAR(src_data[src_mdw.off_l(src_idx, true)], |
85 | dst_data[dst_mdw.off_l(dst_idx, true)], 1e-7); |
86 | } |
87 | |
88 | acc_concat_dim += src_dims[concat_dim]; |
89 | } |
90 | } |
91 | |
92 | protected: |
93 | bool cuda_supported_format_tag(memory::format_tag tag) { |
94 | return impl::utils::one_of(tag, dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd, |
95 | dnnl_abcde, dnnl_abcdef, dnnl_abdec, dnnl_acb, dnnl_acbde, |
96 | dnnl_acbdef, dnnl_acdb, dnnl_acdeb, dnnl_ba, dnnl_bac, |
97 | dnnl_bacd, dnnl_bca, dnnl_bcda, dnnl_bcdea, dnnl_cba, dnnl_cdba, |
98 | dnnl_cdeba, dnnl_decab, dnnl_defcab, dnnl_aBc4b, dnnl_aBcd4b, |
99 | dnnl_aBcde4b); |
100 | } |
101 | |
102 | void SetUp() override { |
103 | auto data_type = data_traits<data_t>::data_type; |
104 | SKIP_IF(unsupported_data_type(data_type), |
105 | "Engine does not support this data type." ); |
106 | concat_test_params_t p |
107 | = ::testing::TestWithParam<decltype(p)>::GetParam(); |
108 | for (size_t i = 0; i < p.srcs_cds.size(); i++) { |
109 | SKIP_IF_CUDA(!cuda_supported_format_tag(p.srcs_format[i]), |
110 | "Unsupported format tag" ); |
111 | } |
112 | |
113 | SKIP_IF_CUDA(!cuda_supported_format_tag(p.dst_format), |
114 | "Unsupported format tag" ); |
115 | catch_expected_failures( |
116 | [=]() { Test(); }, p.expect_to_fail, p.expected_status, false); |
117 | } |
118 | |
119 | virtual void Test() { |
120 | concat_test_params_t p |
121 | = ::testing::TestWithParam<concat_test_params_t>::GetParam(); |
122 | |
123 | int src_dim_sum = 0; |
124 | for (size_t i = 0; i < p.srcs_cds.size(); i++) { |
125 | for (size_t dim = 0; dim < p.dst_cds.size(); dim++) { |
126 | if (dim == p.concat_dimension) |
127 | src_dim_sum += p.srcs_cds[i][dim]; |
128 | else if (p.expect_to_fail == false) { |
129 | ASSERT_TRUE(p.srcs_cds[i][dim] == p.dst_cds[dim]); |
130 | } |
131 | } |
132 | } |
133 | |
134 | if (p.expect_to_fail == false) { |
135 | ASSERT_TRUE(src_dim_sum == p.dst_cds[p.concat_dimension]); |
136 | } |
137 | |
138 | auto eng = get_test_engine(); |
139 | auto strm = make_stream(eng); |
140 | memory::data_type data_type = data_traits<data_t>::data_type; |
141 | |
142 | std::vector<memory::desc> srcs_md; |
143 | std::vector<memory> srcs; |
144 | for (size_t i = 0; i < p.srcs_cds.size(); i++) { |
145 | auto md = memory::desc(p.srcs_cds[i], data_type, p.srcs_format[i]); |
146 | srcs_md.push_back(md); |
147 | } |
148 | |
149 | auto dst_desc = memory::desc(p.dst_cds, data_type, p.dst_format); |
150 | auto concat_pd = concat::primitive_desc( |
151 | eng, dst_desc, static_cast<int>(p.concat_dimension), srcs_md); |
152 | // test construction from a C pd |
153 | concat_pd = concat::primitive_desc(concat_pd.get()); |
154 | |
155 | ASSERT_TRUE(concat_pd.query_md(query::exec_arg_md, DNNL_ARG_DST) |
156 | == concat_pd.dst_desc()); |
157 | |
158 | for (int i = 0; i < (int)srcs.size(); i++) { |
159 | if (p.srcs_format[i] != memory::format_tag::any) { |
160 | ASSERT_TRUE(srcs_md[i] == concat_pd.src_desc(i)); |
161 | } |
162 | } |
163 | |
164 | auto dst = test::make_memory(concat_pd.dst_desc(), eng); |
165 | fill_data<data_t>(dst.get_desc().get_size() / sizeof(data_t), dst); |
166 | check_zero_tail<data_t>(1, dst); |
167 | |
168 | ASSERT_EQ(concat_pd.dst_desc().get_ndims(), dst_desc.get_ndims()); |
169 | |
170 | for (size_t i = 0; i < p.srcs_cds.size(); i++) { |
171 | auto md = concat_pd.src_desc((int)i); |
172 | auto src_memory = test::make_memory(md, eng); |
173 | const size_t sz = src_memory.get_desc().get_size() / sizeof(data_t); |
174 | fill_data<data_t>(sz, src_memory); |
175 | check_zero_tail<data_t>(1, src_memory); |
176 | srcs.push_back(src_memory); |
177 | } |
178 | |
179 | for (int i = 0; i < (int)srcs.size(); i++) |
180 | ASSERT_TRUE(concat_pd.query_md( |
181 | query::exec_arg_md, DNNL_ARG_MULTIPLE_SRC + i) |
182 | == concat_pd.src_desc(i)); |
183 | |
184 | EXPECT_ANY_THROW(concat(concat_pd, {})); |
185 | concat c(concat_pd); |
186 | std::unordered_map<int, memory> args = {{DNNL_ARG_DST, dst}}; |
187 | for (int i = 0; i < (int)srcs.size(); i++) { |
188 | args.insert({DNNL_ARG_MULTIPLE_SRC + i, srcs[i]}); |
189 | } |
190 | c.execute(strm, args); |
191 | strm.wait(); |
192 | |
193 | check_data(srcs, dst, static_cast<int>(p.concat_dimension)); |
194 | check_zero_tail<data_t>(0, dst); |
195 | } |
196 | }; |
197 | |
198 | using concat_test_float = concat_test_t<float>; |
199 | using concat_test_float16 = concat_test_t<float16_t>; |
200 | using concat_test_s8 = concat_test_t<int8_t>; |
201 | using concat_test_bf16 = concat_test_t<bfloat16_t>; |
202 | |
203 | TEST_P(concat_test_float, TestsConcat) {} |
204 | TEST_P(concat_test_s8, TestsConcat) {} |
205 | TEST_P(concat_test_bf16, TestsConcat) {} |
206 | TEST_P(concat_test_float16, TestConcat) {} |
207 | |
208 | using fmt = memory::format_tag; |
209 | |
210 | static auto case_ZeroDim = []() { |
211 | return ::testing::Values( |
212 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
213 | {{4, 0, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}}, |
214 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
215 | {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}}, |
216 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, |
217 | {{4, 0, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}}, |
218 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, |
219 | {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}}, |
220 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
221 | {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}}, |
222 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
223 | {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}}, |
224 | concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc, |
225 | {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}}, |
226 | concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw, |
227 | {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}}, |
228 | concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc, |
229 | {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}}, |
230 | concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw, |
231 | {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}}); |
232 | }; |
233 | INSTANTIATE_TEST_SUITE_P(TestConcat_ZeroDim, concat_test_float, case_ZeroDim()); |
234 | CPU_INSTANTIATE_TEST_SUITE_P( |
235 | TestConcat_ZeroDim_bf16, concat_test_bf16, case_ZeroDim()); |
236 | CPU_INSTANTIATE_TEST_SUITE_P( |
237 | TestConcat_ZeroDim_f16, concat_test_float16, case_ZeroDim()); |
238 | |
239 | static auto cases_EF = []() { |
240 | return ::testing::Values( |
241 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
242 | {{4, 2, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}, true, |
243 | dnnl_invalid_arguments}, |
244 | concat_test_params_t {2, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
245 | {{4, 2, 5, 5}, {4, 3, 5, 5}}, {4, 5, 5, 5}, true, |
246 | dnnl_invalid_arguments}, |
247 | concat_test_params_t {5, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
248 | {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}, true, |
249 | dnnl_invalid_arguments}, |
250 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, |
251 | {{4, -1, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}, true, |
252 | dnnl_invalid_arguments}, |
253 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, |
254 | {{4, 4, 5, 5}, {4, 4, 5, 5}}, {4, 4, 5, 5}, true, |
255 | dnnl_invalid_arguments}, |
256 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
257 | {{0, 4, 5, 5}, {0, 4, 5, 5}}, {0, 6, 5, 5}, true, |
258 | dnnl_invalid_arguments}, |
259 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
260 | {{2, 4, 2, 5}, {2, 2, 1, 5}}, {2, 6, 2, 5}, true, |
261 | dnnl_invalid_arguments}, |
262 | concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc, |
263 | {{1, 4, 5, 5}, {1, 2, 5, 5}}, {1, 7, 5, 5}, true, |
264 | dnnl_invalid_arguments}, |
265 | concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw, |
266 | {{1, 4, 5, 5}, {1, 2, 5, 5}}, {1, 6, 6, 5}, true, |
267 | dnnl_invalid_arguments}, |
268 | concat_test_params_t {1, {fmt::any, fmt::nchw}, fmt::nchw, |
269 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}, true, |
270 | dnnl_invalid_arguments}, |
271 | concat_test_params_t {1, {fmt::nchw, fmt::any}, fmt::nchw, |
272 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}, true, |
273 | dnnl_invalid_arguments}); |
274 | }; |
275 | INSTANTIATE_TEST_SUITE_P(TestConcat_EF, concat_test_float, cases_EF()); |
276 | CPU_INSTANTIATE_TEST_SUITE_P(TestConcat_EF_bf16, concat_test_bf16, cases_EF()); |
277 | CPU_INSTANTIATE_TEST_SUITE_P( |
278 | TestConcat_EF_f16, concat_test_float16, cases_EF()); |
279 | |
280 | static auto cases_padded = []() { |
281 | return ::testing::Values( |
282 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, |
283 | {{1, 12, 28, 28}, {1, 12, 28, 28}}, {1, 24, 28, 28}}, |
284 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, |
285 | {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}}, |
286 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw, |
287 | {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}}, |
288 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw, |
289 | {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}}, |
290 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw8c}, fmt::nchw, |
291 | {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}}, |
292 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
293 | {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}}, |
294 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, |
295 | {{4, 4, 5, 5}, {4, 6, 5, 5}}, {4, 10, 5, 5}}, |
296 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw, |
297 | {{4, 4, 5, 5}, {4, 6, 5, 5}}, {4, 10, 5, 5}}, |
298 | concat_test_params_t {1, {fmt::nchw, fmt::nChw16c}, fmt::nChw16c, |
299 | {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}}, |
300 | concat_test_params_t {1, {fmt::nchw, fmt::nChw16c}, fmt::nchw, |
301 | {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}}, |
302 | // right border |
303 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, |
304 | {{4, 16, 5, 5}, {4, 3, 5, 5}}, {4, 19, 5, 5}}, |
305 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw8c, |
306 | {{4, 16, 5, 5}, {4, 3, 5, 5}}, {4, 19, 5, 5}}, |
307 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, |
308 | {{4, 8, 5, 5}, {4, 3, 5, 5}}, {4, 11, 5, 5}}, |
309 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nChw16c, |
310 | {{4, 8, 5, 5}, {4, 3, 5, 5}}, {4, 11, 5, 5}}, |
311 | // not over channels |
312 | concat_test_params_t {2, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw, |
313 | {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}}, |
314 | concat_test_params_t {2, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw, |
315 | {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}}, |
316 | concat_test_params_t {2, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, |
317 | {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}}); |
318 | }; |
319 | INSTANTIATE_TEST_SUITE_P(TestConcat_padded, concat_test_float, cases_padded()); |
320 | CPU_INSTANTIATE_TEST_SUITE_P( |
321 | TestConcat_padded_bf16, concat_test_bf16, cases_padded()); |
322 | CPU_INSTANTIATE_TEST_SUITE_P( |
323 | TestConcat_padded_f16, concat_test_float16, cases_padded()); |
324 | |
325 | static auto cases_3D = []() { |
326 | return ::testing::Values( |
327 | concat_test_params_t {0, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw, |
328 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {4, 8, 3, 4, 5}}, |
329 | concat_test_params_t {1, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw, |
330 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}}, |
331 | concat_test_params_t {2, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw, |
332 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 6, 4, 5}}, |
333 | concat_test_params_t {3, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw, |
334 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 8, 5}}, |
335 | concat_test_params_t {4, {fmt::ncdhw, fmt::ncdhw}, fmt::ncdhw, |
336 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 4, 10}}, |
337 | concat_test_params_t {0, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c, |
338 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {4, 8, 3, 4, 5}}, |
339 | concat_test_params_t {1, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c, |
340 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}}, |
341 | concat_test_params_t {1, {fmt::nCdhw8c, fmt::ncdhw}, fmt::nCdhw8c, |
342 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}}, |
343 | concat_test_params_t {1, {fmt::ncdhw, fmt::ncdhw}, fmt::nCdhw8c, |
344 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}}, |
345 | concat_test_params_t {2, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c, |
346 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 6, 4, 5}}, |
347 | concat_test_params_t {3, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c, |
348 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 8, 5}}, |
349 | concat_test_params_t {4, {fmt::nCdhw8c, fmt::nCdhw8c}, fmt::nCdhw8c, |
350 | {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 4, 10}}); |
351 | }; |
352 | INSTANTIATE_TEST_SUITE_P(TestConcat3D, concat_test_float, cases_3D()); |
353 | CPU_INSTANTIATE_TEST_SUITE_P(TestConcat3D_bf16, concat_test_bf16, cases_3D()); |
354 | CPU_INSTANTIATE_TEST_SUITE_P(TestConcat3D_f16, concat_test_float16, cases_3D()); |
355 | |
356 | static auto cases_concat = []() { |
357 | return ::testing::Values( |
358 | concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw, |
359 | {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}}, |
360 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, |
361 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}}, |
362 | concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nChw8c, |
363 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}}, |
364 | concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc, |
365 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}}, |
366 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw, |
367 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}}, |
368 | |
369 | concat_test_params_t {0, {fmt::nchw, fmt::nchw}, fmt::nchw, |
370 | {{2, 8, 3, 4}, {2, 8, 3, 4}}, {4, 8, 3, 4}}, |
371 | concat_test_params_t {0, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, |
372 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}}, |
373 | concat_test_params_t {0, {fmt::nchw, fmt::nchw}, fmt::nChw8c, |
374 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}}, |
375 | concat_test_params_t {0, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw, |
376 | {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}}, |
377 | |
378 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, |
379 | {{2, 8, 1, 1}, {2, 8, 1, 1}}, {2, 16, 1, 1}}, |
380 | |
381 | concat_test_params_t {1, {fmt::nChw8c, fmt::nChw16c}, fmt::nChw8c, |
382 | {{2, 8, 1, 1}, {2, 16, 1, 1}}, {2, 24, 1, 1}}); |
383 | }; |
384 | INSTANTIATE_TEST_SUITE_P(TestConcat, concat_test_float, cases_concat()); |
385 | CPU_INSTANTIATE_TEST_SUITE_P(TestConcat_bf16, concat_test_bf16, cases_concat()); |
386 | CPU_INSTANTIATE_TEST_SUITE_P( |
387 | TestConcat_f16, concat_test_float16, cases_concat()); |
388 | |
389 | INSTANTIATE_TEST_SUITE_P(TestConcat, concat_test_s8, |
390 | ::testing::Values( |
391 | concat_test_params_t {1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc, |
392 | {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}}, |
393 | concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw, |
394 | {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}})); |
395 | |
396 | static auto cases_concat_gpu = []() { |
397 | return ::testing::Values(concat_test_params_t {1, {fmt::nchw}, fmt::nchw, |
398 | {{1, 1, 1, 1}}, {1, 1, 1, 1}}, |
399 | concat_test_params_t { |
400 | 1, {fmt::nchw}, fmt::nhwc, {{1, 1, 1, 1}}, {1, 1, 1, 1}}, |
401 | concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw, |
402 | {{1, 1, 1, 1}, {1, 1, 1, 1}}, {1, 2, 1, 1}}, |
403 | concat_test_params_t {1, {fmt::nchw, fmt::nchw}, fmt::nchw, |
404 | {{4, 5, 5, 5}, {4, 5, 5, 5}}, {4, 10, 5, 5}}, |
405 | concat_test_params_t {1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, |
406 | {{4, 16, 5, 5}, {4, 32, 5, 5}}, {4, 48, 5, 5}}, |
407 | concat_test_params_t {0, {fmt::NChw16n16c}, fmt::NChw16n16c, |
408 | {{16, 16, 1, 1}}, {16, 16, 1, 1}}, |
409 | concat_test_params_t {0, {fmt::NChw16n16c, fmt::NChw16n16c}, |
410 | fmt::NChw16n16c, {{16, 16, 1, 1}, {16, 16, 1, 1}}, |
411 | {32, 16, 1, 1}}, |
412 | concat_test_params_t {1, {fmt::NChw16n16c, fmt::NChw16n16c}, |
413 | fmt::NChw16n16c, {{16, 16, 1, 1}, {16, 16, 1, 1}}, |
414 | {16, 32, 1, 1}}, |
415 | concat_test_params_t {2, {fmt::NChw16n16c, fmt::NChw16n16c}, |
416 | fmt::NChw16n16c, {{16, 16, 1, 1}, {16, 16, 1, 1}}, |
417 | {16, 16, 2, 1}}, |
418 | concat_test_params_t {3, {fmt::NChw16n16c, fmt::NChw16n16c}, |
419 | fmt::NChw16n16c, {{16, 16, 1, 1}, {16, 16, 1, 1}}, |
420 | {16, 16, 1, 2}}, |
421 | concat_test_params_t {1, {fmt::NChw16n16c, fmt::NChw16n16c}, |
422 | fmt::NChw16n16c, {{16, 16, 5, 5}, {16, 32, 5, 5}}, |
423 | {16, 48, 5, 5}}, |
424 | concat_test_params_t {1, {fmt::NCdhw16n16c, fmt::NCdhw16n16c}, |
425 | fmt::NCdhw16n16c, {{16, 16, 5, 5, 5}, {16, 32, 5, 5, 5}}, |
426 | {16, 48, 5, 5, 5}}, |
427 | concat_test_params_t {2, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw, |
428 | {{4, 16, 5, 5}, {4, 16, 5, 5}}, {4, 16, 10, 5}}, |
429 | concat_test_params_t {2, {fmt::NChw16n16c, fmt::NChw16n16c}, |
430 | fmt::nchw, {{16, 16, 5, 5}, {16, 16, 5, 5}}, |
431 | {16, 16, 10, 5}}); |
432 | }; |
433 | |
434 | GPU_INSTANTIATE_TEST_SUITE_P(TestConcat, concat_test_float, cases_concat_gpu()); |
435 | GPU_INSTANTIATE_TEST_SUITE_P( |
436 | TestConcat, concat_test_float16, cases_concat_gpu()); |
437 | |
438 | } // namespace dnnl |
439 | |