1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24struct test_inner_product_descr_t {
25 memory::dim mb;
26 memory::dim ic;
27 memory::dim oc;
28 memory::dim kd, kh, kw;
29};
30
31template <typename data_t>
32void compute_ref_inner_product_bwd_data(int ndims,
33 const test_inner_product_descr_t &ipd, const memory &diff_dst,
34 const memory &weights, const memory &diff_src) {
35 auto diff_dst_data = map_memory<data_t>(diff_dst);
36 auto weights_data = map_memory<data_t>(weights);
37 auto diff_src_data = map_memory<data_t>(diff_src);
38
39 const memory::desc diff_dst_d = diff_dst.get_desc();
40 const memory::desc weights_d = weights.get_desc();
41 const memory::desc diff_src_d = diff_src.get_desc();
42 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get());
43 const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get());
44 const dnnl::impl::memory_desc_wrapper diff_src_mdw(diff_src_d.get());
45
46 bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
47 if (ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
48 auto padded_ic = diff_src_d.get_padded_dims()[1];
49
50 dnnl::impl::parallel_nd(ipd.mb, ipd.ic, [&](memory::dim n, memory::dim ic) {
51 if (has_spatial) {
52 for_(memory::dim kd = 0; kd < ipd.kd; ++kd)
53 for_(memory::dim kh = 0; kh < ipd.kh; ++kh)
54 for (memory::dim kw = 0; kw < ipd.kw; ++kw) {
55 memory::dim dsidx = n * padded_ic * ipd.kd * ipd.kh * ipd.kw
56 + ic * ipd.kd * ipd.kh * ipd.kw + kd * ipd.kh * ipd.kw
57 + kh * ipd.kw + kw;
58 data_t *ds = &diff_src_data[diff_src_mdw.off_l(dsidx, true)];
59 *ds = data_t(0);
60 for (memory::dim oc = 0; oc < ipd.oc; ++oc) {
61 memory::dim ddidx = n * ipd.oc + oc;
62 memory::dim widx = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw
63 + ic * ipd.kd * ipd.kh * ipd.kw
64 + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw;
65 *ds += diff_dst_data[diff_dst_mdw.off_l(ddidx, true)]
66 * weights_data[weights_mdw.off_l(widx, true)];
67 }
68 }
69 } else {
70 memory::dim dsidx = n * ipd.ic + ic;
71 data_t *ds = &diff_src_data[diff_src_mdw.off_l(dsidx, true)];
72 *ds = data_t(0);
73 for (memory::dim oc = 0; oc < ipd.oc; ++oc) {
74 memory::dim ddidx = n * ipd.oc + oc;
75 memory::dim widx = oc * ipd.ic + ic;
76 *ds += diff_dst_data[diff_dst_mdw.off_l(ddidx, true)]
77 * weights_data[weights_mdw.off_l(widx, true)];
78 }
79 }
80 });
81}
82
83struct inprod_test_params_t {
84 memory::format_tag diff_src_format;
85 memory::format_tag weights_format;
86 memory::format_tag diff_dst_format;
87 int ndims;
88 test_inner_product_descr_t test_ipd;
89 bool expect_to_fail;
90 dnnl_status_t expected_status;
91};
92
93template <typename data_t>
94class inner_product_test_bwd_data_t
95 : public ::testing::TestWithParam<inprod_test_params_t> {
96protected:
97 void SetUp() override {
98 auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
99 SKIP_IF_CUDA(!cuda_check_format_tags(p.diff_src_format,
100 p.weights_format, p.diff_dst_format),
101 "Unsupported format tag");
102 SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions");
103 catch_expected_failures(
104 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
105 }
106
107 bool cuda_check_format_tags(memory::format_tag diff_src_format,
108 memory::format_tag wei_format, memory::format_tag diff_dst_format) {
109 bool diff_src_ok = diff_src_format == memory::format_tag::ncdhw
110 || diff_src_format == memory::format_tag::ndhwc
111 || diff_src_format == memory::format_tag::nchw
112 || diff_src_format == memory::format_tag::nhwc
113 || diff_src_format == memory::format_tag::ncw
114 || diff_src_format == memory::format_tag::nwc
115 || diff_src_format == memory::format_tag::nc
116 || diff_src_format == memory::format_tag::any;
117 bool wei_ok = wei_format == memory::format_tag::oidhw
118 || wei_format == memory::format_tag::odhwi
119 || wei_format == memory::format_tag::dhwio
120 || wei_format == memory::format_tag::oihw
121 || wei_format == memory::format_tag::hwio
122 || wei_format == memory::format_tag::ohwi
123 || wei_format == memory::format_tag::oiw
124 || wei_format == memory::format_tag::owi
125 || wei_format == memory::format_tag::wio
126 || wei_format == memory::format_tag::io
127 || wei_format == memory::format_tag::oi
128 || wei_format == memory::format_tag::any;
129 bool diff_dst_ok = diff_dst_format == memory::format_tag::any
130 || diff_dst_format == memory::format_tag::nc;
131
132 return diff_src_ok && wei_ok && diff_dst_ok;
133 }
134
135 void Test() {
136 auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
137 test_inner_product_descr_t ipd = p.test_ipd;
138 bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
139 if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
140
141 auto eng = get_test_engine();
142 auto strm = make_stream(eng);
143 memory::data_type data_type = data_traits<data_t>::data_type;
144 ASSERT_EQ(data_type, dnnl::memory::data_type::f32);
145
146 memory::dims diff_src_dims = {ipd.mb, ipd.ic},
147 wei_dims = {ipd.oc, ipd.ic};
148 if (has_spatial) {
149 if (p.ndims == 5) {
150 diff_src_dims.push_back(ipd.kd);
151 wei_dims.push_back(ipd.kd);
152 }
153 if (p.ndims >= 4) {
154 diff_src_dims.push_back(ipd.kh);
155 wei_dims.push_back(ipd.kh);
156 }
157 if (p.ndims >= 3) {
158 diff_src_dims.push_back(ipd.kw);
159 wei_dims.push_back(ipd.kw);
160 }
161 }
162 auto ip_diff_src_desc
163 = create_md(diff_src_dims, data_type, p.diff_src_format);
164 auto ip_weights_desc = create_md(wei_dims, data_type, p.weights_format);
165 auto ip_diff_dst_desc
166 = create_md({ipd.mb, ipd.oc}, data_type, p.diff_dst_format);
167
168 // Create inner product forward (hint for backward)
169 auto ip_fwd_pdesc
170 = inner_product_forward::primitive_desc(eng, prop_kind::forward,
171 ip_diff_src_desc, ip_weights_desc, ip_diff_dst_desc);
172
173 // Create inner product backward
174 auto ip_primitive_desc = inner_product_backward_data::primitive_desc(
175 eng, ip_diff_src_desc, ip_weights_desc, ip_diff_dst_desc,
176 ip_fwd_pdesc);
177
178 ip_primitive_desc = inner_product_backward_data::primitive_desc(
179 ip_primitive_desc.get()); // test construction from a C pd
180
181 auto ip_diff_src
182 = test::make_memory(ip_primitive_desc.diff_src_desc(), eng);
183 auto ip_weights
184 = test::make_memory(ip_primitive_desc.weights_desc(), eng);
185 auto ip_diff_dst
186 = test::make_memory(ip_primitive_desc.diff_dst_desc(), eng);
187 auto diff_src_ref
188 = test::make_memory(ip_primitive_desc.diff_src_desc(), eng);
189
190 fill_data<data_t>(ip_diff_dst.get_desc().get_size() / sizeof(data_t),
191 ip_diff_dst);
192 fill_data<data_t>(
193 ip_weights.get_desc().get_size() / sizeof(data_t), ip_weights);
194
195 check_zero_tail<data_t>(1, ip_diff_dst);
196 check_zero_tail<data_t>(1, ip_weights);
197
198 ASSERT_TRUE(ip_primitive_desc.query_md(
199 query::exec_arg_md, DNNL_ARG_DIFF_SRC)
200 == ip_primitive_desc.diff_src_desc());
201 ASSERT_TRUE(ip_primitive_desc.query_md(
202 query::exec_arg_md, DNNL_ARG_DIFF_DST)
203 == ip_primitive_desc.diff_dst_desc());
204 ASSERT_TRUE(
205 ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS)
206 == ip_primitive_desc.weights_desc());
207
208 ASSERT_EQ(ip_primitive_desc.get_prop_kind(), prop_kind::backward_data);
209
210 EXPECT_ANY_THROW(inner_product_backward_data(ip_primitive_desc, {}));
211 inner_product_backward_data(ip_primitive_desc)
212 .execute(strm,
213 {{DNNL_ARG_DIFF_DST, ip_diff_dst},
214 {DNNL_ARG_WEIGHTS, ip_weights},
215 {DNNL_ARG_DIFF_SRC, ip_diff_src}});
216 strm.wait();
217
218 compute_ref_inner_product_bwd_data<data_t>(
219 p.ndims == 5, ipd, ip_diff_dst, ip_weights, diff_src_ref);
220 check_zero_tail<data_t>(1, diff_src_ref);
221 compare_data<data_t>(diff_src_ref, ip_diff_src);
222 check_zero_tail<data_t>(0, ip_diff_src);
223 }
224};
225
226using inner_product_test_float = inner_product_test_bwd_data_t<float>;
227using inprod_test_params_float = inprod_test_params_t;
228
229#define EXPAND_SIZES_3D(...) \
230 5, { __VA_ARGS__ }
231#define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \
232 4, { mb, ic, oc, 1, kh, kw }
233#define EXPAND_SIZES_1D(mb, ic, oc, kw) \
234 3, { mb, ic, oc, 1, 1, kw }
235
236TEST_P(inner_product_test_float, TestsInnerProduct) {}
237
238INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardZeroDim,
239 inner_product_test_float,
240 ::testing::Values(inprod_test_params_float {memory::format_tag::any,
241 memory::format_tag::any, memory::format_tag::any,
242 EXPAND_SIZES_2D(0, 32, 48, 6, 6)}));
243
244INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardDataEF,
245 inner_product_test_float,
246 ::testing::Values(
247 inprod_test_params_float {memory::format_tag::any,
248 memory::format_tag::any, memory::format_tag::any,
249 EXPAND_SIZES_2D(2, 0, 48, 6, 6), true,
250 dnnl_invalid_arguments},
251 inprod_test_params_float {memory::format_tag::any,
252 memory::format_tag::any, memory::format_tag::any,
253 EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true,
254 dnnl_invalid_arguments},
255 inprod_test_params_float {memory::format_tag::any,
256 memory::format_tag::any, memory::format_tag::any,
257 EXPAND_SIZES_2D(2, -1, 48, 6, 6), true,
258 dnnl_invalid_arguments}));
259
260INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData_nCdhw8c,
261 inner_product_test_float,
262 ::testing::Values(
263 inprod_test_params_float {memory::format_tag::nCdhw8c,
264 memory::format_tag::aBcde8b, memory::format_tag::nc,
265 EXPAND_SIZES_3D(2, 9, 4, 2, 2, 2)},
266 inprod_test_params_float {memory::format_tag::nCdhw8c,
267 memory::format_tag::aBcde8b, memory::format_tag::nc,
268 EXPAND_SIZES_3D(2, 17, 16, 2, 2, 2)},
269 inprod_test_params_float {memory::format_tag::nCdhw8c,
270 memory::format_tag::aBcde8b, memory::format_tag::nc,
271 EXPAND_SIZES_3D(2, 29, 7, 2, 2, 2)}));
272
273INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData_nCdhw16c,
274 inner_product_test_float,
275 ::testing::Values(
276 inprod_test_params_float {memory::format_tag::nCdhw16c,
277 memory::format_tag::aBcde16b, memory::format_tag::nc,
278 EXPAND_SIZES_3D(2, 9, 4, 2, 2, 2)},
279 inprod_test_params_float {memory::format_tag::nCdhw16c,
280 memory::format_tag::aBcde16b, memory::format_tag::nc,
281 EXPAND_SIZES_3D(2, 17, 16, 2, 2, 2)},
282 inprod_test_params_float {memory::format_tag::nCdhw16c,
283 memory::format_tag::aBcde16b, memory::format_tag::nc,
284 EXPAND_SIZES_3D(2, 29, 7, 2, 2, 2)}));
285
286INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData_padded,
287 inner_product_test_float,
288 ::testing::Values(
289 inprod_test_params_float {memory::format_tag::nChw16c,
290 memory::format_tag::aBcd16b, memory::format_tag::nc,
291 EXPAND_SIZES_2D(2, 9, 4, 2, 2)},
292 inprod_test_params_float {memory::format_tag::nChw16c,
293 memory::format_tag::aBcd16b, memory::format_tag::nc,
294 EXPAND_SIZES_2D(2, 17, 16, 2, 2)},
295 inprod_test_params_float {memory::format_tag::nChw16c,
296 memory::format_tag::aBcd16b, memory::format_tag::nc,
297 EXPAND_SIZES_2D(2, 29, 7, 2, 2)},
298 inprod_test_params_float {memory::format_tag::nChw8c,
299 memory::format_tag::aBcd8b, memory::format_tag::nc,
300 EXPAND_SIZES_2D(2, 5, 4, 2, 2)},
301 inprod_test_params_float {memory::format_tag::nChw8c,
302 memory::format_tag::aBcd8b, memory::format_tag::nc,
303 EXPAND_SIZES_2D(2, 14, 16, 2, 2)},
304 inprod_test_params_float {memory::format_tag::nChw8c,
305 memory::format_tag::aBcd8b, memory::format_tag::nc,
306 EXPAND_SIZES_2D(2, 33, 7, 2, 2)}));
307
308INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData, inner_product_test_float,
309 ::testing::Values(
310 inprod_test_params_float {memory::format_tag::any,
311 memory::format_tag::any, memory::format_tag::any,
312 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
313 inprod_test_params_float {memory::format_tag::any,
314 memory::format_tag::any, memory::format_tag::any,
315 EXPAND_SIZES_2D(2, 1024, 48, 2, 2)},
316 inprod_test_params_float {memory::format_tag::nwc,
317 memory::format_tag::oiw, memory::format_tag::nc,
318 EXPAND_SIZES_1D(2, 32, 48, 6)},
319 inprod_test_params_float {memory::format_tag::nwc,
320 memory::format_tag::wio, memory::format_tag::nc,
321 EXPAND_SIZES_1D(2, 32, 48, 5)},
322 inprod_test_params_float {memory::format_tag::nwc,
323 memory::format_tag::owi, memory::format_tag::nc,
324 EXPAND_SIZES_1D(2, 32, 48, 5)},
325 inprod_test_params_float {memory::format_tag::ncw,
326 memory::format_tag::oiw, memory::format_tag::nc,
327 EXPAND_SIZES_1D(2, 32, 48, 5)},
328 inprod_test_params_float {memory::format_tag::ncw,
329 memory::format_tag::wio, memory::format_tag::nc,
330 EXPAND_SIZES_1D(2, 32, 48, 5)},
331 inprod_test_params_float {memory::format_tag::ncw,
332 memory::format_tag::owi, memory::format_tag::nc,
333 EXPAND_SIZES_1D(2, 32, 48, 5)},
334 inprod_test_params_float {memory::format_tag::nhwc,
335 memory::format_tag::hwio, memory::format_tag::nc,
336 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
337 inprod_test_params_float {memory::format_tag::nhwc,
338 memory::format_tag::oihw, memory::format_tag::nc,
339 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
340 inprod_test_params_float {memory::format_tag::nhwc,
341 memory::format_tag::iohw, memory::format_tag::nc,
342 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
343 inprod_test_params_float {memory::format_tag::nchw,
344 memory::format_tag::oihw, memory::format_tag::nc,
345 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
346 inprod_test_params_float {memory::format_tag::nchw,
347 memory::format_tag::hwio, memory::format_tag::nc,
348 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
349 inprod_test_params_float {memory::format_tag::nchw,
350 memory::format_tag::ohwi, memory::format_tag::nc,
351 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
352 inprod_test_params_float {memory::format_tag::nChw8c,
353 memory::format_tag::aBcd8b, memory::format_tag::nc,
354 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
355 inprod_test_params_float {memory::format_tag::any,
356 memory::format_tag::aBcd8b, memory::format_tag::nc,
357 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
358 inprod_test_params_float {memory::format_tag::nChw8c,
359 memory::format_tag::any, memory::format_tag::nc,
360 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
361 inprod_test_params_float {memory::format_tag::nChw8c,
362 memory::format_tag::aBcd8b, memory::format_tag::nc,
363 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
364 inprod_test_params_float {memory::format_tag::nChw16c,
365 memory::format_tag::aBcd16b, memory::format_tag::nc,
366 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
367 inprod_test_params_float {memory::format_tag::nc,
368 memory::format_tag::oi, memory::format_tag::nc,
369 EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
370 inprod_test_params_float {memory::format_tag::nc,
371 memory::format_tag::oi, memory::format_tag::nc,
372 EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
373 inprod_test_params_float {memory::format_tag::nc,
374 memory::format_tag::io, memory::format_tag::nc,
375 EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
376
377INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardData3D,
378 inner_product_test_float,
379 ::testing::Values(
380 inprod_test_params_float {memory::format_tag::any,
381 memory::format_tag::any, memory::format_tag::any,
382 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
383 inprod_test_params_float {memory::format_tag::any,
384 memory::format_tag::any, memory::format_tag::any,
385 EXPAND_SIZES_3D(2, 1024, 48, 2, 2, 2)},
386 inprod_test_params_float {memory::format_tag::ncdhw,
387 memory::format_tag::oidhw, memory::format_tag::nc,
388 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
389 inprod_test_params_float {memory::format_tag::ncdhw,
390 memory::format_tag::odhwi, memory::format_tag::nc,
391 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
392 inprod_test_params_float {memory::format_tag::ncdhw,
393 memory::format_tag::dhwio, memory::format_tag::nc,
394 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
395 inprod_test_params_float {memory::format_tag::ndhwc,
396 memory::format_tag::oidhw, memory::format_tag::nc,
397 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
398 inprod_test_params_float {memory::format_tag::ndhwc,
399 memory::format_tag::odhwi, memory::format_tag::nc,
400 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
401 inprod_test_params_float {memory::format_tag::ndhwc,
402 memory::format_tag::dhwio, memory::format_tag::nc,
403 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}));
404
405} // namespace dnnl
406