1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24struct test_inner_product_descr_t {
25 memory::dim mb;
26 memory::dim ic;
27 memory::dim oc;
28 memory::dim kd, kh, kw;
29};
30
31template <typename data_t>
32void compute_ref_inner_product_bwd_bias(const test_inner_product_descr_t &ipd,
33 const memory &diff_dst, const memory &diff_bias) {
34 auto diff_bias_data = map_memory<data_t>(diff_bias);
35 auto diff_dst_data = map_memory<data_t>(diff_dst);
36
37 const memory::desc diff_bias_d = diff_bias.get_desc();
38 const memory::desc diff_dst_d = diff_dst.get_desc();
39 const dnnl::impl::memory_desc_wrapper diff_bias_mdw(diff_bias_d.get());
40 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get());
41
42 dnnl::impl::parallel_nd(ipd.oc, [&](memory::dim oc) {
43 data_t *db = &diff_bias_data[diff_bias_mdw.off_l(oc, true)];
44 *db = data_t(0);
45 for (memory::dim n = 0; n < ipd.mb; ++n) {
46 *db += diff_dst_data[diff_dst_mdw.off_l(n * ipd.oc + oc, true)];
47 }
48 });
49}
50
51template <typename data_t>
52void compute_ref_inner_product_bwd_weights(int ndims,
53 const test_inner_product_descr_t &ipd, const memory &src,
54 const memory &diff_dst, const memory &diff_weights) {
55 auto src_data = map_memory<data_t>(src);
56 auto diff_weights_data = map_memory<data_t>(diff_weights);
57 auto diff_dst_data = map_memory<data_t>(diff_dst);
58
59 const memory::desc src_d = src.get_desc();
60 const memory::desc diff_weights_d = diff_weights.get_desc();
61 const memory::desc diff_dst_d = diff_dst.get_desc();
62 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get());
63 const dnnl::impl::memory_desc_wrapper diff_weights_mdw(
64 diff_weights_d.get());
65 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get());
66
67 auto padded_ic = src_d.get_padded_dims()[1];
68
69 bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
70 if (ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
71 dnnl::impl::parallel_nd(
72 ipd.oc, ipd.ic, [&](memory::dim oc, memory::dim ic) {
73 if (has_spatial) {
74 for_(memory::dim kd = 0; kd < ipd.kd; ++kd)
75 for_(memory::dim kh = 0; kh < ipd.kh; ++kh)
76 for (memory::dim kw = 0; kw < ipd.kw; ++kw) {
77 memory::dim dwidx
78 = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw
79 + ic * ipd.kd * ipd.kh * ipd.kw
80 + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw;
81 data_t *dw = &diff_weights_data[diff_weights_mdw.off_l(
82 dwidx, true)];
83 *dw = data_t(0);
84 for (memory::dim n = 0; n < ipd.mb; ++n) {
85 memory::dim ddidx = n * ipd.oc + oc;
86 memory::dim sidx
87 = n * padded_ic * ipd.kd * ipd.kh * ipd.kw
88 + ic * ipd.kd * ipd.kh * ipd.kw
89 + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw;
90 *dw += diff_dst_data[diff_dst_mdw.off_l(
91 ddidx, true)]
92 * src_data[src_mdw.off_l(sidx, true)];
93 }
94 }
95 } else {
96 memory::dim dwidx = oc * ipd.ic + ic;
97 data_t *dw = &diff_weights_data[diff_weights_mdw.off_l(
98 dwidx, true)];
99 *dw = data_t(0);
100 for (memory::dim n = 0; n < ipd.mb; ++n) {
101 memory::dim ddidx = n * ipd.oc + oc;
102 memory::dim sidx = n * ipd.ic + ic;
103 *dw += diff_dst_data[diff_dst_mdw.off_l(ddidx, true)]
104 * src_data[src_mdw.off_l(sidx, true)];
105 }
106 }
107 });
108}
109
110struct inprod_test_params_t {
111 memory::format_tag src_format;
112 memory::format_tag diff_weights_format;
113 memory::format_tag diff_bias_format;
114 memory::format_tag diff_dst_format;
115 int ndims;
116 test_inner_product_descr_t test_ipd;
117 bool expect_to_fail;
118 dnnl_status_t expected_status;
119};
120
121template <typename data_t>
122class inner_product_test_bwd_weights_t
123 : public ::testing::TestWithParam<inprod_test_params_t> {
124protected:
125 void SetUp() override {
126 auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
127 SKIP_IF_CUDA(
128 !cuda_check_format_tags(p.src_format, p.diff_weights_format,
129 p.diff_bias_format, p.diff_dst_format),
130 "Unsupported format tag");
131 SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions");
132 catch_expected_failures(
133 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
134 }
135
136 bool cuda_check_format_tags(memory::format_tag src_format,
137 memory::format_tag diff_wei_format,
138 memory::format_tag diff_bia_format,
139 memory::format_tag diff_dst_format) {
140 bool src_ok = src_format == memory::format_tag::ncdhw
141 || src_format == memory::format_tag::ndhwc
142 || src_format == memory::format_tag::nchw
143 || src_format == memory::format_tag::nhwc
144 || src_format == memory::format_tag::ncw
145 || src_format == memory::format_tag::nwc
146 || src_format == memory::format_tag::nc
147 || src_format == memory::format_tag::any;
148 bool diff_wei_ok = diff_wei_format == memory::format_tag::oidhw
149 || diff_wei_format == memory::format_tag::odhwi
150 || diff_wei_format == memory::format_tag::dhwio
151 || diff_wei_format == memory::format_tag::oihw
152 || diff_wei_format == memory::format_tag::ohwi
153 || diff_wei_format == memory::format_tag::hwio
154 || diff_wei_format == memory::format_tag::oiw
155 || diff_wei_format == memory::format_tag::owi
156 || diff_wei_format == memory::format_tag::wio
157 || diff_wei_format == memory::format_tag::io
158 || diff_wei_format == memory::format_tag::oi
159 || diff_wei_format == memory::format_tag::any;
160 bool diff_bia_ok = diff_bia_format == memory::format_tag::undef
161 || diff_bia_format == memory::format_tag::any
162 || diff_bia_format == memory::format_tag::a
163 || diff_bia_format == memory::format_tag::x;
164 bool diff_dst_ok = diff_dst_format == memory::format_tag::any
165 || diff_dst_format == memory::format_tag::nc;
166
167 return src_ok && diff_wei_ok && diff_bia_ok && diff_dst_ok;
168 }
169
170 void Test() {
171 auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
172 test_inner_product_descr_t ipd = p.test_ipd;
173
174 bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
175 if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
176
177 bool with_bias = p.diff_bias_format != memory::format_tag::undef;
178
179 auto eng = get_test_engine();
180 auto strm = make_stream(eng);
181 memory::data_type data_type = data_traits<data_t>::data_type;
182 ASSERT_EQ(data_type, dnnl::memory::data_type::f32);
183
184 memory::dims src_dims = {ipd.mb, ipd.ic},
185 diff_wei_dims = {ipd.oc, ipd.ic};
186 if (has_spatial) {
187 if (p.ndims == 5) {
188 src_dims.push_back(ipd.kd);
189 diff_wei_dims.push_back(ipd.kd);
190 }
191 if (p.ndims >= 4) {
192 src_dims.push_back(ipd.kh);
193 diff_wei_dims.push_back(ipd.kh);
194 }
195 if (p.ndims >= 3) {
196 src_dims.push_back(ipd.kw);
197 diff_wei_dims.push_back(ipd.kw);
198 }
199 }
200 auto ip_src_desc = create_md(src_dims, data_type, p.src_format);
201 auto ip_diff_weights_desc
202 = create_md(diff_wei_dims, data_type, p.diff_weights_format);
203 auto ip_diff_dst_desc
204 = create_md({ipd.mb, ipd.oc}, data_type, p.diff_dst_format);
205 auto ip_diff_bias_desc = with_bias
206 ? create_md({ipd.oc}, data_type, p.diff_bias_format)
207 : create_md({}, data_type, p.diff_bias_format);
208
209 // Create inner product forward (hint for backward)
210 auto ip_fwd_pdesc
211 = inner_product_forward::primitive_desc(eng, prop_kind::forward,
212 ip_src_desc, ip_diff_weights_desc, ip_diff_dst_desc);
213
214 // Create inner product backward
215 auto ip_primitive_desc = with_bias
216 ? inner_product_backward_weights::primitive_desc(eng,
217 ip_src_desc, ip_diff_weights_desc, ip_diff_bias_desc,
218 ip_diff_dst_desc, ip_fwd_pdesc)
219 : inner_product_backward_weights::primitive_desc(eng,
220 ip_src_desc, ip_diff_weights_desc, ip_diff_dst_desc,
221 ip_fwd_pdesc);
222
223 ip_primitive_desc = inner_product_backward_weights::primitive_desc(
224 ip_primitive_desc.get()); // test construction from a C pd
225
226 auto ip_src = test::make_memory(ip_primitive_desc.src_desc(), eng);
227 auto ip_diff_dst
228 = test::make_memory(ip_primitive_desc.diff_dst_desc(), eng);
229 auto ip_diff_weights
230 = test::make_memory(ip_primitive_desc.diff_weights_desc(), eng);
231 auto diff_weights_ref
232 = test::make_memory(ip_primitive_desc.diff_weights_desc(), eng);
233 auto ip_diff_bias
234 = test::make_memory(ip_primitive_desc.diff_bias_desc(), eng);
235 auto diff_bias_ref
236 = test::make_memory(ip_primitive_desc.diff_bias_desc(), eng);
237
238 fill_data<data_t>(
239 ip_src.get_desc().get_size() / sizeof(data_t), ip_src);
240 fill_data<data_t>(ip_diff_dst.get_desc().get_size() / sizeof(data_t),
241 ip_diff_dst);
242
243 check_zero_tail<data_t>(1, ip_src);
244 check_zero_tail<data_t>(1, ip_diff_dst);
245
246 ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
247 == ip_primitive_desc.src_desc());
248 ASSERT_TRUE(ip_primitive_desc.query_md(
249 query::exec_arg_md, DNNL_ARG_DIFF_DST)
250 == ip_primitive_desc.diff_dst_desc());
251 ASSERT_TRUE(ip_primitive_desc.query_md(
252 query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS)
253 == ip_primitive_desc.diff_weights_desc());
254 ASSERT_TRUE(ip_primitive_desc.query_md(
255 query::exec_arg_md, DNNL_ARG_DIFF_BIAS)
256 == ip_primitive_desc.diff_bias_desc());
257
258 ASSERT_EQ(
259 ip_primitive_desc.get_prop_kind(), prop_kind::backward_weights);
260
261 EXPECT_ANY_THROW(inner_product_backward_weights(ip_primitive_desc, {}));
262 inner_product_backward_weights(ip_primitive_desc)
263 .execute(strm,
264 {{DNNL_ARG_DIFF_DST, ip_diff_dst},
265 {DNNL_ARG_SRC, ip_src},
266 {DNNL_ARG_DIFF_WEIGHTS, ip_diff_weights},
267 {DNNL_ARG_DIFF_BIAS, ip_diff_bias}});
268 strm.wait();
269
270 compute_ref_inner_product_bwd_weights<data_t>(
271 p.ndims, ipd, ip_src, ip_diff_dst, diff_weights_ref);
272 check_zero_tail<data_t>(1, diff_weights_ref);
273
274 compare_data<data_t>(diff_weights_ref, ip_diff_weights);
275
276 check_zero_tail<data_t>(0, ip_diff_weights);
277
278 if (with_bias) {
279 compute_ref_inner_product_bwd_bias<data_t>(
280 ipd, ip_diff_dst, diff_bias_ref);
281 compare_data<data_t>(diff_bias_ref, ip_diff_bias);
282 }
283 }
284};
285
286using inner_product_test_float = inner_product_test_bwd_weights_t<float>;
287using inprod_test_params_float = inprod_test_params_t;
288
289#define EXPAND_SIZES_3D(...) \
290 5, { __VA_ARGS__ }
291#define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \
292 4, { mb, ic, oc, 1, kh, kw }
293#define EXPAND_SIZES_1D(mb, ic, oc, kw) \
294 3, { mb, ic, oc, 1, 1, kw }
295
296TEST_P(inner_product_test_float, TestsInnerProduct) {}
297
298INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsZeroDim,
299 inner_product_test_float,
300 ::testing::Values(inprod_test_params_float {memory::format_tag::any,
301 memory::format_tag::any, memory::format_tag::any,
302 memory::format_tag::any, EXPAND_SIZES_2D(0, 32, 48, 6, 6)}));
303
304INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsEF,
305 inner_product_test_float,
306 ::testing::Values(
307 inprod_test_params_float {memory::format_tag::any,
308 memory::format_tag::any, memory::format_tag::any,
309 memory::format_tag::any,
310 EXPAND_SIZES_2D(2, 0, 48, 6, 6), true,
311 dnnl_invalid_arguments},
312 inprod_test_params_float {memory::format_tag::any,
313 memory::format_tag::any, memory::format_tag::any,
314 memory::format_tag::any,
315 EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true,
316 dnnl_invalid_arguments},
317 inprod_test_params_float {memory::format_tag::any,
318 memory::format_tag::any, memory::format_tag::any,
319 memory::format_tag::any,
320 EXPAND_SIZES_2D(2, -1, 48, 6, 6), true,
321 dnnl_invalid_arguments}));
322
323INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsNoBias_padded,
324 inner_product_test_float,
325 ::testing::Values(
326 inprod_test_params_float {memory::format_tag::nChw16c,
327 memory::format_tag::aBcd16b, memory::format_tag::undef,
328 memory::format_tag::nc,
329 EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
330 inprod_test_params_float {memory::format_tag::nChw16c,
331 memory::format_tag::aBcd16b, memory::format_tag::undef,
332 memory::format_tag::nc,
333 EXPAND_SIZES_2D(2, 10, 5, 3, 3)},
334 inprod_test_params_float {memory::format_tag::nChw8c,
335 memory::format_tag::aBcd8b, memory::format_tag::undef,
336 memory::format_tag::nc,
337 EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
338 inprod_test_params_float {memory::format_tag::nChw8c,
339 memory::format_tag::aBcd8b, memory::format_tag::undef,
340 memory::format_tag::nc,
341 EXPAND_SIZES_2D(2, 5, 15, 3, 3)}));
342
343GPU_INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights_padded,
344 inner_product_test_float,
345 ::testing::Values(inprod_test_params_float {memory::format_tag::nChw16c,
346 memory::format_tag::aBcd16b,
347 memory::format_tag::x, memory::format_tag::nc,
348 EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
349 inprod_test_params_float {memory::format_tag::nChw16c,
350 memory::format_tag::aBcd16b, memory::format_tag::x,
351 memory::format_tag::nc,
352 EXPAND_SIZES_2D(2, 10, 5, 3, 3)},
353 inprod_test_params_float {memory::format_tag::nChw8c,
354 memory::format_tag::aBcd8b, memory::format_tag::x,
355 memory::format_tag::nc,
356 EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
357 inprod_test_params_float {memory::format_tag::nChw8c,
358 memory::format_tag::aBcd8b, memory::format_tag::x,
359 memory::format_tag::nc,
360 EXPAND_SIZES_2D(2, 5, 15, 3, 3)}));
361
362INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsNoBias,
363 inner_product_test_float,
364 ::testing::Values(
365 inprod_test_params_float {memory::format_tag::any,
366 memory::format_tag::any, memory::format_tag::undef,
367 memory::format_tag::any,
368 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
369 inprod_test_params_float {memory::format_tag::any,
370 memory::format_tag::any, memory::format_tag::undef,
371 memory::format_tag::any,
372 EXPAND_SIZES_2D(2, 1024, 48, 2, 2)},
373 inprod_test_params_float {memory::format_tag::nwc,
374 memory::format_tag::owi, memory::format_tag::undef,
375 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
376 inprod_test_params_float {memory::format_tag::nwc,
377 memory::format_tag::wio, memory::format_tag::undef,
378 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
379 inprod_test_params_float {memory::format_tag::nwc,
380 memory::format_tag::oiw, memory::format_tag::undef,
381 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
382 inprod_test_params_float {memory::format_tag::ncw,
383 memory::format_tag::oiw, memory::format_tag::undef,
384 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
385 inprod_test_params_float {memory::format_tag::ncw,
386 memory::format_tag::wio, memory::format_tag::undef,
387 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
388 inprod_test_params_float {memory::format_tag::ncw,
389 memory::format_tag::owi, memory::format_tag::undef,
390 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
391 inprod_test_params_float {memory::format_tag::nhwc,
392 memory::format_tag::hwio, memory::format_tag::undef,
393 memory::format_tag::nc,
394 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
395 inprod_test_params_float {memory::format_tag::nhwc,
396 memory::format_tag::oihw, memory::format_tag::undef,
397 memory::format_tag::nc,
398 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
399 inprod_test_params_float {memory::format_tag::nhwc,
400 memory::format_tag::ohwi, memory::format_tag::undef,
401 memory::format_tag::nc,
402 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
403 inprod_test_params_float {memory::format_tag::nchw,
404 memory::format_tag::oihw, memory::format_tag::undef,
405 memory::format_tag::nc,
406 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
407 inprod_test_params_float {memory::format_tag::nchw,
408 memory::format_tag::ohwi, memory::format_tag::undef,
409 memory::format_tag::nc,
410 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
411 inprod_test_params_float {memory::format_tag::nchw,
412 memory::format_tag::hwio, memory::format_tag::undef,
413 memory::format_tag::nc,
414 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
415 inprod_test_params_float {memory::format_tag::nChw8c,
416 memory::format_tag::aBcd8b, memory::format_tag::undef,
417 memory::format_tag::nc,
418 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
419 inprod_test_params_float {memory::format_tag::nChw16c,
420 memory::format_tag::aBcd16b, memory::format_tag::undef,
421 memory::format_tag::nc,
422 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
423 inprod_test_params_float {memory::format_tag::any,
424 memory::format_tag::aBcd16b, memory::format_tag::undef,
425 memory::format_tag::nc,
426 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
427 inprod_test_params_float {memory::format_tag::nChw16c,
428 memory::format_tag::any, memory::format_tag::undef,
429 memory::format_tag::nc,
430 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
431 inprod_test_params_float {memory::format_tag::nChw16c,
432 memory::format_tag::aBcd16b, memory::format_tag::undef,
433 memory::format_tag::nc,
434 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
435 inprod_test_params_float {memory::format_tag::nc,
436 memory::format_tag::oi, memory::format_tag::undef,
437 memory::format_tag::nc,
438 EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
439 inprod_test_params_float {memory::format_tag::nc,
440 memory::format_tag::oi, memory::format_tag::undef,
441 memory::format_tag::nc, EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
442 inprod_test_params_float {memory::format_tag::nc,
443 memory::format_tag::io, memory::format_tag::undef,
444 memory::format_tag::nc,
445 EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
446
447INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights,
448 inner_product_test_float,
449 ::testing::Values(
450 inprod_test_params_float {memory::format_tag::any,
451 memory::format_tag::any, memory::format_tag::any,
452 memory::format_tag::any,
453 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
454 inprod_test_params_float {memory::format_tag::any,
455 memory::format_tag::any, memory::format_tag::any,
456 memory::format_tag::any,
457 EXPAND_SIZES_2D(2, 32, 1024, 2, 2)},
458 inprod_test_params_float {memory::format_tag::nhwc,
459 memory::format_tag::hwio, memory::format_tag::x,
460 memory::format_tag::nc,
461 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
462 inprod_test_params_float {memory::format_tag::nhwc,
463 memory::format_tag::oihw, memory::format_tag::x,
464 memory::format_tag::nc,
465 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
466 inprod_test_params_float {memory::format_tag::nchw,
467 memory::format_tag::oihw, memory::format_tag::x,
468 memory::format_tag::nc,
469 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
470 inprod_test_params_float {memory::format_tag::nChw8c,
471 memory::format_tag::aBcd8b, memory::format_tag::x,
472 memory::format_tag::nc,
473 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
474 inprod_test_params_float {memory::format_tag::nChw16c,
475 memory::format_tag::aBcd16b, memory::format_tag::x,
476 memory::format_tag::nc,
477 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
478 inprod_test_params_float {memory::format_tag::nc,
479 memory::format_tag::oi, memory::format_tag::x,
480 memory::format_tag::nc,
481 EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
482 inprod_test_params_float {memory::format_tag::nc,
483 memory::format_tag::oi, memory::format_tag::x,
484 memory::format_tag::nc, EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
485 inprod_test_params_float {memory::format_tag::nc,
486 memory::format_tag::io, memory::format_tag::x,
487 memory::format_tag::nc,
488 EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
489
490INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights3D,
491 inner_product_test_float,
492 ::testing::Values(
493 inprod_test_params_float {memory::format_tag::any,
494 memory::format_tag::any, memory::format_tag::any,
495 memory::format_tag::any,
496 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
497 inprod_test_params_float {memory::format_tag::any,
498 memory::format_tag::any, memory::format_tag::any,
499 memory::format_tag::any,
500 EXPAND_SIZES_3D(2, 32, 1024, 2, 2, 2)},
501 inprod_test_params_float {memory::format_tag::ncdhw,
502 memory::format_tag::oidhw, memory::format_tag::x,
503 memory::format_tag::nc,
504 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
505 inprod_test_params_float {memory::format_tag::ncdhw,
506 memory::format_tag::dhwio, memory::format_tag::x,
507 memory::format_tag::nc,
508 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
509 inprod_test_params_float {memory::format_tag::ncdhw,
510 memory::format_tag::odhwi, memory::format_tag::x,
511 memory::format_tag::nc,
512 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
513 inprod_test_params_float {memory::format_tag::nCdhw8c,
514 memory::format_tag::aBcde8b, memory::format_tag::x,
515 memory::format_tag::nc,
516 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
517 inprod_test_params_float {memory::format_tag::nCdhw16c,
518 memory::format_tag::aBcde16b, memory::format_tag::x,
519 memory::format_tag::nc,
520 EXPAND_SIZES_3D(2, 32, 1000, 6, 6, 6)},
521 inprod_test_params_float {memory::format_tag::ndhwc,
522 memory::format_tag::dhwio, memory::format_tag::x,
523 memory::format_tag::nc,
524 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
525 inprod_test_params_float {memory::format_tag::ndhwc,
526 memory::format_tag::odhwi, memory::format_tag::x,
527 memory::format_tag::nc,
528 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
529 inprod_test_params_float {memory::format_tag::ndhwc,
530 memory::format_tag::oidhw, memory::format_tag::x,
531 memory::format_tag::nc,
532 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}));
533} // namespace dnnl
534