1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | struct test_inner_product_descr_t { |
25 | memory::dim mb; |
26 | memory::dim ic; |
27 | memory::dim oc; |
28 | memory::dim kd, kh, kw; |
29 | }; |
30 | |
31 | template <typename data_t> |
32 | void compute_ref_inner_product_bwd_bias(const test_inner_product_descr_t &ipd, |
33 | const memory &diff_dst, const memory &diff_bias) { |
34 | auto diff_bias_data = map_memory<data_t>(diff_bias); |
35 | auto diff_dst_data = map_memory<data_t>(diff_dst); |
36 | |
37 | const memory::desc diff_bias_d = diff_bias.get_desc(); |
38 | const memory::desc diff_dst_d = diff_dst.get_desc(); |
39 | const dnnl::impl::memory_desc_wrapper diff_bias_mdw(diff_bias_d.get()); |
40 | const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get()); |
41 | |
42 | dnnl::impl::parallel_nd(ipd.oc, [&](memory::dim oc) { |
43 | data_t *db = &diff_bias_data[diff_bias_mdw.off_l(oc, true)]; |
44 | *db = data_t(0); |
45 | for (memory::dim n = 0; n < ipd.mb; ++n) { |
46 | *db += diff_dst_data[diff_dst_mdw.off_l(n * ipd.oc + oc, true)]; |
47 | } |
48 | }); |
49 | } |
50 | |
51 | template <typename data_t> |
52 | void compute_ref_inner_product_bwd_weights(int ndims, |
53 | const test_inner_product_descr_t &ipd, const memory &src, |
54 | const memory &diff_dst, const memory &diff_weights) { |
55 | auto src_data = map_memory<data_t>(src); |
56 | auto diff_weights_data = map_memory<data_t>(diff_weights); |
57 | auto diff_dst_data = map_memory<data_t>(diff_dst); |
58 | |
59 | const memory::desc src_d = src.get_desc(); |
60 | const memory::desc diff_weights_d = diff_weights.get_desc(); |
61 | const memory::desc diff_dst_d = diff_dst.get_desc(); |
62 | const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get()); |
63 | const dnnl::impl::memory_desc_wrapper diff_weights_mdw( |
64 | diff_weights_d.get()); |
65 | const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get()); |
66 | |
67 | auto padded_ic = src_d.get_padded_dims()[1]; |
68 | |
69 | bool has_spatial = ipd.kh > 1 || ipd.kw > 1; |
70 | if (ndims == 5) has_spatial = has_spatial || ipd.kd > 1; |
71 | dnnl::impl::parallel_nd( |
72 | ipd.oc, ipd.ic, [&](memory::dim oc, memory::dim ic) { |
73 | if (has_spatial) { |
74 | for_(memory::dim kd = 0; kd < ipd.kd; ++kd) |
75 | for_(memory::dim kh = 0; kh < ipd.kh; ++kh) |
76 | for (memory::dim kw = 0; kw < ipd.kw; ++kw) { |
77 | memory::dim dwidx |
78 | = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw |
79 | + ic * ipd.kd * ipd.kh * ipd.kw |
80 | + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw; |
81 | data_t *dw = &diff_weights_data[diff_weights_mdw.off_l( |
82 | dwidx, true)]; |
83 | *dw = data_t(0); |
84 | for (memory::dim n = 0; n < ipd.mb; ++n) { |
85 | memory::dim ddidx = n * ipd.oc + oc; |
86 | memory::dim sidx |
87 | = n * padded_ic * ipd.kd * ipd.kh * ipd.kw |
88 | + ic * ipd.kd * ipd.kh * ipd.kw |
89 | + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw; |
90 | *dw += diff_dst_data[diff_dst_mdw.off_l( |
91 | ddidx, true)] |
92 | * src_data[src_mdw.off_l(sidx, true)]; |
93 | } |
94 | } |
95 | } else { |
96 | memory::dim dwidx = oc * ipd.ic + ic; |
97 | data_t *dw = &diff_weights_data[diff_weights_mdw.off_l( |
98 | dwidx, true)]; |
99 | *dw = data_t(0); |
100 | for (memory::dim n = 0; n < ipd.mb; ++n) { |
101 | memory::dim ddidx = n * ipd.oc + oc; |
102 | memory::dim sidx = n * ipd.ic + ic; |
103 | *dw += diff_dst_data[diff_dst_mdw.off_l(ddidx, true)] |
104 | * src_data[src_mdw.off_l(sidx, true)]; |
105 | } |
106 | } |
107 | }); |
108 | } |
109 | |
110 | struct inprod_test_params_t { |
111 | memory::format_tag src_format; |
112 | memory::format_tag diff_weights_format; |
113 | memory::format_tag diff_bias_format; |
114 | memory::format_tag diff_dst_format; |
115 | int ndims; |
116 | test_inner_product_descr_t test_ipd; |
117 | bool expect_to_fail; |
118 | dnnl_status_t expected_status; |
119 | }; |
120 | |
121 | template <typename data_t> |
122 | class inner_product_test_bwd_weights_t |
123 | : public ::testing::TestWithParam<inprod_test_params_t> { |
124 | protected: |
125 | void SetUp() override { |
126 | auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam(); |
127 | SKIP_IF_CUDA( |
128 | !cuda_check_format_tags(p.src_format, p.diff_weights_format, |
129 | p.diff_bias_format, p.diff_dst_format), |
130 | "Unsupported format tag" ); |
131 | SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions" ); |
132 | catch_expected_failures( |
133 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
134 | } |
135 | |
136 | bool cuda_check_format_tags(memory::format_tag src_format, |
137 | memory::format_tag diff_wei_format, |
138 | memory::format_tag diff_bia_format, |
139 | memory::format_tag diff_dst_format) { |
140 | bool src_ok = src_format == memory::format_tag::ncdhw |
141 | || src_format == memory::format_tag::ndhwc |
142 | || src_format == memory::format_tag::nchw |
143 | || src_format == memory::format_tag::nhwc |
144 | || src_format == memory::format_tag::ncw |
145 | || src_format == memory::format_tag::nwc |
146 | || src_format == memory::format_tag::nc |
147 | || src_format == memory::format_tag::any; |
148 | bool diff_wei_ok = diff_wei_format == memory::format_tag::oidhw |
149 | || diff_wei_format == memory::format_tag::odhwi |
150 | || diff_wei_format == memory::format_tag::dhwio |
151 | || diff_wei_format == memory::format_tag::oihw |
152 | || diff_wei_format == memory::format_tag::ohwi |
153 | || diff_wei_format == memory::format_tag::hwio |
154 | || diff_wei_format == memory::format_tag::oiw |
155 | || diff_wei_format == memory::format_tag::owi |
156 | || diff_wei_format == memory::format_tag::wio |
157 | || diff_wei_format == memory::format_tag::io |
158 | || diff_wei_format == memory::format_tag::oi |
159 | || diff_wei_format == memory::format_tag::any; |
160 | bool diff_bia_ok = diff_bia_format == memory::format_tag::undef |
161 | || diff_bia_format == memory::format_tag::any |
162 | || diff_bia_format == memory::format_tag::a |
163 | || diff_bia_format == memory::format_tag::x; |
164 | bool diff_dst_ok = diff_dst_format == memory::format_tag::any |
165 | || diff_dst_format == memory::format_tag::nc; |
166 | |
167 | return src_ok && diff_wei_ok && diff_bia_ok && diff_dst_ok; |
168 | } |
169 | |
170 | void Test() { |
171 | auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam(); |
172 | test_inner_product_descr_t ipd = p.test_ipd; |
173 | |
174 | bool has_spatial = ipd.kh > 1 || ipd.kw > 1; |
175 | if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1; |
176 | |
177 | bool with_bias = p.diff_bias_format != memory::format_tag::undef; |
178 | |
179 | auto eng = get_test_engine(); |
180 | auto strm = make_stream(eng); |
181 | memory::data_type data_type = data_traits<data_t>::data_type; |
182 | ASSERT_EQ(data_type, dnnl::memory::data_type::f32); |
183 | |
184 | memory::dims src_dims = {ipd.mb, ipd.ic}, |
185 | diff_wei_dims = {ipd.oc, ipd.ic}; |
186 | if (has_spatial) { |
187 | if (p.ndims == 5) { |
188 | src_dims.push_back(ipd.kd); |
189 | diff_wei_dims.push_back(ipd.kd); |
190 | } |
191 | if (p.ndims >= 4) { |
192 | src_dims.push_back(ipd.kh); |
193 | diff_wei_dims.push_back(ipd.kh); |
194 | } |
195 | if (p.ndims >= 3) { |
196 | src_dims.push_back(ipd.kw); |
197 | diff_wei_dims.push_back(ipd.kw); |
198 | } |
199 | } |
200 | auto ip_src_desc = create_md(src_dims, data_type, p.src_format); |
201 | auto ip_diff_weights_desc |
202 | = create_md(diff_wei_dims, data_type, p.diff_weights_format); |
203 | auto ip_diff_dst_desc |
204 | = create_md({ipd.mb, ipd.oc}, data_type, p.diff_dst_format); |
205 | auto ip_diff_bias_desc = with_bias |
206 | ? create_md({ipd.oc}, data_type, p.diff_bias_format) |
207 | : create_md({}, data_type, p.diff_bias_format); |
208 | |
209 | // Create inner product forward (hint for backward) |
210 | auto ip_fwd_pdesc |
211 | = inner_product_forward::primitive_desc(eng, prop_kind::forward, |
212 | ip_src_desc, ip_diff_weights_desc, ip_diff_dst_desc); |
213 | |
214 | // Create inner product backward |
215 | auto ip_primitive_desc = with_bias |
216 | ? inner_product_backward_weights::primitive_desc(eng, |
217 | ip_src_desc, ip_diff_weights_desc, ip_diff_bias_desc, |
218 | ip_diff_dst_desc, ip_fwd_pdesc) |
219 | : inner_product_backward_weights::primitive_desc(eng, |
220 | ip_src_desc, ip_diff_weights_desc, ip_diff_dst_desc, |
221 | ip_fwd_pdesc); |
222 | |
223 | ip_primitive_desc = inner_product_backward_weights::primitive_desc( |
224 | ip_primitive_desc.get()); // test construction from a C pd |
225 | |
226 | auto ip_src = test::make_memory(ip_primitive_desc.src_desc(), eng); |
227 | auto ip_diff_dst |
228 | = test::make_memory(ip_primitive_desc.diff_dst_desc(), eng); |
229 | auto ip_diff_weights |
230 | = test::make_memory(ip_primitive_desc.diff_weights_desc(), eng); |
231 | auto diff_weights_ref |
232 | = test::make_memory(ip_primitive_desc.diff_weights_desc(), eng); |
233 | auto ip_diff_bias |
234 | = test::make_memory(ip_primitive_desc.diff_bias_desc(), eng); |
235 | auto diff_bias_ref |
236 | = test::make_memory(ip_primitive_desc.diff_bias_desc(), eng); |
237 | |
238 | fill_data<data_t>( |
239 | ip_src.get_desc().get_size() / sizeof(data_t), ip_src); |
240 | fill_data<data_t>(ip_diff_dst.get_desc().get_size() / sizeof(data_t), |
241 | ip_diff_dst); |
242 | |
243 | check_zero_tail<data_t>(1, ip_src); |
244 | check_zero_tail<data_t>(1, ip_diff_dst); |
245 | |
246 | ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC) |
247 | == ip_primitive_desc.src_desc()); |
248 | ASSERT_TRUE(ip_primitive_desc.query_md( |
249 | query::exec_arg_md, DNNL_ARG_DIFF_DST) |
250 | == ip_primitive_desc.diff_dst_desc()); |
251 | ASSERT_TRUE(ip_primitive_desc.query_md( |
252 | query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS) |
253 | == ip_primitive_desc.diff_weights_desc()); |
254 | ASSERT_TRUE(ip_primitive_desc.query_md( |
255 | query::exec_arg_md, DNNL_ARG_DIFF_BIAS) |
256 | == ip_primitive_desc.diff_bias_desc()); |
257 | |
258 | ASSERT_EQ( |
259 | ip_primitive_desc.get_prop_kind(), prop_kind::backward_weights); |
260 | |
261 | EXPECT_ANY_THROW(inner_product_backward_weights(ip_primitive_desc, {})); |
262 | inner_product_backward_weights(ip_primitive_desc) |
263 | .execute(strm, |
264 | {{DNNL_ARG_DIFF_DST, ip_diff_dst}, |
265 | {DNNL_ARG_SRC, ip_src}, |
266 | {DNNL_ARG_DIFF_WEIGHTS, ip_diff_weights}, |
267 | {DNNL_ARG_DIFF_BIAS, ip_diff_bias}}); |
268 | strm.wait(); |
269 | |
270 | compute_ref_inner_product_bwd_weights<data_t>( |
271 | p.ndims, ipd, ip_src, ip_diff_dst, diff_weights_ref); |
272 | check_zero_tail<data_t>(1, diff_weights_ref); |
273 | |
274 | compare_data<data_t>(diff_weights_ref, ip_diff_weights); |
275 | |
276 | check_zero_tail<data_t>(0, ip_diff_weights); |
277 | |
278 | if (with_bias) { |
279 | compute_ref_inner_product_bwd_bias<data_t>( |
280 | ipd, ip_diff_dst, diff_bias_ref); |
281 | compare_data<data_t>(diff_bias_ref, ip_diff_bias); |
282 | } |
283 | } |
284 | }; |
285 | |
286 | using inner_product_test_float = inner_product_test_bwd_weights_t<float>; |
287 | using inprod_test_params_float = inprod_test_params_t; |
288 | |
289 | #define EXPAND_SIZES_3D(...) \ |
290 | 5, { __VA_ARGS__ } |
291 | #define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \ |
292 | 4, { mb, ic, oc, 1, kh, kw } |
293 | #define EXPAND_SIZES_1D(mb, ic, oc, kw) \ |
294 | 3, { mb, ic, oc, 1, 1, kw } |
295 | |
296 | TEST_P(inner_product_test_float, TestsInnerProduct) {} |
297 | |
298 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsZeroDim, |
299 | inner_product_test_float, |
300 | ::testing::Values(inprod_test_params_float {memory::format_tag::any, |
301 | memory::format_tag::any, memory::format_tag::any, |
302 | memory::format_tag::any, EXPAND_SIZES_2D(0, 32, 48, 6, 6)})); |
303 | |
304 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsEF, |
305 | inner_product_test_float, |
306 | ::testing::Values( |
307 | inprod_test_params_float {memory::format_tag::any, |
308 | memory::format_tag::any, memory::format_tag::any, |
309 | memory::format_tag::any, |
310 | EXPAND_SIZES_2D(2, 0, 48, 6, 6), true, |
311 | dnnl_invalid_arguments}, |
312 | inprod_test_params_float {memory::format_tag::any, |
313 | memory::format_tag::any, memory::format_tag::any, |
314 | memory::format_tag::any, |
315 | EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true, |
316 | dnnl_invalid_arguments}, |
317 | inprod_test_params_float {memory::format_tag::any, |
318 | memory::format_tag::any, memory::format_tag::any, |
319 | memory::format_tag::any, |
320 | EXPAND_SIZES_2D(2, -1, 48, 6, 6), true, |
321 | dnnl_invalid_arguments})); |
322 | |
323 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsNoBias_padded, |
324 | inner_product_test_float, |
325 | ::testing::Values( |
326 | inprod_test_params_float {memory::format_tag::nChw16c, |
327 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
328 | memory::format_tag::nc, |
329 | EXPAND_SIZES_2D(2, 17, 5, 3, 3)}, |
330 | inprod_test_params_float {memory::format_tag::nChw16c, |
331 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
332 | memory::format_tag::nc, |
333 | EXPAND_SIZES_2D(2, 10, 5, 3, 3)}, |
334 | inprod_test_params_float {memory::format_tag::nChw8c, |
335 | memory::format_tag::aBcd8b, memory::format_tag::undef, |
336 | memory::format_tag::nc, |
337 | EXPAND_SIZES_2D(2, 17, 5, 3, 3)}, |
338 | inprod_test_params_float {memory::format_tag::nChw8c, |
339 | memory::format_tag::aBcd8b, memory::format_tag::undef, |
340 | memory::format_tag::nc, |
341 | EXPAND_SIZES_2D(2, 5, 15, 3, 3)})); |
342 | |
343 | GPU_INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights_padded, |
344 | inner_product_test_float, |
345 | ::testing::Values(inprod_test_params_float {memory::format_tag::nChw16c, |
346 | memory::format_tag::aBcd16b, |
347 | memory::format_tag::x, memory::format_tag::nc, |
348 | EXPAND_SIZES_2D(2, 17, 5, 3, 3)}, |
349 | inprod_test_params_float {memory::format_tag::nChw16c, |
350 | memory::format_tag::aBcd16b, memory::format_tag::x, |
351 | memory::format_tag::nc, |
352 | EXPAND_SIZES_2D(2, 10, 5, 3, 3)}, |
353 | inprod_test_params_float {memory::format_tag::nChw8c, |
354 | memory::format_tag::aBcd8b, memory::format_tag::x, |
355 | memory::format_tag::nc, |
356 | EXPAND_SIZES_2D(2, 17, 5, 3, 3)}, |
357 | inprod_test_params_float {memory::format_tag::nChw8c, |
358 | memory::format_tag::aBcd8b, memory::format_tag::x, |
359 | memory::format_tag::nc, |
360 | EXPAND_SIZES_2D(2, 5, 15, 3, 3)})); |
361 | |
362 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsNoBias, |
363 | inner_product_test_float, |
364 | ::testing::Values( |
365 | inprod_test_params_float {memory::format_tag::any, |
366 | memory::format_tag::any, memory::format_tag::undef, |
367 | memory::format_tag::any, |
368 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
369 | inprod_test_params_float {memory::format_tag::any, |
370 | memory::format_tag::any, memory::format_tag::undef, |
371 | memory::format_tag::any, |
372 | EXPAND_SIZES_2D(2, 1024, 48, 2, 2)}, |
373 | inprod_test_params_float {memory::format_tag::nwc, |
374 | memory::format_tag::owi, memory::format_tag::undef, |
375 | memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)}, |
376 | inprod_test_params_float {memory::format_tag::nwc, |
377 | memory::format_tag::wio, memory::format_tag::undef, |
378 | memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)}, |
379 | inprod_test_params_float {memory::format_tag::nwc, |
380 | memory::format_tag::oiw, memory::format_tag::undef, |
381 | memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)}, |
382 | inprod_test_params_float {memory::format_tag::ncw, |
383 | memory::format_tag::oiw, memory::format_tag::undef, |
384 | memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)}, |
385 | inprod_test_params_float {memory::format_tag::ncw, |
386 | memory::format_tag::wio, memory::format_tag::undef, |
387 | memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)}, |
388 | inprod_test_params_float {memory::format_tag::ncw, |
389 | memory::format_tag::owi, memory::format_tag::undef, |
390 | memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)}, |
391 | inprod_test_params_float {memory::format_tag::nhwc, |
392 | memory::format_tag::hwio, memory::format_tag::undef, |
393 | memory::format_tag::nc, |
394 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
395 | inprod_test_params_float {memory::format_tag::nhwc, |
396 | memory::format_tag::oihw, memory::format_tag::undef, |
397 | memory::format_tag::nc, |
398 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
399 | inprod_test_params_float {memory::format_tag::nhwc, |
400 | memory::format_tag::ohwi, memory::format_tag::undef, |
401 | memory::format_tag::nc, |
402 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
403 | inprod_test_params_float {memory::format_tag::nchw, |
404 | memory::format_tag::oihw, memory::format_tag::undef, |
405 | memory::format_tag::nc, |
406 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
407 | inprod_test_params_float {memory::format_tag::nchw, |
408 | memory::format_tag::ohwi, memory::format_tag::undef, |
409 | memory::format_tag::nc, |
410 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
411 | inprod_test_params_float {memory::format_tag::nchw, |
412 | memory::format_tag::hwio, memory::format_tag::undef, |
413 | memory::format_tag::nc, |
414 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
415 | inprod_test_params_float {memory::format_tag::nChw8c, |
416 | memory::format_tag::aBcd8b, memory::format_tag::undef, |
417 | memory::format_tag::nc, |
418 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
419 | inprod_test_params_float {memory::format_tag::nChw16c, |
420 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
421 | memory::format_tag::nc, |
422 | EXPAND_SIZES_2D(2, 32, 1000, 6, 6)}, |
423 | inprod_test_params_float {memory::format_tag::any, |
424 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
425 | memory::format_tag::nc, |
426 | EXPAND_SIZES_2D(2, 32, 1000, 6, 6)}, |
427 | inprod_test_params_float {memory::format_tag::nChw16c, |
428 | memory::format_tag::any, memory::format_tag::undef, |
429 | memory::format_tag::nc, |
430 | EXPAND_SIZES_2D(2, 32, 1000, 6, 6)}, |
431 | inprod_test_params_float {memory::format_tag::nChw16c, |
432 | memory::format_tag::aBcd16b, memory::format_tag::undef, |
433 | memory::format_tag::nc, |
434 | EXPAND_SIZES_2D(2, 32, 1000, 6, 6)}, |
435 | inprod_test_params_float {memory::format_tag::nc, |
436 | memory::format_tag::oi, memory::format_tag::undef, |
437 | memory::format_tag::nc, |
438 | EXPAND_SIZES_2D(2, 32, 1152, 1, 1)}, |
439 | inprod_test_params_float {memory::format_tag::nc, |
440 | memory::format_tag::oi, memory::format_tag::undef, |
441 | memory::format_tag::nc, EXPAND_SIZES_2D(2, 2, 4, 1, 1)}, |
442 | inprod_test_params_float {memory::format_tag::nc, |
443 | memory::format_tag::io, memory::format_tag::undef, |
444 | memory::format_tag::nc, |
445 | EXPAND_SIZES_2D(2, 8, 16, 1, 1)})); |
446 | |
447 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights, |
448 | inner_product_test_float, |
449 | ::testing::Values( |
450 | inprod_test_params_float {memory::format_tag::any, |
451 | memory::format_tag::any, memory::format_tag::any, |
452 | memory::format_tag::any, |
453 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
454 | inprod_test_params_float {memory::format_tag::any, |
455 | memory::format_tag::any, memory::format_tag::any, |
456 | memory::format_tag::any, |
457 | EXPAND_SIZES_2D(2, 32, 1024, 2, 2)}, |
458 | inprod_test_params_float {memory::format_tag::nhwc, |
459 | memory::format_tag::hwio, memory::format_tag::x, |
460 | memory::format_tag::nc, |
461 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
462 | inprod_test_params_float {memory::format_tag::nhwc, |
463 | memory::format_tag::oihw, memory::format_tag::x, |
464 | memory::format_tag::nc, |
465 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
466 | inprod_test_params_float {memory::format_tag::nchw, |
467 | memory::format_tag::oihw, memory::format_tag::x, |
468 | memory::format_tag::nc, |
469 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
470 | inprod_test_params_float {memory::format_tag::nChw8c, |
471 | memory::format_tag::aBcd8b, memory::format_tag::x, |
472 | memory::format_tag::nc, |
473 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
474 | inprod_test_params_float {memory::format_tag::nChw16c, |
475 | memory::format_tag::aBcd16b, memory::format_tag::x, |
476 | memory::format_tag::nc, |
477 | EXPAND_SIZES_2D(2, 32, 1000, 6, 6)}, |
478 | inprod_test_params_float {memory::format_tag::nc, |
479 | memory::format_tag::oi, memory::format_tag::x, |
480 | memory::format_tag::nc, |
481 | EXPAND_SIZES_2D(2, 32, 1152, 1, 1)}, |
482 | inprod_test_params_float {memory::format_tag::nc, |
483 | memory::format_tag::oi, memory::format_tag::x, |
484 | memory::format_tag::nc, EXPAND_SIZES_2D(2, 2, 4, 1, 1)}, |
485 | inprod_test_params_float {memory::format_tag::nc, |
486 | memory::format_tag::io, memory::format_tag::x, |
487 | memory::format_tag::nc, |
488 | EXPAND_SIZES_2D(2, 8, 16, 1, 1)})); |
489 | |
490 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights3D, |
491 | inner_product_test_float, |
492 | ::testing::Values( |
493 | inprod_test_params_float {memory::format_tag::any, |
494 | memory::format_tag::any, memory::format_tag::any, |
495 | memory::format_tag::any, |
496 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
497 | inprod_test_params_float {memory::format_tag::any, |
498 | memory::format_tag::any, memory::format_tag::any, |
499 | memory::format_tag::any, |
500 | EXPAND_SIZES_3D(2, 32, 1024, 2, 2, 2)}, |
501 | inprod_test_params_float {memory::format_tag::ncdhw, |
502 | memory::format_tag::oidhw, memory::format_tag::x, |
503 | memory::format_tag::nc, |
504 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
505 | inprod_test_params_float {memory::format_tag::ncdhw, |
506 | memory::format_tag::dhwio, memory::format_tag::x, |
507 | memory::format_tag::nc, |
508 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
509 | inprod_test_params_float {memory::format_tag::ncdhw, |
510 | memory::format_tag::odhwi, memory::format_tag::x, |
511 | memory::format_tag::nc, |
512 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
513 | inprod_test_params_float {memory::format_tag::nCdhw8c, |
514 | memory::format_tag::aBcde8b, memory::format_tag::x, |
515 | memory::format_tag::nc, |
516 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
517 | inprod_test_params_float {memory::format_tag::nCdhw16c, |
518 | memory::format_tag::aBcde16b, memory::format_tag::x, |
519 | memory::format_tag::nc, |
520 | EXPAND_SIZES_3D(2, 32, 1000, 6, 6, 6)}, |
521 | inprod_test_params_float {memory::format_tag::ndhwc, |
522 | memory::format_tag::dhwio, memory::format_tag::x, |
523 | memory::format_tag::nc, |
524 | EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}, |
525 | inprod_test_params_float {memory::format_tag::ndhwc, |
526 | memory::format_tag::odhwi, memory::format_tag::x, |
527 | memory::format_tag::nc, |
528 | EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}, |
529 | inprod_test_params_float {memory::format_tag::ndhwc, |
530 | memory::format_tag::oidhw, memory::format_tag::x, |
531 | memory::format_tag::nc, |
532 | EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)})); |
533 | } // namespace dnnl |
534 | |