1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <memory>
18
19#include "dnnl_test_common.hpp"
20#include "gtest/gtest.h"
21
22#include "oneapi/dnnl/dnnl.hpp"
23
24namespace dnnl {
25
26struct test_resampling_desc_t {
27 memory::dim mb, c;
28 memory::dim id, ih, iw;
29 memory::dim od, oh, ow;
30 float fd, fh, fw;
31};
32
33struct resampling_test_params_t {
34 prop_kind aprop_kind;
35 algorithm aalgorithm;
36 memory::format_tag src_format;
37 int ndims;
38 test_resampling_desc_t test_pd;
39 bool expect_to_fail;
40 dnnl_status_t expected_status;
41};
42
43float linear_map(memory::dim y, memory::dim y_max, memory::dim x_max) {
44 const float s = (y + 0.5f) * x_max / y_max;
45 return s - 0.5f;
46}
47memory::dim left_edge(memory::dim y, memory::dim y_max, memory::dim x_max) {
48 return std::max((int64_t)floor(linear_map(y, y_max, x_max)), (int64_t)0);
49}
50memory::dim right_edge(memory::dim y, memory::dim y_max, memory::dim x_max) {
51 return std::min((int64_t)ceil(linear_map(y, y_max, x_max)), x_max - 1);
52}
53memory::dim nearest_edge(memory::dim y, memory::dim y_max, memory::dim x_max) {
54 return std::round(linear_map(y, y_max, x_max));
55}
56float linear_weight(memory::dim y, memory::dim y_max, memory::dim x_max) {
57 return fabs(linear_map(y, y_max, x_max) - left_edge(y, y_max, x_max));
58}
59
60template <typename data_t>
61void compute_ref_resampling_fwd(const resampling_test_params_t &p,
62 const memory &src_m, const memory &dst_m) {
63 auto src_data = map_memory<data_t>(src_m);
64 auto dst_data = map_memory<data_t>(dst_m);
65
66 const memory::desc src_d = src_m.get_desc();
67 const memory::desc dst_d = dst_m.get_desc();
68
69 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get());
70 const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get());
71
72 auto pd = p.test_pd;
73 auto padded_c = src_mdw.padded_dims()[1];
74
75 auto src = [&](memory::dim n, memory::dim c, memory::dim d, memory::dim h,
76 memory::dim w) {
77 memory::dim idx = n * padded_c * pd.id * pd.ih * pd.iw
78 + c * pd.id * pd.ih * pd.iw + d * pd.ih * pd.iw + h * pd.iw + w;
79 return src_data[src_mdw.off_l(idx, true)];
80 };
81
82 dnnl::impl::parallel_nd(pd.mb, pd.c, [&](memory::dim n, memory::dim c) {
83 for_(memory::dim od = 0; od < pd.od; od++)
84 for_(memory::dim oh = 0; oh < pd.oh; oh++)
85 for (memory::dim ow = 0; ow < pd.ow; ow++) {
86 memory::dim oidx = n * padded_c * pd.od * pd.oh * pd.ow
87 + c * pd.od * pd.oh * pd.ow + od * pd.oh * pd.ow
88 + oh * pd.ow + ow;
89
90 if (p.aalgorithm == algorithm::resampling_nearest) {
91 memory::dim id = nearest_edge(od, pd.od, pd.id),
92 ih = nearest_edge(oh, pd.oh, pd.ih),
93 iw = nearest_edge(ow, pd.ow, pd.iw);
94 memory::dim iidx = n * padded_c * pd.id * pd.ih * pd.iw
95 + c * pd.id * pd.ih * pd.iw + id * pd.ih * pd.iw
96 + ih * pd.iw + iw;
97 dst_data[dst_mdw.off_l(oidx, true)]
98 = src_data[src_mdw.off_l(iidx, true)];
99 } else if (p.aalgorithm == algorithm::resampling_linear) {
100 memory::dim id_left = left_edge(od, pd.od, pd.id),
101 id_right = right_edge(od, pd.od, pd.id),
102 ih_left = left_edge(oh, pd.oh, pd.ih),
103 ih_right = right_edge(oh, pd.oh, pd.ih),
104 iw_left = left_edge(ow, pd.ow, pd.iw),
105 iw_right = right_edge(ow, pd.ow, pd.iw);
106 float w_d = linear_weight(od, pd.od, pd.id),
107 w_h = linear_weight(oh, pd.oh, pd.ih),
108 w_w = linear_weight(ow, pd.ow, pd.iw);
109 float c00 = src(n, c, id_left, ih_left, iw_left) * (1 - w_d)
110 + src(n, c, id_right, ih_left, iw_left) * w_d;
111 float c01 = src(n, c, id_left, ih_left, iw_right) * (1 - w_d)
112 + src(n, c, id_right, ih_left, iw_right) * w_d;
113 float c10 = src(n, c, id_left, ih_right, iw_left) * (1 - w_d)
114 + src(n, c, id_right, ih_right, iw_left) * w_d;
115 float c11 = src(n, c, id_left, ih_right, iw_right) * (1 - w_d)
116 + src(n, c, id_right, ih_right, iw_right) * w_d;
117 float c0 = c00 * (1 - w_h) + c10 * w_h;
118 float c1 = c01 * (1 - w_h) + c11 * w_h;
119 dst_data[dst_mdw.off_l(oidx, true)] = c0 * (1 - w_w) + c1 * w_w;
120 }
121 }
122 });
123}
124
125template <typename data_t>
126void compute_ref_resampling_bwd(const resampling_test_params_t &p,
127 const memory &diff_dst_m, const memory &diff_src_m) {
128 auto diff_src_data = map_memory<data_t>(diff_src_m);
129 auto diff_dst_data = map_memory<data_t>(diff_dst_m);
130
131 const memory::desc diff_src_d = diff_src_m.get_desc();
132 const memory::desc diff_dst_d = diff_dst_m.get_desc();
133
134 const dnnl::impl::memory_desc_wrapper diff_src_mdw(diff_src_d.get());
135 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get());
136
137 auto pd = p.test_pd;
138 auto padded_c = diff_src_mdw.padded_dims()[1];
139
140 auto off = [&](memory::dim n, memory::dim c, memory::dim d, memory::dim h,
141 memory::dim w) {
142 return diff_src_mdw.off_l(n * padded_c * pd.id * pd.ih * pd.iw
143 + c * pd.id * pd.ih * pd.iw + d * pd.ih * pd.iw
144 + h * pd.iw + w,
145 true);
146 };
147 dnnl::impl::parallel_nd(pd.mb, pd.c, [&](memory::dim n, memory::dim c) {
148 for_(memory::dim id = 0; id < pd.id; id++)
149 for_(memory::dim ih = 0; ih < pd.ih; ih++)
150 for (memory::dim iw = 0; iw < pd.iw; iw++) {
151 memory::dim iidx = n * padded_c * pd.id * pd.ih * pd.iw
152 + c * pd.id * pd.ih * pd.iw + id * pd.ih * pd.iw
153 + ih * pd.iw + iw;
154
155 diff_src_data[diff_src_mdw.off_l(iidx, true)] = 0.f;
156 }
157 for_(memory::dim od = 0; od < pd.od; od++)
158 for_(memory::dim oh = 0; oh < pd.oh; oh++)
159 for (memory::dim ow = 0; ow < pd.ow; ow++) {
160 memory::dim oidx = n * padded_c * pd.od * pd.oh * pd.ow
161 + c * pd.od * pd.oh * pd.ow + od * pd.oh * pd.ow
162 + oh * pd.ow + ow;
163
164 if (p.aalgorithm == algorithm::resampling_nearest) {
165 memory::dim id = nearest_edge(od, pd.od, pd.id),
166 ih = nearest_edge(oh, pd.oh, pd.ih),
167 iw = nearest_edge(ow, pd.ow, pd.iw);
168 memory::dim iidx = n * padded_c * pd.id * pd.ih * pd.iw
169 + c * pd.id * pd.ih * pd.iw + id * pd.ih * pd.iw
170 + ih * pd.iw + iw;
171 diff_src_data[diff_src_mdw.off_l(iidx, true)]
172 += diff_dst_data[diff_dst_mdw.off_l(oidx, true)];
173 } else if (p.aalgorithm == algorithm::resampling_linear) {
174 memory::dim id_left = left_edge(od, pd.od, pd.id),
175 id_right = right_edge(od, pd.od, pd.id),
176 ih_left = left_edge(oh, pd.oh, pd.ih),
177 ih_right = right_edge(oh, pd.oh, pd.ih),
178 iw_left = left_edge(ow, pd.ow, pd.iw),
179 iw_right = right_edge(ow, pd.ow, pd.iw);
180 float w_d = linear_weight(od, pd.od, pd.id),
181 w_h = linear_weight(oh, pd.oh, pd.ih),
182 w_w = linear_weight(ow, pd.ow, pd.iw);
183 float dd = diff_dst_data[diff_dst_mdw.off_l(oidx, true)];
184
185 diff_src_data[off(n, c, id_left, ih_left, iw_left)]
186 += (1 - w_d) * (1 - w_h) * (1 - w_w) * dd;
187 diff_src_data[off(n, c, id_right, ih_left, iw_left)]
188 += w_d * (1 - w_h) * (1 - w_w) * dd;
189 diff_src_data[off(n, c, id_left, ih_right, iw_left)]
190 += (1 - w_d) * w_h * (1 - w_w) * dd;
191 diff_src_data[off(n, c, id_left, ih_left, iw_right)]
192 += (1 - w_d) * (1 - w_h) * w_w * dd;
193 diff_src_data[off(n, c, id_right, ih_right, iw_left)]
194 += w_d * w_h * (1 - w_w) * dd;
195 diff_src_data[off(n, c, id_left, ih_right, iw_right)]
196 += (1 - w_d) * w_h * w_w * dd;
197 diff_src_data[off(n, c, id_right, ih_left, iw_right)]
198 += w_d * (1 - w_h) * w_w * dd;
199 diff_src_data[off(n, c, id_right, ih_right, iw_right)]
200 += w_d * w_h * w_w * dd;
201 }
202 }
203 });
204}
205
206template <typename data_t>
207class resampling_test_t
208 : public ::testing::TestWithParam<resampling_test_params_t> {
209private:
210 std::shared_ptr<test_memory> src, dst, diff_src, diff_dst;
211 std::shared_ptr<memory::desc> src_desc, dst_desc;
212 std::vector<float> factors;
213 std::vector<float> expected_factors;
214 resampling_forward::primitive_desc resampling_pd;
215
216 resampling_test_params_t p;
217 engine eng;
218 stream strm;
219
220protected:
221 bool cuda_supported_format_tag(memory::format_tag tag) {
222 return impl::utils::one_of(
223 tag, dnnl_abc, dnnl_abcd, dnnl_acb, dnnl_acdb);
224 }
225 void SetUp() override {
226 p = ::testing::TestWithParam<decltype(p)>::GetParam();
227 SKIP_IF_CUDA(p.aalgorithm == algorithm::resampling_nearest,
228 "nearet algorithm is not supported for cudnn backend");
229 SKIP_IF_CUDA(p.ndims == 5,
230 "cudnn resampling backend does not support 5d tensor");
231 SKIP_IF_CUDA(!cuda_supported_format_tag(p.src_format),
232 "Unsupported format tag");
233
234 catch_expected_failures(
235 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
236 }
237
238 void Test() {
239 p = ::testing::TestWithParam<decltype(p)>::GetParam();
240
241 eng = get_test_engine();
242 strm = make_stream(eng);
243
244 test_resampling_desc_t pd = p.test_pd;
245
246 memory::dims src_dims = {pd.mb, pd.c}, dst_dims = {pd.mb, pd.c};
247 // When `out_of_memory` testing is enabled, factors is expanded each
248 // time `Test()` is executed for any test. Clear the vector to avoid
249 // having its size > DNNL_MAX_NDIMS.
250 factors.clear();
251 if (p.ndims == 5) {
252 factors.push_back(pd.fd);
253 src_dims.push_back(pd.id);
254 dst_dims.push_back(pd.od);
255 }
256 if (p.ndims >= 4) {
257 factors.push_back(pd.fh);
258 src_dims.push_back(pd.ih);
259 dst_dims.push_back(pd.oh);
260 }
261 if (p.ndims >= 3) {
262 factors.push_back(pd.fw);
263 src_dims.push_back(pd.iw);
264 dst_dims.push_back(pd.ow);
265 }
266
267 memory::data_type data_type = data_traits<data_t>::data_type;
268 src_desc = std::make_shared<memory::desc>(
269 src_dims, data_type, p.src_format);
270 dst_desc = std::make_shared<memory::desc>(
271 dst_dims, data_type, p.src_format);
272
273 for (int i = 0; i < src_desc->get_ndims() - 2; i++) {
274 expected_factors.push_back((double)dst_desc->get_dims()[2 + i]
275 / src_desc->get_dims()[2 + i]);
276 }
277
278 Forward();
279 Backward();
280 }
281
282 void Forward() {
283 resampling_pd = resampling_forward::primitive_desc(
284 eng, p.aprop_kind, p.aalgorithm, *src_desc, *dst_desc);
285 resampling_pd = resampling_forward::primitive_desc(
286 resampling_pd.get()); // test construction from a C pd
287
288 {
289 auto resampling_desc_no_dst
290 = resampling_forward::primitive_desc(eng, p.aprop_kind,
291 p.aalgorithm, factors, resampling_pd.src_desc());
292 auto resampling_pd_no_dst
293 = resampling_forward::primitive_desc(eng, p.aprop_kind,
294 p.aalgorithm, factors, resampling_pd.src_desc());
295 ASSERT_EQ(
296 resampling_pd.dst_desc(), resampling_pd_no_dst.dst_desc());
297 ASSERT_EQ(resampling_pd_no_dst.get_factors(), expected_factors);
298 }
299
300 ASSERT_EQ(resampling_pd.get_prop_kind(), p.aprop_kind);
301 ASSERT_EQ(resampling_pd.get_algorithm(), p.aalgorithm);
302 ASSERT_EQ(resampling_pd.get_factors(), expected_factors);
303
304 auto src = test::make_memory(resampling_pd.src_desc(), eng);
305 auto dst = test::make_memory(resampling_pd.dst_desc(), eng);
306 auto dst_ref = test::make_memory(resampling_pd.dst_desc(), eng);
307
308 fill_data<data_t>(src.get_desc().get_size() / sizeof(data_t), src);
309 check_zero_tail<data_t>(1, src);
310
311 EXPECT_ANY_THROW(resampling_forward(resampling_pd, {}));
312 resampling_forward(resampling_pd)
313 .execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}});
314 strm.wait();
315
316 compute_ref_resampling_fwd<data_t>(p, src, dst_ref);
317 check_zero_tail<data_t>(1, dst_ref);
318 compare_data<data_t>(dst_ref, dst);
319
320 check_zero_tail<data_t>(0, dst);
321 }
322
323 void Backward() {
324 auto resampling_bwd_pd = resampling_backward::primitive_desc(eng,
325 p.aalgorithm, factors, *src_desc, *dst_desc, resampling_pd);
326
327 auto diff_src
328 = test::make_memory(resampling_bwd_pd.diff_src_desc(), eng);
329 auto diff_dst
330 = test::make_memory(resampling_bwd_pd.diff_dst_desc(), eng);
331 auto diff_src_ref
332 = test::make_memory(resampling_bwd_pd.diff_src_desc(), eng);
333
334 ASSERT_EQ(resampling_bwd_pd.get_prop_kind(), prop_kind::backward_data);
335 ASSERT_EQ(resampling_bwd_pd.get_algorithm(), p.aalgorithm);
336 ASSERT_EQ(resampling_bwd_pd.get_factors(), expected_factors);
337
338 fill_data<data_t>(
339 diff_dst.get_desc().get_size() / sizeof(data_t), diff_dst);
340 check_zero_tail<data_t>(1, diff_dst);
341 check_zero_tail<data_t>(1, diff_src);
342
343 EXPECT_ANY_THROW(resampling_backward(resampling_bwd_pd, {}));
344 resampling_backward(resampling_bwd_pd)
345 .execute(strm,
346 {{DNNL_ARG_DIFF_SRC, diff_src},
347 {DNNL_ARG_DIFF_DST, diff_dst}});
348 strm.wait();
349
350 compute_ref_resampling_bwd<data_t>(p, diff_dst, diff_src_ref);
351 check_zero_tail<data_t>(1, diff_src_ref);
352 compare_data<data_t>(diff_src_ref, diff_src);
353 check_zero_tail<data_t>(0, diff_src);
354 }
355};
356
357using resampling_test_float = resampling_test_t<float>;
358
359#define EXPAND_SIZES_3D(...) \
360 5, { __VA_ARGS__ }
361#define EXPAND_SIZES_2D(mb, c, ih, iw, oh, ow, fh, fw) \
362 4, { mb, c, 1, ih, iw, 1, oh, ow, 1.f, fh, fw }
363#define EXPAND_SIZES_1D(mb, c, iw, ow, fw) \
364 3, { mb, c, 1, 1, iw, 1, 1, ow, 1.f, 1.f, fw }
365
366TEST_P(resampling_test_float, TestsResampleF32) {}
367
368INSTANTIATE_TEST_SUITE_P(TestResampleEF, resampling_test_float,
369 ::testing::Values(resampling_test_params_t {prop_kind::forward,
370 algorithm::resampling_linear, memory::format_tag::any,
371 EXPAND_SIZES_1D(1, 1, 5, 10, 2.f), true,
372 dnnl_invalid_arguments}));
373
374INSTANTIATE_TEST_SUITE_P(TestResampleForwardPlainLinear, resampling_test_float,
375 ::testing::Values(
376 resampling_test_params_t {prop_kind::forward,
377 algorithm::resampling_linear, memory::format_tag::ncw,
378 EXPAND_SIZES_1D(1, 1, 5, 10, 2.f)},
379 resampling_test_params_t {prop_kind::forward,
380 algorithm::resampling_linear, memory::format_tag::ncw,
381 EXPAND_SIZES_1D(1, 1, 525, 5, 0.01f)},
382 resampling_test_params_t {prop_kind::forward,
383 algorithm::resampling_linear, memory::format_tag::ncw,
384 EXPAND_SIZES_1D(13, 10, 7, 13, 1.99f)},
385 resampling_test_params_t {prop_kind::forward,
386 algorithm::resampling_linear, memory::format_tag::ncw,
387 EXPAND_SIZES_1D(10, 16, 7, 13, 1.9f)},
388 resampling_test_params_t {prop_kind::forward,
389 algorithm::resampling_linear, memory::format_tag::nchw,
390 EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)},
391 resampling_test_params_t {prop_kind::forward,
392 algorithm::resampling_linear, memory::format_tag::nhwc,
393 EXPAND_SIZES_2D(2, 14, 5, 5, 2, 3, 0.5f, 0.6f)},
394 resampling_test_params_t {prop_kind::forward,
395 algorithm::resampling_linear, memory::format_tag::ndhwc,
396 EXPAND_SIZES_3D(
397 1, 16, 5, 10, 1, 10, 5, 1, 2.f, 0.5f, 1.f)}));
398
399GPU_INSTANTIATE_TEST_SUITE_P(TestResamplePlainLinear, resampling_test_float,
400 ::testing::Values(
401 resampling_test_params_t {prop_kind::forward,
402 algorithm::resampling_linear, memory::format_tag::ncw,
403 EXPAND_SIZES_1D(1, 1, 5, 10, 2.f)},
404 resampling_test_params_t {prop_kind::forward,
405 algorithm::resampling_linear, memory::format_tag::ncw,
406 EXPAND_SIZES_1D(1, 1, 525, 5, 0.01f)},
407 resampling_test_params_t {prop_kind::forward,
408 algorithm::resampling_linear, memory::format_tag::ncw,
409 EXPAND_SIZES_1D(13, 10, 7, 13, 1.99f)},
410 resampling_test_params_t {prop_kind::forward,
411 algorithm::resampling_linear, memory::format_tag::ncw,
412 EXPAND_SIZES_1D(10, 16, 7, 13, 1.9f)},
413 resampling_test_params_t {prop_kind::forward,
414 algorithm::resampling_linear, memory::format_tag::nchw,
415 EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)},
416 resampling_test_params_t {prop_kind::forward,
417 algorithm::resampling_linear, memory::format_tag::nhwc,
418 EXPAND_SIZES_2D(2, 14, 5, 5, 2, 3, 0.5f, 0.6f)}));
419INSTANTIATE_TEST_SUITE_P(TestResampleForwardBlockedLinear,
420 resampling_test_float,
421 ::testing::Values(
422 resampling_test_params_t {prop_kind::forward,
423 algorithm::resampling_linear,
424 memory::format_tag::nChw8c,
425 EXPAND_SIZES_2D(32, 16, 14, 6, 28, 3, 2, 0.5f)},
426 resampling_test_params_t {prop_kind::forward,
427 algorithm::resampling_linear,
428 memory::format_tag::nChw16c,
429 EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)},
430 resampling_test_params_t {prop_kind::forward,
431 algorithm::resampling_linear, memory::format_tag::ncdhw,
432 EXPAND_SIZES_3D(
433 1, 1, 5, 10, 15, 10, 5, 7, 2.f, 0.5f, 0.5f)}));
434
435INSTANTIATE_TEST_SUITE_P(TestResampleForwardPlainNN, resampling_test_float,
436 ::testing::Values(
437 resampling_test_params_t {prop_kind::forward,
438 algorithm::resampling_nearest, memory::format_tag::ncw,
439 EXPAND_SIZES_1D(10, 16, 5, 10, 2.f)},
440 resampling_test_params_t {prop_kind::forward,
441 algorithm::resampling_nearest, memory::format_tag::ncw,
442 EXPAND_SIZES_1D(13, 10, 7, 13, 1.99f)},
443 resampling_test_params_t {prop_kind::forward,
444 algorithm::resampling_nearest, memory::format_tag::ncw,
445 EXPAND_SIZES_1D(10, 16, 7, 13, 1.9f)},
446 resampling_test_params_t {prop_kind::forward,
447 algorithm::resampling_nearest, memory::format_tag::nchw,
448 EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)},
449 resampling_test_params_t {prop_kind::forward,
450 algorithm::resampling_nearest, memory::format_tag::nhwc,
451 EXPAND_SIZES_2D(64, 32, 5, 5, 2, 3, 0.5f, 0.6f)},
452 resampling_test_params_t {prop_kind::forward,
453 algorithm::resampling_nearest,
454 memory::format_tag::ndhwc,
455 EXPAND_SIZES_3D(
456 5, 5, 5, 10, 15, 10, 5, 7, 2.f, 0.5f, 0.5f)}));
457
458INSTANTIATE_TEST_SUITE_P(TestResampleForwardBlockedNN, resampling_test_float,
459 ::testing::Values(
460 resampling_test_params_t {prop_kind::forward,
461 algorithm::resampling_nearest,
462 memory::format_tag::nChw8c,
463 EXPAND_SIZES_2D(32, 16, 14, 6, 28, 3, 2, 0.5f)},
464 resampling_test_params_t {prop_kind::forward,
465 algorithm::resampling_nearest,
466 memory::format_tag::nChw16c,
467 EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)},
468 resampling_test_params_t {prop_kind::forward,
469 algorithm::resampling_nearest,
470 memory::format_tag::nCdhw16c,
471 EXPAND_SIZES_3D(
472 5, 5, 5, 10, 15, 10, 5, 7, 2.f, 0.5f, 0.5f)}));
473} // namespace dnnl
474