1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24struct test_inner_product_descr_t {
25 memory::dim mb;
26 memory::dim ic;
27 memory::dim oc;
28 memory::dim kd, kh, kw;
29};
30
31template <typename data_t>
32void compute_ref_inner_product_fwd(test_inner_product_descr_t ipd, memory &src,
33 memory &weights, memory &bias, memory &dst) {
34 const bool w_bias = bias.get_desc().get_ndims() != 0;
35 auto src_data = map_memory<data_t>(src);
36 auto weights_data = map_memory<data_t>(weights);
37 auto bias_data = w_bias ? map_memory<data_t>(bias) : nullptr;
38 auto dst_data = map_memory<data_t>(dst);
39
40 const memory::desc src_d = src.get_desc();
41 const memory::desc weights_d = weights.get_desc();
42 const memory::desc bias_d = bias.get_desc();
43 const memory::desc dst_d = dst.get_desc();
44 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get());
45 const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get());
46 const dnnl::impl::memory_desc_wrapper bias_mdw(bias_d.get());
47 const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get());
48
49 auto padded_ic = src_mdw.padded_dims()[1];
50
51 dnnl::impl::parallel_nd(ipd.mb, ipd.oc, [&](memory::dim n, memory::dim oc) {
52 memory::dim oidx = n * ipd.oc + oc;
53 dst_data[dst_mdw.off_l(oidx, true)]
54 = bias_data ? bias_data[bias_mdw.off_l(oc, true)] : data_t {0};
55 for (memory::dim ic = 0; ic < ipd.ic; ic++) {
56 for_(memory::dim kd = 0; kd < ipd.kd; kd++)
57 for_(memory::dim kh = 0; kh < ipd.kh; kh++)
58 for (memory::dim kw = 0; kw < ipd.kw; kw++) {
59 memory::dim iidx = n * padded_ic * ipd.kd * ipd.kh * ipd.kw
60 + ic * ipd.kd * ipd.kh * ipd.kw + kd * ipd.kh * ipd.kw
61 + kh * ipd.kw + kw;
62 memory::dim widx = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw
63 + ic * ipd.kd * ipd.kh * ipd.kw + kd * ipd.kh * ipd.kw
64 + kh * ipd.kw + kw;
65 dst_data[dst_mdw.off_l(oidx, true)]
66 += src_data[src_mdw.off_l(iidx, true)]
67 * weights_data[weights_mdw.off_l(widx, true)];
68 }
69 }
70 });
71}
72
73struct inprod_test_params_t {
74 prop_kind aprop_kind;
75 memory::format_tag src_format;
76 memory::format_tag weights_format;
77 memory::format_tag bias_format;
78 memory::format_tag dst_format;
79 int ndims;
80 test_inner_product_descr_t test_ipd;
81 bool expect_to_fail;
82 dnnl_status_t expected_status;
83};
84
85template <typename data_t>
86class inner_product_test_t
87 : public ::testing::TestWithParam<inprod_test_params_t> {
88protected:
89 void SetUp() override {
90 auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
91 SKIP_IF_CUDA(!cuda_check_format_tags(p.src_format, p.weights_format,
92 p.bias_format, p.dst_format),
93 "Unsupported format tag");
94 SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions");
95 catch_expected_failures(
96 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
97 }
98
99 bool cuda_check_format_tags(memory::format_tag src_format,
100 memory::format_tag wei_format, memory::format_tag bia_format,
101 memory::format_tag dst_format) {
102 bool src_ok = src_format == memory::format_tag::ncdhw
103 || src_format == memory::format_tag::ndhwc
104 || src_format == memory::format_tag::nchw
105 || src_format == memory::format_tag::nhwc
106 || src_format == memory::format_tag::ncw
107 || src_format == memory::format_tag::nwc
108 || src_format == memory::format_tag::nc
109 || src_format == memory::format_tag::any;
110 bool wei_ok = wei_format == memory::format_tag::oidhw
111 || wei_format == memory::format_tag::odhwi
112 || wei_format == memory::format_tag::dhwio
113 || wei_format == memory::format_tag::oihw
114 || wei_format == memory::format_tag::ohwi
115 || wei_format == memory::format_tag::hwio
116 || wei_format == memory::format_tag::oiw
117 || wei_format == memory::format_tag::owi
118 || wei_format == memory::format_tag::wio
119 || wei_format == memory::format_tag::io
120 || wei_format == memory::format_tag::oi
121 || wei_format == memory::format_tag::any;
122 bool bia_ok = bia_format == memory::format_tag::undef
123 || bia_format == memory::format_tag::any
124 || bia_format == memory::format_tag::a
125 || bia_format == memory::format_tag::x;
126 bool dst_ok = dst_format == memory::format_tag::any
127 || dst_format == memory::format_tag::nc;
128
129 return src_ok && wei_ok && bia_ok && dst_ok;
130 }
131
132 void Test() {
133 auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
134 test_inner_product_descr_t ipd = p.test_ipd;
135 bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
136 if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
137 bool with_bias = p.bias_format != memory::format_tag::undef;
138
139 ASSERT_EQ(p.aprop_kind, prop_kind::forward);
140 auto eng = get_test_engine();
141 auto strm = make_stream(eng);
142 memory::data_type data_type = data_traits<data_t>::data_type;
143 ASSERT_EQ(data_type, dnnl::memory::data_type::f32);
144
145 memory::dims src_dims = {ipd.mb, ipd.ic}, wei_dims = {ipd.oc, ipd.ic};
146 if (has_spatial) {
147 if (p.ndims == 5) {
148 src_dims.push_back(ipd.kd);
149 wei_dims.push_back(ipd.kd);
150 }
151 if (p.ndims >= 4) {
152 src_dims.push_back(ipd.kh);
153 wei_dims.push_back(ipd.kh);
154 }
155 if (p.ndims >= 3) {
156 src_dims.push_back(ipd.kw);
157 wei_dims.push_back(ipd.kw);
158 }
159 }
160 auto ip_src_desc = create_md(src_dims, data_type, p.src_format);
161 auto ip_weights_desc = create_md(wei_dims, data_type, p.weights_format);
162 auto ip_bias_desc = with_bias
163 ? create_md({ipd.oc}, data_type, p.bias_format)
164 : create_md({}, data_type, p.bias_format);
165 auto ip_dst_desc = create_md({ipd.mb, ipd.oc}, data_type, p.dst_format);
166
167 auto ip_primitive_desc = with_bias
168 ? inner_product_forward::primitive_desc(eng, p.aprop_kind,
169 ip_src_desc, ip_weights_desc, ip_bias_desc, ip_dst_desc)
170 : inner_product_forward::primitive_desc(eng, p.aprop_kind,
171 ip_src_desc, ip_weights_desc, ip_dst_desc);
172
173 ip_primitive_desc = inner_product_forward::primitive_desc(
174 ip_primitive_desc.get()); // test construction from a C pd
175
176 auto ip_src = test::make_memory(ip_primitive_desc.src_desc(), eng);
177 auto ip_weights
178 = test::make_memory(ip_primitive_desc.weights_desc(), eng);
179 auto ip_bias = test::make_memory(ip_primitive_desc.bias_desc(), eng);
180 auto ip_dst = test::make_memory(ip_primitive_desc.dst_desc(), eng);
181 auto dst_ref = test::make_memory(ip_primitive_desc.dst_desc(), eng);
182
183 fill_data<data_t>(
184 ip_src.get_desc().get_size() / sizeof(data_t), ip_src);
185 fill_data<data_t>(
186 ip_weights.get_desc().get_size() / sizeof(data_t), ip_weights);
187 if (with_bias) {
188 fill_data<data_t>(
189 ip_bias.get_desc().get_size() / sizeof(data_t), ip_bias);
190 }
191 check_zero_tail<data_t>(1, ip_src);
192 check_zero_tail<data_t>(1, ip_weights);
193
194 ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
195 == ip_primitive_desc.src_desc());
196 ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_DST)
197 == ip_primitive_desc.dst_desc());
198 ASSERT_TRUE(
199 ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS)
200 == ip_primitive_desc.weights_desc());
201 ASSERT_TRUE(
202 ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_BIAS)
203 == ip_primitive_desc.bias_desc());
204
205 ASSERT_EQ(ip_primitive_desc.get_prop_kind(), p.aprop_kind);
206
207 EXPECT_ANY_THROW(inner_product_forward(ip_primitive_desc, {}));
208 inner_product_forward(ip_primitive_desc)
209 .execute(strm,
210 {{DNNL_ARG_SRC, ip_src}, {DNNL_ARG_WEIGHTS, ip_weights},
211 {DNNL_ARG_BIAS, ip_bias},
212 {DNNL_ARG_DST, ip_dst}});
213 strm.wait();
214
215 compute_ref_inner_product_fwd<data_t>(
216 ipd, ip_src, ip_weights, ip_bias, dst_ref);
217 check_zero_tail<data_t>(1, dst_ref);
218 compare_data<data_t>(dst_ref, ip_dst);
219
220 check_zero_tail<data_t>(0, ip_dst);
221 }
222};
223
224using inner_product_test_float = inner_product_test_t<float>;
225using inprod_test_params_float = inprod_test_params_t;
226
227#define EXPAND_SIZES_3D(...) \
228 5, { __VA_ARGS__ }
229#define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \
230 4, { mb, ic, oc, 1, kh, kw }
231#define EXPAND_SIZES_1D(mb, ic, oc, kw) \
232 3, { mb, ic, oc, 1, 1, kw }
233
234TEST_P(inner_product_test_float, TestsInnerProduct) {}
235
236INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardZeroDim,
237 inner_product_test_float,
238 ::testing::Values(inprod_test_params_float {prop_kind::forward,
239 memory::format_tag::any, memory::format_tag::any,
240 memory::format_tag::any, memory::format_tag::any,
241 EXPAND_SIZES_2D(0, 32, 48, 6, 6)}));
242
243INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardEF, inner_product_test_float,
244 ::testing::Values(
245 inprod_test_params_float {prop_kind::forward,
246 memory::format_tag::any, memory::format_tag::any,
247 memory::format_tag::any, memory::format_tag::any,
248 EXPAND_SIZES_2D(2, 0, 48, 6, 6), true,
249 dnnl_invalid_arguments},
250 inprod_test_params_float {prop_kind::forward,
251 memory::format_tag::any, memory::format_tag::any,
252 memory::format_tag::any, memory::format_tag::any,
253 EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true,
254 dnnl_invalid_arguments},
255 inprod_test_params_float {prop_kind::forward,
256 memory::format_tag::any, memory::format_tag::any,
257 memory::format_tag::any, memory::format_tag::any,
258 EXPAND_SIZES_2D(2, -1, 48, 6, 6), true,
259 dnnl_invalid_arguments}));
260
261INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardNoBias_padded,
262 inner_product_test_float,
263 ::testing::Values(
264 inprod_test_params_float {prop_kind::forward,
265 memory::format_tag::nChw16c,
266 memory::format_tag::aBcd16b, memory::format_tag::undef,
267 memory::format_tag::nc,
268 EXPAND_SIZES_2D(4, 14, 25, 5, 5)},
269 inprod_test_params_float {prop_kind::forward,
270 memory::format_tag::nChw16c,
271 memory::format_tag::aBcd16b, memory::format_tag::undef,
272 memory::format_tag::nc,
273 EXPAND_SIZES_2D(4, 20, 15, 5, 5)},
274 inprod_test_params_float {prop_kind::forward,
275 memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
276 memory::format_tag::undef, memory::format_tag::nc,
277 EXPAND_SIZES_2D(4, 6, 15, 5, 5)},
278 inprod_test_params_float {prop_kind::forward,
279 memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
280 memory::format_tag::undef, memory::format_tag::nc,
281 EXPAND_SIZES_2D(4, 10, 5, 5, 5)},
282 inprod_test_params_float {prop_kind::forward,
283 memory::format_tag::nChw4c, memory::format_tag::aBcd4b,
284 memory::format_tag::undef, memory::format_tag::nc,
285 EXPAND_SIZES_2D(4, 16, 5, 5, 5)}));
286
287GPU_INSTANTIATE_TEST_SUITE_P(TestInnerProductForward_padded,
288 inner_product_test_float,
289 ::testing::Values(inprod_test_params_float {prop_kind::forward,
290 memory::format_tag::nChw16c,
291 memory::format_tag::aBcd16b,
292 memory::format_tag::x, memory::format_tag::nc,
293 EXPAND_SIZES_2D(4, 14, 25, 5, 5)},
294 inprod_test_params_float {prop_kind::forward,
295 memory::format_tag::nChw16c,
296 memory::format_tag::aBcd16b, memory::format_tag::x,
297 memory::format_tag::nc,
298 EXPAND_SIZES_2D(4, 20, 15, 5, 5)},
299 inprod_test_params_float {prop_kind::forward,
300 memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
301 memory::format_tag::x, memory::format_tag::nc,
302 EXPAND_SIZES_2D(4, 6, 15, 5, 5)},
303 inprod_test_params_float {prop_kind::forward,
304 memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
305 memory::format_tag::x, memory::format_tag::nc,
306 EXPAND_SIZES_2D(4, 10, 5, 5, 5)}));
307
308INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardNoBias,
309 inner_product_test_float,
310 ::testing::Values(
311 inprod_test_params_float {prop_kind::forward,
312 memory::format_tag::any, memory::format_tag::any,
313 memory::format_tag::undef, memory::format_tag::any,
314 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
315 inprod_test_params_float {prop_kind::forward,
316 memory::format_tag::any, memory::format_tag::any,
317 memory::format_tag::undef, memory::format_tag::any,
318 EXPAND_SIZES_2D(2, 512, 48, 2, 2)},
319 inprod_test_params_float {prop_kind::forward,
320 memory::format_tag::nwc, memory::format_tag::wio,
321 memory::format_tag::undef, memory::format_tag::nc,
322 EXPAND_SIZES_1D(2, 32, 48, 5)},
323 inprod_test_params_float {prop_kind::forward,
324 memory::format_tag::nwc, memory::format_tag::owi,
325 memory::format_tag::undef, memory::format_tag::nc,
326 EXPAND_SIZES_1D(2, 32, 48, 5)},
327 inprod_test_params_float {prop_kind::forward,
328 memory::format_tag::nwc, memory::format_tag::oiw,
329 memory::format_tag::undef, memory::format_tag::nc,
330 EXPAND_SIZES_1D(2, 32, 48, 5)},
331 inprod_test_params_float {prop_kind::forward,
332 memory::format_tag::ncw, memory::format_tag::oiw,
333 memory::format_tag::undef, memory::format_tag::nc,
334 EXPAND_SIZES_1D(2, 32, 48, 5)},
335 inprod_test_params_float {prop_kind::forward,
336 memory::format_tag::ncw, memory::format_tag::owi,
337 memory::format_tag::undef, memory::format_tag::nc,
338 EXPAND_SIZES_1D(2, 32, 48, 5)},
339 inprod_test_params_float {prop_kind::forward,
340 memory::format_tag::ncw, memory::format_tag::wio,
341 memory::format_tag::undef, memory::format_tag::nc,
342 EXPAND_SIZES_1D(2, 32, 48, 5)},
343 inprod_test_params_float {prop_kind::forward,
344 memory::format_tag::nhwc, memory::format_tag::hwio,
345 memory::format_tag::undef, memory::format_tag::nc,
346 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
347 inprod_test_params_float {prop_kind::forward,
348 memory::format_tag::nhwc, memory::format_tag::ohwi,
349 memory::format_tag::undef, memory::format_tag::nc,
350 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
351 inprod_test_params_float {prop_kind::forward,
352 memory::format_tag::nhwc, memory::format_tag::oihw,
353 memory::format_tag::undef, memory::format_tag::nc,
354 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
355 inprod_test_params_float {prop_kind::forward,
356 memory::format_tag::nchw, memory::format_tag::oihw,
357 memory::format_tag::undef, memory::format_tag::nc,
358 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
359 inprod_test_params_float {prop_kind::forward,
360 memory::format_tag::nchw, memory::format_tag::hwio,
361 memory::format_tag::undef, memory::format_tag::nc,
362 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
363 inprod_test_params_float {prop_kind::forward,
364 memory::format_tag::nchw, memory::format_tag::ohwi,
365 memory::format_tag::undef, memory::format_tag::nc,
366 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
367 inprod_test_params_float {prop_kind::forward,
368 memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
369 memory::format_tag::undef, memory::format_tag::nc,
370 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
371 inprod_test_params_float {prop_kind::forward,
372 memory::format_tag::nChw16c,
373 memory::format_tag::aBcd16b, memory::format_tag::undef,
374 memory::format_tag::nc,
375 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
376 inprod_test_params_float {prop_kind::forward,
377 memory::format_tag::any, memory::format_tag::aBcd8b,
378 memory::format_tag::undef, memory::format_tag::nc,
379 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
380 inprod_test_params_float {prop_kind::forward,
381 memory::format_tag::nChw8c, memory::format_tag::any,
382 memory::format_tag::undef, memory::format_tag::nc,
383 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
384 inprod_test_params_float {prop_kind::forward,
385 memory::format_tag::nChw16c,
386 memory::format_tag::aBcd16b, memory::format_tag::undef,
387 memory::format_tag::nc,
388 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
389 inprod_test_params_float {prop_kind::forward,
390 memory::format_tag::nc, memory::format_tag::oi,
391 memory::format_tag::undef, memory::format_tag::nc,
392 EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
393 inprod_test_params_float {prop_kind::forward,
394 memory::format_tag::nc, memory::format_tag::oi,
395 memory::format_tag::undef, memory::format_tag::nc,
396 EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
397 inprod_test_params_float {prop_kind::forward,
398 memory::format_tag::nc, memory::format_tag::io,
399 memory::format_tag::undef, memory::format_tag::nc,
400 EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
401
402INSTANTIATE_TEST_SUITE_P(TestInnerProductForward3D, inner_product_test_float,
403 ::testing::Values(
404 inprod_test_params_float {prop_kind::forward,
405 memory::format_tag::any, memory::format_tag::any,
406 memory::format_tag::undef, memory::format_tag::any,
407 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
408 inprod_test_params_float {prop_kind::forward,
409 memory::format_tag::ncdhw, memory::format_tag::dhwio,
410 memory::format_tag::x, memory::format_tag::nc,
411 EXPAND_SIZES_3D(2, 32, 48, 3, 5, 7)},
412 inprod_test_params_float {prop_kind::forward,
413 memory::format_tag::ncdhw, memory::format_tag::odhwi,
414 memory::format_tag::undef, memory::format_tag::nc,
415 EXPAND_SIZES_3D(2, 32, 48, 2, 4, 6)},
416 inprod_test_params_float {prop_kind::forward,
417 memory::format_tag::ncdhw, memory::format_tag::oidhw,
418 memory::format_tag::undef, memory::format_tag::nc,
419 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
420 inprod_test_params_float {prop_kind::forward,
421 memory::format_tag::nCdhw8c,
422 memory::format_tag::aBcde8b, memory::format_tag::x,
423 memory::format_tag::nc,
424 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
425 inprod_test_params_float {prop_kind::forward,
426 memory::format_tag::nCdhw16c,
427 memory::format_tag::aBcde16b, memory::format_tag::x,
428 memory::format_tag::nc,
429 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
430 inprod_test_params_float {prop_kind::forward,
431 memory::format_tag::ndhwc, memory::format_tag::dhwio,
432 memory::format_tag::undef, memory::format_tag::nc,
433 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
434 inprod_test_params_float {prop_kind::forward,
435 memory::format_tag::ndhwc, memory::format_tag::odhwi,
436 memory::format_tag::undef, memory::format_tag::nc,
437 EXPAND_SIZES_3D(2, 16, 48, 3, 4, 5)},
438 inprod_test_params_float {prop_kind::forward,
439 memory::format_tag::ndhwc, memory::format_tag::oidhw,
440 memory::format_tag::undef, memory::format_tag::nc,
441 EXPAND_SIZES_3D(2, 16, 48, 3, 5, 4)}));
442
443INSTANTIATE_TEST_SUITE_P(TestInnerProductForward, inner_product_test_float,
444 ::testing::Values(
445 inprod_test_params_float {prop_kind::forward,
446 memory::format_tag::any, memory::format_tag::any,
447 memory::format_tag::any, memory::format_tag::any,
448 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
449 inprod_test_params_float {prop_kind::forward,
450 memory::format_tag::any, memory::format_tag::any,
451 memory::format_tag::any, memory::format_tag::any,
452 EXPAND_SIZES_2D(2, 512, 48, 2, 2)},
453 inprod_test_params_float {prop_kind::forward,
454 memory::format_tag::nhwc, memory::format_tag::oihw,
455 memory::format_tag::x, memory::format_tag::nc,
456 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
457 inprod_test_params_float {prop_kind::forward,
458 memory::format_tag::nhwc, memory::format_tag::hwio,
459 memory::format_tag::x, memory::format_tag::nc,
460 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
461 inprod_test_params_float {prop_kind::forward,
462 memory::format_tag::nchw, memory::format_tag::oihw,
463 memory::format_tag::x, memory::format_tag::nc,
464 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
465 inprod_test_params_float {prop_kind::forward,
466 memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
467 memory::format_tag::x, memory::format_tag::nc,
468 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
469 inprod_test_params_float {prop_kind::forward,
470 memory::format_tag::nChw16c,
471 memory::format_tag::aBcd16b, memory::format_tag::x,
472 memory::format_tag::nc,
473 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
474 inprod_test_params_float {prop_kind::forward,
475 memory::format_tag::nc, memory::format_tag::oi,
476 memory::format_tag::x, memory::format_tag::nc,
477 EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
478 inprod_test_params_float {prop_kind::forward,
479 memory::format_tag::nc, memory::format_tag::oi,
480 memory::format_tag::x, memory::format_tag::nc,
481 EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
482 inprod_test_params_float {prop_kind::forward,
483 memory::format_tag::nc, memory::format_tag::oi,
484 memory::format_tag::x, memory::format_tag::nc,
485 EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
486} // namespace dnnl
487