1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | struct test_inner_product_descr_t { |
25 | memory::dim mb; |
26 | memory::dim ic; |
27 | memory::dim oc; |
28 | memory::dim kd, kh, kw; |
29 | }; |
30 | |
31 | template <typename data_t> |
32 | void compute_ref_inner_product_fwd(test_inner_product_descr_t ipd, memory &src, |
33 | memory &weights, memory &bias, memory &dst) { |
34 | const bool w_bias = bias.get_desc().get_ndims() != 0; |
35 | auto src_data = map_memory<data_t>(src); |
36 | auto weights_data = map_memory<data_t>(weights); |
37 | auto bias_data = w_bias ? map_memory<data_t>(bias) : nullptr; |
38 | auto dst_data = map_memory<data_t>(dst); |
39 | |
40 | const memory::desc src_d = src.get_desc(); |
41 | const memory::desc weights_d = weights.get_desc(); |
42 | const memory::desc bias_d = bias.get_desc(); |
43 | const memory::desc dst_d = dst.get_desc(); |
44 | const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get()); |
45 | const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get()); |
46 | const dnnl::impl::memory_desc_wrapper bias_mdw(bias_d.get()); |
47 | const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get()); |
48 | |
49 | auto padded_ic = src_mdw.padded_dims()[1]; |
50 | |
51 | dnnl::impl::parallel_nd(ipd.mb, ipd.oc, [&](memory::dim n, memory::dim oc) { |
52 | memory::dim oidx = n * ipd.oc + oc; |
53 | dst_data[dst_mdw.off_l(oidx, true)] |
54 | = bias_data ? bias_data[bias_mdw.off_l(oc, true)] : data_t {0}; |
55 | for (memory::dim ic = 0; ic < ipd.ic; ic++) { |
56 | for_(memory::dim kd = 0; kd < ipd.kd; kd++) |
57 | for_(memory::dim kh = 0; kh < ipd.kh; kh++) |
58 | for (memory::dim kw = 0; kw < ipd.kw; kw++) { |
59 | memory::dim iidx = n * padded_ic * ipd.kd * ipd.kh * ipd.kw |
60 | + ic * ipd.kd * ipd.kh * ipd.kw + kd * ipd.kh * ipd.kw |
61 | + kh * ipd.kw + kw; |
62 | memory::dim widx = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw |
63 | + ic * ipd.kd * ipd.kh * ipd.kw + kd * ipd.kh * ipd.kw |
64 | + kh * ipd.kw + kw; |
65 | dst_data[dst_mdw.off_l(oidx, true)] |
66 | += src_data[src_mdw.off_l(iidx, true)] |
67 | * weights_data[weights_mdw.off_l(widx, true)]; |
68 | } |
69 | } |
70 | }); |
71 | } |
72 | |
73 | struct inprod_test_params_t { |
74 | prop_kind aprop_kind; |
75 | memory::format_tag src_format; |
76 | memory::format_tag weights_format; |
77 | memory::format_tag bias_format; |
78 | memory::format_tag dst_format; |
79 | int ndims; |
80 | test_inner_product_descr_t test_ipd; |
81 | bool expect_to_fail; |
82 | dnnl_status_t expected_status; |
83 | }; |
84 | |
85 | template <typename data_t> |
86 | class inner_product_test_t |
87 | : public ::testing::TestWithParam<inprod_test_params_t> { |
88 | protected: |
89 | void SetUp() override { |
90 | auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam(); |
91 | SKIP_IF_CUDA(!cuda_check_format_tags(p.src_format, p.weights_format, |
92 | p.bias_format, p.dst_format), |
93 | "Unsupported format tag" ); |
94 | SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions" ); |
95 | catch_expected_failures( |
96 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
97 | } |
98 | |
99 | bool cuda_check_format_tags(memory::format_tag src_format, |
100 | memory::format_tag wei_format, memory::format_tag bia_format, |
101 | memory::format_tag dst_format) { |
102 | bool src_ok = src_format == memory::format_tag::ncdhw |
103 | || src_format == memory::format_tag::ndhwc |
104 | || src_format == memory::format_tag::nchw |
105 | || src_format == memory::format_tag::nhwc |
106 | || src_format == memory::format_tag::ncw |
107 | || src_format == memory::format_tag::nwc |
108 | || src_format == memory::format_tag::nc |
109 | || src_format == memory::format_tag::any; |
110 | bool wei_ok = wei_format == memory::format_tag::oidhw |
111 | || wei_format == memory::format_tag::odhwi |
112 | || wei_format == memory::format_tag::dhwio |
113 | || wei_format == memory::format_tag::oihw |
114 | || wei_format == memory::format_tag::ohwi |
115 | || wei_format == memory::format_tag::hwio |
116 | || wei_format == memory::format_tag::oiw |
117 | || wei_format == memory::format_tag::owi |
118 | || wei_format == memory::format_tag::wio |
119 | || wei_format == memory::format_tag::io |
120 | || wei_format == memory::format_tag::oi |
121 | || wei_format == memory::format_tag::any; |
122 | bool bia_ok = bia_format == memory::format_tag::undef |
123 | || bia_format == memory::format_tag::any |
124 | || bia_format == memory::format_tag::a |
125 | || bia_format == memory::format_tag::x; |
126 | bool dst_ok = dst_format == memory::format_tag::any |
127 | || dst_format == memory::format_tag::nc; |
128 | |
129 | return src_ok && wei_ok && bia_ok && dst_ok; |
130 | } |
131 | |
132 | void Test() { |
133 | auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam(); |
134 | test_inner_product_descr_t ipd = p.test_ipd; |
135 | bool has_spatial = ipd.kh > 1 || ipd.kw > 1; |
136 | if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1; |
137 | bool with_bias = p.bias_format != memory::format_tag::undef; |
138 | |
139 | ASSERT_EQ(p.aprop_kind, prop_kind::forward); |
140 | auto eng = get_test_engine(); |
141 | auto strm = make_stream(eng); |
142 | memory::data_type data_type = data_traits<data_t>::data_type; |
143 | ASSERT_EQ(data_type, dnnl::memory::data_type::f32); |
144 | |
145 | memory::dims src_dims = {ipd.mb, ipd.ic}, wei_dims = {ipd.oc, ipd.ic}; |
146 | if (has_spatial) { |
147 | if (p.ndims == 5) { |
148 | src_dims.push_back(ipd.kd); |
149 | wei_dims.push_back(ipd.kd); |
150 | } |
151 | if (p.ndims >= 4) { |
152 | src_dims.push_back(ipd.kh); |
153 | wei_dims.push_back(ipd.kh); |
154 | } |
155 | if (p.ndims >= 3) { |
156 | src_dims.push_back(ipd.kw); |
157 | wei_dims.push_back(ipd.kw); |
158 | } |
159 | } |
160 | auto ip_src_desc = create_md(src_dims, data_type, p.src_format); |
161 | auto ip_weights_desc = create_md(wei_dims, data_type, p.weights_format); |
162 | auto ip_bias_desc = with_bias |
163 | ? create_md({ipd.oc}, data_type, p.bias_format) |
164 | : create_md({}, data_type, p.bias_format); |
165 | auto ip_dst_desc = create_md({ipd.mb, ipd.oc}, data_type, p.dst_format); |
166 | |
167 | auto ip_primitive_desc = with_bias |
168 | ? inner_product_forward::primitive_desc(eng, p.aprop_kind, |
169 | ip_src_desc, ip_weights_desc, ip_bias_desc, ip_dst_desc) |
170 | : inner_product_forward::primitive_desc(eng, p.aprop_kind, |
171 | ip_src_desc, ip_weights_desc, ip_dst_desc); |
172 | |
173 | ip_primitive_desc = inner_product_forward::primitive_desc( |
174 | ip_primitive_desc.get()); // test construction from a C pd |
175 | |
176 | auto ip_src = test::make_memory(ip_primitive_desc.src_desc(), eng); |
177 | auto ip_weights |
178 | = test::make_memory(ip_primitive_desc.weights_desc(), eng); |
179 | auto ip_bias = test::make_memory(ip_primitive_desc.bias_desc(), eng); |
180 | auto ip_dst = test::make_memory(ip_primitive_desc.dst_desc(), eng); |
181 | auto dst_ref = test::make_memory(ip_primitive_desc.dst_desc(), eng); |
182 | |
183 | fill_data<data_t>( |
184 | ip_src.get_desc().get_size() / sizeof(data_t), ip_src); |
185 | fill_data<data_t>( |
186 | ip_weights.get_desc().get_size() / sizeof(data_t), ip_weights); |
187 | if (with_bias) { |
188 | fill_data<data_t>( |
189 | ip_bias.get_desc().get_size() / sizeof(data_t), ip_bias); |
190 | } |
191 | check_zero_tail<data_t>(1, ip_src); |
192 | check_zero_tail<data_t>(1, ip_weights); |
193 | |
194 | ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC) |
195 | == ip_primitive_desc.src_desc()); |
196 | ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_DST) |
197 | == ip_primitive_desc.dst_desc()); |
198 | ASSERT_TRUE( |
199 | ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS) |
200 | == ip_primitive_desc.weights_desc()); |
201 | ASSERT_TRUE( |
202 | ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_BIAS) |
203 | == ip_primitive_desc.bias_desc()); |
204 | |
205 | ASSERT_EQ(ip_primitive_desc.get_prop_kind(), p.aprop_kind); |
206 | |
207 | EXPECT_ANY_THROW(inner_product_forward(ip_primitive_desc, {})); |
208 | inner_product_forward(ip_primitive_desc) |
209 | .execute(strm, |
210 | {{DNNL_ARG_SRC, ip_src}, {DNNL_ARG_WEIGHTS, ip_weights}, |
211 | {DNNL_ARG_BIAS, ip_bias}, |
212 | {DNNL_ARG_DST, ip_dst}}); |
213 | strm.wait(); |
214 | |
215 | compute_ref_inner_product_fwd<data_t>( |
216 | ipd, ip_src, ip_weights, ip_bias, dst_ref); |
217 | check_zero_tail<data_t>(1, dst_ref); |
218 | compare_data<data_t>(dst_ref, ip_dst); |
219 | |
220 | check_zero_tail<data_t>(0, ip_dst); |
221 | } |
222 | }; |
223 | |
224 | using inner_product_test_float = inner_product_test_t<float>; |
225 | using inprod_test_params_float = inprod_test_params_t; |
226 | |
227 | #define EXPAND_SIZES_3D(...) \ |
228 | 5, { __VA_ARGS__ } |
229 | #define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \ |
230 | 4, { mb, ic, oc, 1, kh, kw } |
231 | #define EXPAND_SIZES_1D(mb, ic, oc, kw) \ |
232 | 3, { mb, ic, oc, 1, 1, kw } |
233 | |
234 | TEST_P(inner_product_test_float, TestsInnerProduct) {} |
235 | |
236 | INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardZeroDim, |
237 | inner_product_test_float, |
238 | ::testing::Values(inprod_test_params_float {prop_kind::forward, |
239 | memory::format_tag::any, memory::format_tag::any, |
240 | memory::format_tag::any, memory::format_tag::any, |
241 | EXPAND_SIZES_2D(0, 32, 48, 6, 6)})); |
242 | |
243 | INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardEF, inner_product_test_float, |
244 | ::testing::Values( |
245 | inprod_test_params_float {prop_kind::forward, |
246 | memory::format_tag::any, memory::format_tag::any, |
247 | memory::format_tag::any, memory::format_tag::any, |
248 | EXPAND_SIZES_2D(2, 0, 48, 6, 6), true, |
249 | dnnl_invalid_arguments}, |
250 | inprod_test_params_float {prop_kind::forward, |
251 | memory::format_tag::any, memory::format_tag::any, |
252 | memory::format_tag::any, memory::format_tag::any, |
253 | EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true, |
254 | dnnl_invalid_arguments}, |
255 | inprod_test_params_float {prop_kind::forward, |
256 | memory::format_tag::any, memory::format_tag::any, |
257 | memory::format_tag::any, memory::format_tag::any, |
258 | EXPAND_SIZES_2D(2, -1, 48, 6, 6), true, |
259 | dnnl_invalid_arguments})); |
260 | |
261 | INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardNoBias_padded, |
262 | inner_product_test_float, |
263 | ::testing::Values( |
264 | inprod_test_params_float {prop_kind::forward, |
265 | memory::format_tag::nChw16c, |
266 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
267 | memory::format_tag::nc, |
268 | EXPAND_SIZES_2D(4, 14, 25, 5, 5)}, |
269 | inprod_test_params_float {prop_kind::forward, |
270 | memory::format_tag::nChw16c, |
271 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
272 | memory::format_tag::nc, |
273 | EXPAND_SIZES_2D(4, 20, 15, 5, 5)}, |
274 | inprod_test_params_float {prop_kind::forward, |
275 | memory::format_tag::nChw8c, memory::format_tag::aBcd8b, |
276 | memory::format_tag::undef, memory::format_tag::nc, |
277 | EXPAND_SIZES_2D(4, 6, 15, 5, 5)}, |
278 | inprod_test_params_float {prop_kind::forward, |
279 | memory::format_tag::nChw8c, memory::format_tag::aBcd8b, |
280 | memory::format_tag::undef, memory::format_tag::nc, |
281 | EXPAND_SIZES_2D(4, 10, 5, 5, 5)}, |
282 | inprod_test_params_float {prop_kind::forward, |
283 | memory::format_tag::nChw4c, memory::format_tag::aBcd4b, |
284 | memory::format_tag::undef, memory::format_tag::nc, |
285 | EXPAND_SIZES_2D(4, 16, 5, 5, 5)})); |
286 | |
287 | GPU_INSTANTIATE_TEST_SUITE_P(TestInnerProductForward_padded, |
288 | inner_product_test_float, |
289 | ::testing::Values(inprod_test_params_float {prop_kind::forward, |
290 | memory::format_tag::nChw16c, |
291 | memory::format_tag::aBcd16b, |
292 | memory::format_tag::x, memory::format_tag::nc, |
293 | EXPAND_SIZES_2D(4, 14, 25, 5, 5)}, |
294 | inprod_test_params_float {prop_kind::forward, |
295 | memory::format_tag::nChw16c, |
296 | memory::format_tag::aBcd16b, memory::format_tag::x, |
297 | memory::format_tag::nc, |
298 | EXPAND_SIZES_2D(4, 20, 15, 5, 5)}, |
299 | inprod_test_params_float {prop_kind::forward, |
300 | memory::format_tag::nChw8c, memory::format_tag::aBcd8b, |
301 | memory::format_tag::x, memory::format_tag::nc, |
302 | EXPAND_SIZES_2D(4, 6, 15, 5, 5)}, |
303 | inprod_test_params_float {prop_kind::forward, |
304 | memory::format_tag::nChw8c, memory::format_tag::aBcd8b, |
305 | memory::format_tag::x, memory::format_tag::nc, |
306 | EXPAND_SIZES_2D(4, 10, 5, 5, 5)})); |
307 | |
308 | INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardNoBias, |
309 | inner_product_test_float, |
310 | ::testing::Values( |
311 | inprod_test_params_float {prop_kind::forward, |
312 | memory::format_tag::any, memory::format_tag::any, |
313 | memory::format_tag::undef, memory::format_tag::any, |
314 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
315 | inprod_test_params_float {prop_kind::forward, |
316 | memory::format_tag::any, memory::format_tag::any, |
317 | memory::format_tag::undef, memory::format_tag::any, |
318 | EXPAND_SIZES_2D(2, 512, 48, 2, 2)}, |
319 | inprod_test_params_float {prop_kind::forward, |
320 | memory::format_tag::nwc, memory::format_tag::wio, |
321 | memory::format_tag::undef, memory::format_tag::nc, |
322 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
323 | inprod_test_params_float {prop_kind::forward, |
324 | memory::format_tag::nwc, memory::format_tag::owi, |
325 | memory::format_tag::undef, memory::format_tag::nc, |
326 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
327 | inprod_test_params_float {prop_kind::forward, |
328 | memory::format_tag::nwc, memory::format_tag::oiw, |
329 | memory::format_tag::undef, memory::format_tag::nc, |
330 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
331 | inprod_test_params_float {prop_kind::forward, |
332 | memory::format_tag::ncw, memory::format_tag::oiw, |
333 | memory::format_tag::undef, memory::format_tag::nc, |
334 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
335 | inprod_test_params_float {prop_kind::forward, |
336 | memory::format_tag::ncw, memory::format_tag::owi, |
337 | memory::format_tag::undef, memory::format_tag::nc, |
338 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
339 | inprod_test_params_float {prop_kind::forward, |
340 | memory::format_tag::ncw, memory::format_tag::wio, |
341 | memory::format_tag::undef, memory::format_tag::nc, |
342 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
343 | inprod_test_params_float {prop_kind::forward, |
344 | memory::format_tag::nhwc, memory::format_tag::hwio, |
345 | memory::format_tag::undef, memory::format_tag::nc, |
346 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
347 | inprod_test_params_float {prop_kind::forward, |
348 | memory::format_tag::nhwc, memory::format_tag::ohwi, |
349 | memory::format_tag::undef, memory::format_tag::nc, |
350 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
351 | inprod_test_params_float {prop_kind::forward, |
352 | memory::format_tag::nhwc, memory::format_tag::oihw, |
353 | memory::format_tag::undef, memory::format_tag::nc, |
354 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
355 | inprod_test_params_float {prop_kind::forward, |
356 | memory::format_tag::nchw, memory::format_tag::oihw, |
357 | memory::format_tag::undef, memory::format_tag::nc, |
358 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
359 | inprod_test_params_float {prop_kind::forward, |
360 | memory::format_tag::nchw, memory::format_tag::hwio, |
361 | memory::format_tag::undef, memory::format_tag::nc, |
362 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
363 | inprod_test_params_float {prop_kind::forward, |
364 | memory::format_tag::nchw, memory::format_tag::ohwi, |
365 | memory::format_tag::undef, memory::format_tag::nc, |
366 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
367 | inprod_test_params_float {prop_kind::forward, |
368 | memory::format_tag::nChw8c, memory::format_tag::aBcd8b, |
369 | memory::format_tag::undef, memory::format_tag::nc, |
370 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
371 | inprod_test_params_float {prop_kind::forward, |
372 | memory::format_tag::nChw16c, |
373 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
374 | memory::format_tag::nc, |
375 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
376 | inprod_test_params_float {prop_kind::forward, |
377 | memory::format_tag::any, memory::format_tag::aBcd8b, |
378 | memory::format_tag::undef, memory::format_tag::nc, |
379 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
380 | inprod_test_params_float {prop_kind::forward, |
381 | memory::format_tag::nChw8c, memory::format_tag::any, |
382 | memory::format_tag::undef, memory::format_tag::nc, |
383 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
384 | inprod_test_params_float {prop_kind::forward, |
385 | memory::format_tag::nChw16c, |
386 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
387 | memory::format_tag::nc, |
388 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
389 | inprod_test_params_float {prop_kind::forward, |
390 | memory::format_tag::nc, memory::format_tag::oi, |
391 | memory::format_tag::undef, memory::format_tag::nc, |
392 | EXPAND_SIZES_2D(2, 32, 1152, 1, 1)}, |
393 | inprod_test_params_float {prop_kind::forward, |
394 | memory::format_tag::nc, memory::format_tag::oi, |
395 | memory::format_tag::undef, memory::format_tag::nc, |
396 | EXPAND_SIZES_2D(2, 2, 4, 1, 1)}, |
397 | inprod_test_params_float {prop_kind::forward, |
398 | memory::format_tag::nc, memory::format_tag::io, |
399 | memory::format_tag::undef, memory::format_tag::nc, |
400 | EXPAND_SIZES_2D(2, 8, 16, 1, 1)})); |
401 | |
402 | INSTANTIATE_TEST_SUITE_P(TestInnerProductForward3D, inner_product_test_float, |
403 | ::testing::Values( |
404 | inprod_test_params_float {prop_kind::forward, |
405 | memory::format_tag::any, memory::format_tag::any, |
406 | memory::format_tag::undef, memory::format_tag::any, |
407 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
408 | inprod_test_params_float {prop_kind::forward, |
409 | memory::format_tag::ncdhw, memory::format_tag::dhwio, |
410 | memory::format_tag::x, memory::format_tag::nc, |
411 | EXPAND_SIZES_3D(2, 32, 48, 3, 5, 7)}, |
412 | inprod_test_params_float {prop_kind::forward, |
413 | memory::format_tag::ncdhw, memory::format_tag::odhwi, |
414 | memory::format_tag::undef, memory::format_tag::nc, |
415 | EXPAND_SIZES_3D(2, 32, 48, 2, 4, 6)}, |
416 | inprod_test_params_float {prop_kind::forward, |
417 | memory::format_tag::ncdhw, memory::format_tag::oidhw, |
418 | memory::format_tag::undef, memory::format_tag::nc, |
419 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
420 | inprod_test_params_float {prop_kind::forward, |
421 | memory::format_tag::nCdhw8c, |
422 | memory::format_tag::aBcde8b, memory::format_tag::x, |
423 | memory::format_tag::nc, |
424 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
425 | inprod_test_params_float {prop_kind::forward, |
426 | memory::format_tag::nCdhw16c, |
427 | memory::format_tag::aBcde16b, memory::format_tag::x, |
428 | memory::format_tag::nc, |
429 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
430 | inprod_test_params_float {prop_kind::forward, |
431 | memory::format_tag::ndhwc, memory::format_tag::dhwio, |
432 | memory::format_tag::undef, memory::format_tag::nc, |
433 | EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}, |
434 | inprod_test_params_float {prop_kind::forward, |
435 | memory::format_tag::ndhwc, memory::format_tag::odhwi, |
436 | memory::format_tag::undef, memory::format_tag::nc, |
437 | EXPAND_SIZES_3D(2, 16, 48, 3, 4, 5)}, |
438 | inprod_test_params_float {prop_kind::forward, |
439 | memory::format_tag::ndhwc, memory::format_tag::oidhw, |
440 | memory::format_tag::undef, memory::format_tag::nc, |
441 | EXPAND_SIZES_3D(2, 16, 48, 3, 5, 4)})); |
442 | |
443 | INSTANTIATE_TEST_SUITE_P(TestInnerProductForward, inner_product_test_float, |
444 | ::testing::Values( |
445 | inprod_test_params_float {prop_kind::forward, |
446 | memory::format_tag::any, memory::format_tag::any, |
447 | memory::format_tag::any, memory::format_tag::any, |
448 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
449 | inprod_test_params_float {prop_kind::forward, |
450 | memory::format_tag::any, memory::format_tag::any, |
451 | memory::format_tag::any, memory::format_tag::any, |
452 | EXPAND_SIZES_2D(2, 512, 48, 2, 2)}, |
453 | inprod_test_params_float {prop_kind::forward, |
454 | memory::format_tag::nhwc, memory::format_tag::oihw, |
455 | memory::format_tag::x, memory::format_tag::nc, |
456 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
457 | inprod_test_params_float {prop_kind::forward, |
458 | memory::format_tag::nhwc, memory::format_tag::hwio, |
459 | memory::format_tag::x, memory::format_tag::nc, |
460 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
461 | inprod_test_params_float {prop_kind::forward, |
462 | memory::format_tag::nchw, memory::format_tag::oihw, |
463 | memory::format_tag::x, memory::format_tag::nc, |
464 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
465 | inprod_test_params_float {prop_kind::forward, |
466 | memory::format_tag::nChw8c, memory::format_tag::aBcd8b, |
467 | memory::format_tag::x, memory::format_tag::nc, |
468 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
469 | inprod_test_params_float {prop_kind::forward, |
470 | memory::format_tag::nChw16c, |
471 | memory::format_tag::aBcd16b, memory::format_tag::x, |
472 | memory::format_tag::nc, |
473 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
474 | inprod_test_params_float {prop_kind::forward, |
475 | memory::format_tag::nc, memory::format_tag::oi, |
476 | memory::format_tag::x, memory::format_tag::nc, |
477 | EXPAND_SIZES_2D(2, 32, 1152, 1, 1)}, |
478 | inprod_test_params_float {prop_kind::forward, |
479 | memory::format_tag::nc, memory::format_tag::oi, |
480 | memory::format_tag::x, memory::format_tag::nc, |
481 | EXPAND_SIZES_2D(2, 2, 4, 1, 1)}, |
482 | inprod_test_params_float {prop_kind::forward, |
483 | memory::format_tag::nc, memory::format_tag::oi, |
484 | memory::format_tag::x, memory::format_tag::nc, |
485 | EXPAND_SIZES_2D(2, 8, 16, 1, 1)})); |
486 | } // namespace dnnl |
487 | |