1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | struct test_inner_product_descr_t { |
25 | memory::dim mb; |
26 | memory::dim ic; |
27 | memory::dim oc; |
28 | memory::dim kd, kh, kw; |
29 | }; |
30 | |
31 | template <typename data_t> |
32 | void compute_ref_inner_product_bwd_data(int ndims, |
33 | const test_inner_product_descr_t &ipd, const memory &diff_dst, |
34 | const memory &weights, const memory &diff_src) { |
35 | auto diff_dst_data = map_memory<data_t>(diff_dst); |
36 | auto weights_data = map_memory<data_t>(weights); |
37 | auto diff_src_data = map_memory<data_t>(diff_src); |
38 | |
39 | const memory::desc diff_dst_d = diff_dst.get_desc(); |
40 | const memory::desc weights_d = weights.get_desc(); |
41 | const memory::desc diff_src_d = diff_src.get_desc(); |
42 | const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get()); |
43 | const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get()); |
44 | const dnnl::impl::memory_desc_wrapper diff_src_mdw(diff_src_d.get()); |
45 | |
46 | bool has_spatial = ipd.kh > 1 || ipd.kw > 1; |
47 | if (ndims == 5) has_spatial = has_spatial || ipd.kd > 1; |
48 | auto padded_ic = diff_src_d.get_padded_dims()[1]; |
49 | |
50 | dnnl::impl::parallel_nd(ipd.mb, ipd.ic, [&](memory::dim n, memory::dim ic) { |
51 | if (has_spatial) { |
52 | for_(memory::dim kd = 0; kd < ipd.kd; ++kd) |
53 | for_(memory::dim kh = 0; kh < ipd.kh; ++kh) |
54 | for (memory::dim kw = 0; kw < ipd.kw; ++kw) { |
55 | memory::dim dsidx = n * padded_ic * ipd.kd * ipd.kh * ipd.kw |
56 | + ic * ipd.kd * ipd.kh * ipd.kw + kd * ipd.kh * ipd.kw |
57 | + kh * ipd.kw + kw; |
58 | data_t *ds = &diff_src_data[diff_src_mdw.off_l(dsidx, true)]; |
59 | *ds = data_t(0); |
60 | for (memory::dim oc = 0; oc < ipd.oc; ++oc) { |
61 | memory::dim ddidx = n * ipd.oc + oc; |
62 | memory::dim widx = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw |
63 | + ic * ipd.kd * ipd.kh * ipd.kw |
64 | + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw; |
65 | *ds += diff_dst_data[diff_dst_mdw.off_l(ddidx, true)] |
66 | * weights_data[weights_mdw.off_l(widx, true)]; |
67 | } |
68 | } |
69 | } else { |
70 | memory::dim dsidx = n * ipd.ic + ic; |
71 | data_t *ds = &diff_src_data[diff_src_mdw.off_l(dsidx, true)]; |
72 | *ds = data_t(0); |
73 | for (memory::dim oc = 0; oc < ipd.oc; ++oc) { |
74 | memory::dim ddidx = n * ipd.oc + oc; |
75 | memory::dim widx = oc * ipd.ic + ic; |
76 | *ds += diff_dst_data[diff_dst_mdw.off_l(ddidx, true)] |
77 | * weights_data[weights_mdw.off_l(widx, true)]; |
78 | } |
79 | } |
80 | }); |
81 | } |
82 | |
83 | struct inprod_test_params_t { |
84 | memory::format_tag diff_src_format; |
85 | memory::format_tag weights_format; |
86 | memory::format_tag diff_dst_format; |
87 | int ndims; |
88 | test_inner_product_descr_t test_ipd; |
89 | bool expect_to_fail; |
90 | dnnl_status_t expected_status; |
91 | }; |
92 | |
93 | template <typename data_t> |
94 | class inner_product_test_bwd_data_t |
95 | : public ::testing::TestWithParam<inprod_test_params_t> { |
96 | protected: |
97 | void SetUp() override { |
98 | auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam(); |
99 | SKIP_IF_CUDA(!cuda_check_format_tags(p.diff_src_format, |
100 | p.weights_format, p.diff_dst_format), |
101 | "Unsupported format tag" ); |
102 | SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions" ); |
103 | catch_expected_failures( |
104 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
105 | } |
106 | |
107 | bool cuda_check_format_tags(memory::format_tag diff_src_format, |
108 | memory::format_tag wei_format, memory::format_tag diff_dst_format) { |
109 | bool diff_src_ok = diff_src_format == memory::format_tag::ncdhw |
110 | || diff_src_format == memory::format_tag::ndhwc |
111 | || diff_src_format == memory::format_tag::nchw |
112 | || diff_src_format == memory::format_tag::nhwc |
113 | || diff_src_format == memory::format_tag::ncw |
114 | || diff_src_format == memory::format_tag::nwc |
115 | || diff_src_format == memory::format_tag::nc |
116 | || diff_src_format == memory::format_tag::any; |
117 | bool wei_ok = wei_format == memory::format_tag::oidhw |
118 | || wei_format == memory::format_tag::odhwi |
119 | || wei_format == memory::format_tag::dhwio |
120 | || wei_format == memory::format_tag::oihw |
121 | || wei_format == memory::format_tag::hwio |
122 | || wei_format == memory::format_tag::ohwi |
123 | || wei_format == memory::format_tag::oiw |
124 | || wei_format == memory::format_tag::owi |
125 | || wei_format == memory::format_tag::wio |
126 | || wei_format == memory::format_tag::io |
127 | || wei_format == memory::format_tag::oi |
128 | || wei_format == memory::format_tag::any; |
129 | bool diff_dst_ok = diff_dst_format == memory::format_tag::any |
130 | || diff_dst_format == memory::format_tag::nc; |
131 | |
132 | return diff_src_ok && wei_ok && diff_dst_ok; |
133 | } |
134 | |
135 | void Test() { |
136 | auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam(); |
137 | test_inner_product_descr_t ipd = p.test_ipd; |
138 | bool has_spatial = ipd.kh > 1 || ipd.kw > 1; |
139 | if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1; |
140 | |
141 | auto eng = get_test_engine(); |
142 | auto strm = make_stream(eng); |
143 | memory::data_type data_type = data_traits<data_t>::data_type; |
144 | ASSERT_EQ(data_type, dnnl::memory::data_type::f32); |
145 | |
146 | memory::dims diff_src_dims = {ipd.mb, ipd.ic}, |
147 | wei_dims = {ipd.oc, ipd.ic}; |
148 | if (has_spatial) { |
149 | if (p.ndims == 5) { |
150 | diff_src_dims.push_back(ipd.kd); |
151 | wei_dims.push_back(ipd.kd); |
152 | } |
153 | if (p.ndims >= 4) { |
154 | diff_src_dims.push_back(ipd.kh); |
155 | wei_dims.push_back(ipd.kh); |
156 | } |
157 | if (p.ndims >= 3) { |
158 | diff_src_dims.push_back(ipd.kw); |
159 | wei_dims.push_back(ipd.kw); |
160 | } |
161 | } |
162 | auto ip_diff_src_desc |
163 | = create_md(diff_src_dims, data_type, p.diff_src_format); |
164 | auto ip_weights_desc = create_md(wei_dims, data_type, p.weights_format); |
165 | auto ip_diff_dst_desc |
166 | = create_md({ipd.mb, ipd.oc}, data_type, p.diff_dst_format); |
167 | |
168 | // Create inner product forward (hint for backward) |
169 | auto ip_fwd_pdesc |
170 | = inner_product_forward::primitive_desc(eng, prop_kind::forward, |
171 | ip_diff_src_desc, ip_weights_desc, ip_diff_dst_desc); |
172 | |
173 | // Create inner product backward |
174 | auto ip_primitive_desc = inner_product_backward_data::primitive_desc( |
175 | eng, ip_diff_src_desc, ip_weights_desc, ip_diff_dst_desc, |
176 | ip_fwd_pdesc); |
177 | |
178 | ip_primitive_desc = inner_product_backward_data::primitive_desc( |
179 | ip_primitive_desc.get()); // test construction from a C pd |
180 | |
181 | auto ip_diff_src |
182 | = test::make_memory(ip_primitive_desc.diff_src_desc(), eng); |
183 | auto ip_weights |
184 | = test::make_memory(ip_primitive_desc.weights_desc(), eng); |
185 | auto ip_diff_dst |
186 | = test::make_memory(ip_primitive_desc.diff_dst_desc(), eng); |
187 | auto diff_src_ref |
188 | = test::make_memory(ip_primitive_desc.diff_src_desc(), eng); |
189 | |
190 | fill_data<data_t>(ip_diff_dst.get_desc().get_size() / sizeof(data_t), |
191 | ip_diff_dst); |
192 | fill_data<data_t>( |
193 | ip_weights.get_desc().get_size() / sizeof(data_t), ip_weights); |
194 | |
195 | check_zero_tail<data_t>(1, ip_diff_dst); |
196 | check_zero_tail<data_t>(1, ip_weights); |
197 | |
198 | ASSERT_TRUE(ip_primitive_desc.query_md( |
199 | query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
200 | == ip_primitive_desc.diff_src_desc()); |
201 | ASSERT_TRUE(ip_primitive_desc.query_md( |
202 | query::exec_arg_md, DNNL_ARG_DIFF_DST) |
203 | == ip_primitive_desc.diff_dst_desc()); |
204 | ASSERT_TRUE( |
205 | ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS) |
206 | == ip_primitive_desc.weights_desc()); |
207 | |
208 | ASSERT_EQ(ip_primitive_desc.get_prop_kind(), prop_kind::backward_data); |
209 | |
210 | EXPECT_ANY_THROW(inner_product_backward_data(ip_primitive_desc, {})); |
211 | inner_product_backward_data(ip_primitive_desc) |
212 | .execute(strm, |
213 | {{DNNL_ARG_DIFF_DST, ip_diff_dst}, |
214 | {DNNL_ARG_WEIGHTS, ip_weights}, |
215 | {DNNL_ARG_DIFF_SRC, ip_diff_src}}); |
216 | strm.wait(); |
217 | |
218 | compute_ref_inner_product_bwd_data<data_t>( |
219 | p.ndims == 5, ipd, ip_diff_dst, ip_weights, diff_src_ref); |
220 | check_zero_tail<data_t>(1, diff_src_ref); |
221 | compare_data<data_t>(diff_src_ref, ip_diff_src); |
222 | check_zero_tail<data_t>(0, ip_diff_src); |
223 | } |
224 | }; |
225 | |
226 | using inner_product_test_float = inner_product_test_bwd_data_t<float>; |
227 | using inprod_test_params_float = inprod_test_params_t; |
228 | |
229 | #define EXPAND_SIZES_3D(...) \ |
230 | 5, { __VA_ARGS__ } |
231 | #define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \ |
232 | 4, { mb, ic, oc, 1, kh, kw } |
233 | #define EXPAND_SIZES_1D(mb, ic, oc, kw) \ |
234 | 3, { mb, ic, oc, 1, 1, kw } |
235 | |
236 | TEST_P(inner_product_test_float, TestsInnerProduct) {} |
237 | |
238 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardZeroDim, |
239 | inner_product_test_float, |
240 | ::testing::Values(inprod_test_params_float {memory::format_tag::any, |
241 | memory::format_tag::any, memory::format_tag::any, |
242 | EXPAND_SIZES_2D(0, 32, 48, 6, 6)})); |
243 | |
244 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardDataEF, |
245 | inner_product_test_float, |
246 | ::testing::Values( |
247 | inprod_test_params_float {memory::format_tag::any, |
248 | memory::format_tag::any, memory::format_tag::any, |
249 | EXPAND_SIZES_2D(2, 0, 48, 6, 6), true, |
250 | dnnl_invalid_arguments}, |
251 | inprod_test_params_float {memory::format_tag::any, |
252 | memory::format_tag::any, memory::format_tag::any, |
253 | EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true, |
254 | dnnl_invalid_arguments}, |
255 | inprod_test_params_float {memory::format_tag::any, |
256 | memory::format_tag::any, memory::format_tag::any, |
257 | EXPAND_SIZES_2D(2, -1, 48, 6, 6), true, |
258 | dnnl_invalid_arguments})); |
259 | |
260 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData_nCdhw8c, |
261 | inner_product_test_float, |
262 | ::testing::Values( |
263 | inprod_test_params_float {memory::format_tag::nCdhw8c, |
264 | memory::format_tag::aBcde8b, memory::format_tag::nc, |
265 | EXPAND_SIZES_3D(2, 9, 4, 2, 2, 2)}, |
266 | inprod_test_params_float {memory::format_tag::nCdhw8c, |
267 | memory::format_tag::aBcde8b, memory::format_tag::nc, |
268 | EXPAND_SIZES_3D(2, 17, 16, 2, 2, 2)}, |
269 | inprod_test_params_float {memory::format_tag::nCdhw8c, |
270 | memory::format_tag::aBcde8b, memory::format_tag::nc, |
271 | EXPAND_SIZES_3D(2, 29, 7, 2, 2, 2)})); |
272 | |
273 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData_nCdhw16c, |
274 | inner_product_test_float, |
275 | ::testing::Values( |
276 | inprod_test_params_float {memory::format_tag::nCdhw16c, |
277 | memory::format_tag::aBcde16b, memory::format_tag::nc, |
278 | EXPAND_SIZES_3D(2, 9, 4, 2, 2, 2)}, |
279 | inprod_test_params_float {memory::format_tag::nCdhw16c, |
280 | memory::format_tag::aBcde16b, memory::format_tag::nc, |
281 | EXPAND_SIZES_3D(2, 17, 16, 2, 2, 2)}, |
282 | inprod_test_params_float {memory::format_tag::nCdhw16c, |
283 | memory::format_tag::aBcde16b, memory::format_tag::nc, |
284 | EXPAND_SIZES_3D(2, 29, 7, 2, 2, 2)})); |
285 | |
286 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData_padded, |
287 | inner_product_test_float, |
288 | ::testing::Values( |
289 | inprod_test_params_float {memory::format_tag::nChw16c, |
290 | memory::format_tag::aBcd16b, memory::format_tag::nc, |
291 | EXPAND_SIZES_2D(2, 9, 4, 2, 2)}, |
292 | inprod_test_params_float {memory::format_tag::nChw16c, |
293 | memory::format_tag::aBcd16b, memory::format_tag::nc, |
294 | EXPAND_SIZES_2D(2, 17, 16, 2, 2)}, |
295 | inprod_test_params_float {memory::format_tag::nChw16c, |
296 | memory::format_tag::aBcd16b, memory::format_tag::nc, |
297 | EXPAND_SIZES_2D(2, 29, 7, 2, 2)}, |
298 | inprod_test_params_float {memory::format_tag::nChw8c, |
299 | memory::format_tag::aBcd8b, memory::format_tag::nc, |
300 | EXPAND_SIZES_2D(2, 5, 4, 2, 2)}, |
301 | inprod_test_params_float {memory::format_tag::nChw8c, |
302 | memory::format_tag::aBcd8b, memory::format_tag::nc, |
303 | EXPAND_SIZES_2D(2, 14, 16, 2, 2)}, |
304 | inprod_test_params_float {memory::format_tag::nChw8c, |
305 | memory::format_tag::aBcd8b, memory::format_tag::nc, |
306 | EXPAND_SIZES_2D(2, 33, 7, 2, 2)})); |
307 | |
308 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData, inner_product_test_float, |
309 | ::testing::Values( |
310 | inprod_test_params_float {memory::format_tag::any, |
311 | memory::format_tag::any, memory::format_tag::any, |
312 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
313 | inprod_test_params_float {memory::format_tag::any, |
314 | memory::format_tag::any, memory::format_tag::any, |
315 | EXPAND_SIZES_2D(2, 1024, 48, 2, 2)}, |
316 | inprod_test_params_float {memory::format_tag::nwc, |
317 | memory::format_tag::oiw, memory::format_tag::nc, |
318 | EXPAND_SIZES_1D(2, 32, 48, 6)}, |
319 | inprod_test_params_float {memory::format_tag::nwc, |
320 | memory::format_tag::wio, memory::format_tag::nc, |
321 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
322 | inprod_test_params_float {memory::format_tag::nwc, |
323 | memory::format_tag::owi, memory::format_tag::nc, |
324 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
325 | inprod_test_params_float {memory::format_tag::ncw, |
326 | memory::format_tag::oiw, memory::format_tag::nc, |
327 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
328 | inprod_test_params_float {memory::format_tag::ncw, |
329 | memory::format_tag::wio, memory::format_tag::nc, |
330 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
331 | inprod_test_params_float {memory::format_tag::ncw, |
332 | memory::format_tag::owi, memory::format_tag::nc, |
333 | EXPAND_SIZES_1D(2, 32, 48, 5)}, |
334 | inprod_test_params_float {memory::format_tag::nhwc, |
335 | memory::format_tag::hwio, memory::format_tag::nc, |
336 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
337 | inprod_test_params_float {memory::format_tag::nhwc, |
338 | memory::format_tag::oihw, memory::format_tag::nc, |
339 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
340 | inprod_test_params_float {memory::format_tag::nhwc, |
341 | memory::format_tag::iohw, memory::format_tag::nc, |
342 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
343 | inprod_test_params_float {memory::format_tag::nchw, |
344 | memory::format_tag::oihw, memory::format_tag::nc, |
345 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
346 | inprod_test_params_float {memory::format_tag::nchw, |
347 | memory::format_tag::hwio, memory::format_tag::nc, |
348 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
349 | inprod_test_params_float {memory::format_tag::nchw, |
350 | memory::format_tag::ohwi, memory::format_tag::nc, |
351 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
352 | inprod_test_params_float {memory::format_tag::nChw8c, |
353 | memory::format_tag::aBcd8b, memory::format_tag::nc, |
354 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
355 | inprod_test_params_float {memory::format_tag::any, |
356 | memory::format_tag::aBcd8b, memory::format_tag::nc, |
357 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
358 | inprod_test_params_float {memory::format_tag::nChw8c, |
359 | memory::format_tag::any, memory::format_tag::nc, |
360 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
361 | inprod_test_params_float {memory::format_tag::nChw8c, |
362 | memory::format_tag::aBcd8b, memory::format_tag::nc, |
363 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
364 | inprod_test_params_float {memory::format_tag::nChw16c, |
365 | memory::format_tag::aBcd16b, memory::format_tag::nc, |
366 | EXPAND_SIZES_2D(2, 32, 48, 6, 6)}, |
367 | inprod_test_params_float {memory::format_tag::nc, |
368 | memory::format_tag::oi, memory::format_tag::nc, |
369 | EXPAND_SIZES_2D(2, 32, 1152, 1, 1)}, |
370 | inprod_test_params_float {memory::format_tag::nc, |
371 | memory::format_tag::oi, memory::format_tag::nc, |
372 | EXPAND_SIZES_2D(2, 2, 4, 1, 1)}, |
373 | inprod_test_params_float {memory::format_tag::nc, |
374 | memory::format_tag::io, memory::format_tag::nc, |
375 | EXPAND_SIZES_2D(2, 8, 16, 1, 1)})); |
376 | |
377 | INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData3D, |
378 | inner_product_test_float, |
379 | ::testing::Values( |
380 | inprod_test_params_float {memory::format_tag::any, |
381 | memory::format_tag::any, memory::format_tag::any, |
382 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
383 | inprod_test_params_float {memory::format_tag::any, |
384 | memory::format_tag::any, memory::format_tag::any, |
385 | EXPAND_SIZES_3D(2, 1024, 48, 2, 2, 2)}, |
386 | inprod_test_params_float {memory::format_tag::ncdhw, |
387 | memory::format_tag::oidhw, memory::format_tag::nc, |
388 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
389 | inprod_test_params_float {memory::format_tag::ncdhw, |
390 | memory::format_tag::odhwi, memory::format_tag::nc, |
391 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
392 | inprod_test_params_float {memory::format_tag::ncdhw, |
393 | memory::format_tag::dhwio, memory::format_tag::nc, |
394 | EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)}, |
395 | inprod_test_params_float {memory::format_tag::ndhwc, |
396 | memory::format_tag::oidhw, memory::format_tag::nc, |
397 | EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}, |
398 | inprod_test_params_float {memory::format_tag::ndhwc, |
399 | memory::format_tag::odhwi, memory::format_tag::nc, |
400 | EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}, |
401 | inprod_test_params_float {memory::format_tag::ndhwc, |
402 | memory::format_tag::dhwio, memory::format_tag::nc, |
403 | EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)})); |
404 | |
405 | } // namespace dnnl |
406 | |