1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <memory> |
18 | |
19 | #include "dnnl_test_common.hpp" |
20 | #include "gtest/gtest.h" |
21 | |
22 | #include "oneapi/dnnl/dnnl.hpp" |
23 | |
24 | namespace dnnl { |
25 | |
26 | struct test_resampling_desc_t { |
27 | memory::dim mb, c; |
28 | memory::dim id, ih, iw; |
29 | memory::dim od, oh, ow; |
30 | float fd, fh, fw; |
31 | }; |
32 | |
33 | struct resampling_test_params_t { |
34 | prop_kind aprop_kind; |
35 | algorithm aalgorithm; |
36 | memory::format_tag src_format; |
37 | int ndims; |
38 | test_resampling_desc_t test_pd; |
39 | bool expect_to_fail; |
40 | dnnl_status_t expected_status; |
41 | }; |
42 | |
43 | float linear_map(memory::dim y, memory::dim y_max, memory::dim x_max) { |
44 | const float s = (y + 0.5f) * x_max / y_max; |
45 | return s - 0.5f; |
46 | } |
47 | memory::dim left_edge(memory::dim y, memory::dim y_max, memory::dim x_max) { |
48 | return std::max((int64_t)floor(linear_map(y, y_max, x_max)), (int64_t)0); |
49 | } |
50 | memory::dim right_edge(memory::dim y, memory::dim y_max, memory::dim x_max) { |
51 | return std::min((int64_t)ceil(linear_map(y, y_max, x_max)), x_max - 1); |
52 | } |
53 | memory::dim nearest_edge(memory::dim y, memory::dim y_max, memory::dim x_max) { |
54 | return std::round(linear_map(y, y_max, x_max)); |
55 | } |
56 | float linear_weight(memory::dim y, memory::dim y_max, memory::dim x_max) { |
57 | return fabs(linear_map(y, y_max, x_max) - left_edge(y, y_max, x_max)); |
58 | } |
59 | |
60 | template <typename data_t> |
61 | void compute_ref_resampling_fwd(const resampling_test_params_t &p, |
62 | const memory &src_m, const memory &dst_m) { |
63 | auto src_data = map_memory<data_t>(src_m); |
64 | auto dst_data = map_memory<data_t>(dst_m); |
65 | |
66 | const memory::desc src_d = src_m.get_desc(); |
67 | const memory::desc dst_d = dst_m.get_desc(); |
68 | |
69 | const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get()); |
70 | const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get()); |
71 | |
72 | auto pd = p.test_pd; |
73 | auto padded_c = src_mdw.padded_dims()[1]; |
74 | |
75 | auto src = [&](memory::dim n, memory::dim c, memory::dim d, memory::dim h, |
76 | memory::dim w) { |
77 | memory::dim idx = n * padded_c * pd.id * pd.ih * pd.iw |
78 | + c * pd.id * pd.ih * pd.iw + d * pd.ih * pd.iw + h * pd.iw + w; |
79 | return src_data[src_mdw.off_l(idx, true)]; |
80 | }; |
81 | |
82 | dnnl::impl::parallel_nd(pd.mb, pd.c, [&](memory::dim n, memory::dim c) { |
83 | for_(memory::dim od = 0; od < pd.od; od++) |
84 | for_(memory::dim oh = 0; oh < pd.oh; oh++) |
85 | for (memory::dim ow = 0; ow < pd.ow; ow++) { |
86 | memory::dim oidx = n * padded_c * pd.od * pd.oh * pd.ow |
87 | + c * pd.od * pd.oh * pd.ow + od * pd.oh * pd.ow |
88 | + oh * pd.ow + ow; |
89 | |
90 | if (p.aalgorithm == algorithm::resampling_nearest) { |
91 | memory::dim id = nearest_edge(od, pd.od, pd.id), |
92 | ih = nearest_edge(oh, pd.oh, pd.ih), |
93 | iw = nearest_edge(ow, pd.ow, pd.iw); |
94 | memory::dim iidx = n * padded_c * pd.id * pd.ih * pd.iw |
95 | + c * pd.id * pd.ih * pd.iw + id * pd.ih * pd.iw |
96 | + ih * pd.iw + iw; |
97 | dst_data[dst_mdw.off_l(oidx, true)] |
98 | = src_data[src_mdw.off_l(iidx, true)]; |
99 | } else if (p.aalgorithm == algorithm::resampling_linear) { |
100 | memory::dim id_left = left_edge(od, pd.od, pd.id), |
101 | id_right = right_edge(od, pd.od, pd.id), |
102 | ih_left = left_edge(oh, pd.oh, pd.ih), |
103 | ih_right = right_edge(oh, pd.oh, pd.ih), |
104 | iw_left = left_edge(ow, pd.ow, pd.iw), |
105 | iw_right = right_edge(ow, pd.ow, pd.iw); |
106 | float w_d = linear_weight(od, pd.od, pd.id), |
107 | w_h = linear_weight(oh, pd.oh, pd.ih), |
108 | w_w = linear_weight(ow, pd.ow, pd.iw); |
109 | float c00 = src(n, c, id_left, ih_left, iw_left) * (1 - w_d) |
110 | + src(n, c, id_right, ih_left, iw_left) * w_d; |
111 | float c01 = src(n, c, id_left, ih_left, iw_right) * (1 - w_d) |
112 | + src(n, c, id_right, ih_left, iw_right) * w_d; |
113 | float c10 = src(n, c, id_left, ih_right, iw_left) * (1 - w_d) |
114 | + src(n, c, id_right, ih_right, iw_left) * w_d; |
115 | float c11 = src(n, c, id_left, ih_right, iw_right) * (1 - w_d) |
116 | + src(n, c, id_right, ih_right, iw_right) * w_d; |
117 | float c0 = c00 * (1 - w_h) + c10 * w_h; |
118 | float c1 = c01 * (1 - w_h) + c11 * w_h; |
119 | dst_data[dst_mdw.off_l(oidx, true)] = c0 * (1 - w_w) + c1 * w_w; |
120 | } |
121 | } |
122 | }); |
123 | } |
124 | |
125 | template <typename data_t> |
126 | void compute_ref_resampling_bwd(const resampling_test_params_t &p, |
127 | const memory &diff_dst_m, const memory &diff_src_m) { |
128 | auto diff_src_data = map_memory<data_t>(diff_src_m); |
129 | auto diff_dst_data = map_memory<data_t>(diff_dst_m); |
130 | |
131 | const memory::desc diff_src_d = diff_src_m.get_desc(); |
132 | const memory::desc diff_dst_d = diff_dst_m.get_desc(); |
133 | |
134 | const dnnl::impl::memory_desc_wrapper diff_src_mdw(diff_src_d.get()); |
135 | const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get()); |
136 | |
137 | auto pd = p.test_pd; |
138 | auto padded_c = diff_src_mdw.padded_dims()[1]; |
139 | |
140 | auto off = [&](memory::dim n, memory::dim c, memory::dim d, memory::dim h, |
141 | memory::dim w) { |
142 | return diff_src_mdw.off_l(n * padded_c * pd.id * pd.ih * pd.iw |
143 | + c * pd.id * pd.ih * pd.iw + d * pd.ih * pd.iw |
144 | + h * pd.iw + w, |
145 | true); |
146 | }; |
147 | dnnl::impl::parallel_nd(pd.mb, pd.c, [&](memory::dim n, memory::dim c) { |
148 | for_(memory::dim id = 0; id < pd.id; id++) |
149 | for_(memory::dim ih = 0; ih < pd.ih; ih++) |
150 | for (memory::dim iw = 0; iw < pd.iw; iw++) { |
151 | memory::dim iidx = n * padded_c * pd.id * pd.ih * pd.iw |
152 | + c * pd.id * pd.ih * pd.iw + id * pd.ih * pd.iw |
153 | + ih * pd.iw + iw; |
154 | |
155 | diff_src_data[diff_src_mdw.off_l(iidx, true)] = 0.f; |
156 | } |
157 | for_(memory::dim od = 0; od < pd.od; od++) |
158 | for_(memory::dim oh = 0; oh < pd.oh; oh++) |
159 | for (memory::dim ow = 0; ow < pd.ow; ow++) { |
160 | memory::dim oidx = n * padded_c * pd.od * pd.oh * pd.ow |
161 | + c * pd.od * pd.oh * pd.ow + od * pd.oh * pd.ow |
162 | + oh * pd.ow + ow; |
163 | |
164 | if (p.aalgorithm == algorithm::resampling_nearest) { |
165 | memory::dim id = nearest_edge(od, pd.od, pd.id), |
166 | ih = nearest_edge(oh, pd.oh, pd.ih), |
167 | iw = nearest_edge(ow, pd.ow, pd.iw); |
168 | memory::dim iidx = n * padded_c * pd.id * pd.ih * pd.iw |
169 | + c * pd.id * pd.ih * pd.iw + id * pd.ih * pd.iw |
170 | + ih * pd.iw + iw; |
171 | diff_src_data[diff_src_mdw.off_l(iidx, true)] |
172 | += diff_dst_data[diff_dst_mdw.off_l(oidx, true)]; |
173 | } else if (p.aalgorithm == algorithm::resampling_linear) { |
174 | memory::dim id_left = left_edge(od, pd.od, pd.id), |
175 | id_right = right_edge(od, pd.od, pd.id), |
176 | ih_left = left_edge(oh, pd.oh, pd.ih), |
177 | ih_right = right_edge(oh, pd.oh, pd.ih), |
178 | iw_left = left_edge(ow, pd.ow, pd.iw), |
179 | iw_right = right_edge(ow, pd.ow, pd.iw); |
180 | float w_d = linear_weight(od, pd.od, pd.id), |
181 | w_h = linear_weight(oh, pd.oh, pd.ih), |
182 | w_w = linear_weight(ow, pd.ow, pd.iw); |
183 | float dd = diff_dst_data[diff_dst_mdw.off_l(oidx, true)]; |
184 | |
185 | diff_src_data[off(n, c, id_left, ih_left, iw_left)] |
186 | += (1 - w_d) * (1 - w_h) * (1 - w_w) * dd; |
187 | diff_src_data[off(n, c, id_right, ih_left, iw_left)] |
188 | += w_d * (1 - w_h) * (1 - w_w) * dd; |
189 | diff_src_data[off(n, c, id_left, ih_right, iw_left)] |
190 | += (1 - w_d) * w_h * (1 - w_w) * dd; |
191 | diff_src_data[off(n, c, id_left, ih_left, iw_right)] |
192 | += (1 - w_d) * (1 - w_h) * w_w * dd; |
193 | diff_src_data[off(n, c, id_right, ih_right, iw_left)] |
194 | += w_d * w_h * (1 - w_w) * dd; |
195 | diff_src_data[off(n, c, id_left, ih_right, iw_right)] |
196 | += (1 - w_d) * w_h * w_w * dd; |
197 | diff_src_data[off(n, c, id_right, ih_left, iw_right)] |
198 | += w_d * (1 - w_h) * w_w * dd; |
199 | diff_src_data[off(n, c, id_right, ih_right, iw_right)] |
200 | += w_d * w_h * w_w * dd; |
201 | } |
202 | } |
203 | }); |
204 | } |
205 | |
206 | template <typename data_t> |
207 | class resampling_test_t |
208 | : public ::testing::TestWithParam<resampling_test_params_t> { |
209 | private: |
210 | std::shared_ptr<test_memory> src, dst, diff_src, diff_dst; |
211 | std::shared_ptr<memory::desc> src_desc, dst_desc; |
212 | std::vector<float> factors; |
213 | std::vector<float> expected_factors; |
214 | resampling_forward::primitive_desc resampling_pd; |
215 | |
216 | resampling_test_params_t p; |
217 | engine eng; |
218 | stream strm; |
219 | |
220 | protected: |
221 | bool cuda_supported_format_tag(memory::format_tag tag) { |
222 | return impl::utils::one_of( |
223 | tag, dnnl_abc, dnnl_abcd, dnnl_acb, dnnl_acdb); |
224 | } |
225 | void SetUp() override { |
226 | p = ::testing::TestWithParam<decltype(p)>::GetParam(); |
227 | SKIP_IF_CUDA(p.aalgorithm == algorithm::resampling_nearest, |
228 | "nearet algorithm is not supported for cudnn backend" ); |
229 | SKIP_IF_CUDA(p.ndims == 5, |
230 | "cudnn resampling backend does not support 5d tensor" ); |
231 | SKIP_IF_CUDA(!cuda_supported_format_tag(p.src_format), |
232 | "Unsupported format tag" ); |
233 | |
234 | catch_expected_failures( |
235 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
236 | } |
237 | |
238 | void Test() { |
239 | p = ::testing::TestWithParam<decltype(p)>::GetParam(); |
240 | |
241 | eng = get_test_engine(); |
242 | strm = make_stream(eng); |
243 | |
244 | test_resampling_desc_t pd = p.test_pd; |
245 | |
246 | memory::dims src_dims = {pd.mb, pd.c}, dst_dims = {pd.mb, pd.c}; |
247 | // When `out_of_memory` testing is enabled, factors is expanded each |
248 | // time `Test()` is executed for any test. Clear the vector to avoid |
249 | // having its size > DNNL_MAX_NDIMS. |
250 | factors.clear(); |
251 | if (p.ndims == 5) { |
252 | factors.push_back(pd.fd); |
253 | src_dims.push_back(pd.id); |
254 | dst_dims.push_back(pd.od); |
255 | } |
256 | if (p.ndims >= 4) { |
257 | factors.push_back(pd.fh); |
258 | src_dims.push_back(pd.ih); |
259 | dst_dims.push_back(pd.oh); |
260 | } |
261 | if (p.ndims >= 3) { |
262 | factors.push_back(pd.fw); |
263 | src_dims.push_back(pd.iw); |
264 | dst_dims.push_back(pd.ow); |
265 | } |
266 | |
267 | memory::data_type data_type = data_traits<data_t>::data_type; |
268 | src_desc = std::make_shared<memory::desc>( |
269 | src_dims, data_type, p.src_format); |
270 | dst_desc = std::make_shared<memory::desc>( |
271 | dst_dims, data_type, p.src_format); |
272 | |
273 | for (int i = 0; i < src_desc->get_ndims() - 2; i++) { |
274 | expected_factors.push_back((double)dst_desc->get_dims()[2 + i] |
275 | / src_desc->get_dims()[2 + i]); |
276 | } |
277 | |
278 | Forward(); |
279 | Backward(); |
280 | } |
281 | |
282 | void Forward() { |
283 | resampling_pd = resampling_forward::primitive_desc( |
284 | eng, p.aprop_kind, p.aalgorithm, *src_desc, *dst_desc); |
285 | resampling_pd = resampling_forward::primitive_desc( |
286 | resampling_pd.get()); // test construction from a C pd |
287 | |
288 | { |
289 | auto resampling_desc_no_dst |
290 | = resampling_forward::primitive_desc(eng, p.aprop_kind, |
291 | p.aalgorithm, factors, resampling_pd.src_desc()); |
292 | auto resampling_pd_no_dst |
293 | = resampling_forward::primitive_desc(eng, p.aprop_kind, |
294 | p.aalgorithm, factors, resampling_pd.src_desc()); |
295 | ASSERT_EQ( |
296 | resampling_pd.dst_desc(), resampling_pd_no_dst.dst_desc()); |
297 | ASSERT_EQ(resampling_pd_no_dst.get_factors(), expected_factors); |
298 | } |
299 | |
300 | ASSERT_EQ(resampling_pd.get_prop_kind(), p.aprop_kind); |
301 | ASSERT_EQ(resampling_pd.get_algorithm(), p.aalgorithm); |
302 | ASSERT_EQ(resampling_pd.get_factors(), expected_factors); |
303 | |
304 | auto src = test::make_memory(resampling_pd.src_desc(), eng); |
305 | auto dst = test::make_memory(resampling_pd.dst_desc(), eng); |
306 | auto dst_ref = test::make_memory(resampling_pd.dst_desc(), eng); |
307 | |
308 | fill_data<data_t>(src.get_desc().get_size() / sizeof(data_t), src); |
309 | check_zero_tail<data_t>(1, src); |
310 | |
311 | EXPECT_ANY_THROW(resampling_forward(resampling_pd, {})); |
312 | resampling_forward(resampling_pd) |
313 | .execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); |
314 | strm.wait(); |
315 | |
316 | compute_ref_resampling_fwd<data_t>(p, src, dst_ref); |
317 | check_zero_tail<data_t>(1, dst_ref); |
318 | compare_data<data_t>(dst_ref, dst); |
319 | |
320 | check_zero_tail<data_t>(0, dst); |
321 | } |
322 | |
323 | void Backward() { |
324 | auto resampling_bwd_pd = resampling_backward::primitive_desc(eng, |
325 | p.aalgorithm, factors, *src_desc, *dst_desc, resampling_pd); |
326 | |
327 | auto diff_src |
328 | = test::make_memory(resampling_bwd_pd.diff_src_desc(), eng); |
329 | auto diff_dst |
330 | = test::make_memory(resampling_bwd_pd.diff_dst_desc(), eng); |
331 | auto diff_src_ref |
332 | = test::make_memory(resampling_bwd_pd.diff_src_desc(), eng); |
333 | |
334 | ASSERT_EQ(resampling_bwd_pd.get_prop_kind(), prop_kind::backward_data); |
335 | ASSERT_EQ(resampling_bwd_pd.get_algorithm(), p.aalgorithm); |
336 | ASSERT_EQ(resampling_bwd_pd.get_factors(), expected_factors); |
337 | |
338 | fill_data<data_t>( |
339 | diff_dst.get_desc().get_size() / sizeof(data_t), diff_dst); |
340 | check_zero_tail<data_t>(1, diff_dst); |
341 | check_zero_tail<data_t>(1, diff_src); |
342 | |
343 | EXPECT_ANY_THROW(resampling_backward(resampling_bwd_pd, {})); |
344 | resampling_backward(resampling_bwd_pd) |
345 | .execute(strm, |
346 | {{DNNL_ARG_DIFF_SRC, diff_src}, |
347 | {DNNL_ARG_DIFF_DST, diff_dst}}); |
348 | strm.wait(); |
349 | |
350 | compute_ref_resampling_bwd<data_t>(p, diff_dst, diff_src_ref); |
351 | check_zero_tail<data_t>(1, diff_src_ref); |
352 | compare_data<data_t>(diff_src_ref, diff_src); |
353 | check_zero_tail<data_t>(0, diff_src); |
354 | } |
355 | }; |
356 | |
357 | using resampling_test_float = resampling_test_t<float>; |
358 | |
359 | #define EXPAND_SIZES_3D(...) \ |
360 | 5, { __VA_ARGS__ } |
361 | #define EXPAND_SIZES_2D(mb, c, ih, iw, oh, ow, fh, fw) \ |
362 | 4, { mb, c, 1, ih, iw, 1, oh, ow, 1.f, fh, fw } |
363 | #define EXPAND_SIZES_1D(mb, c, iw, ow, fw) \ |
364 | 3, { mb, c, 1, 1, iw, 1, 1, ow, 1.f, 1.f, fw } |
365 | |
366 | TEST_P(resampling_test_float, TestsResampleF32) {} |
367 | |
368 | INSTANTIATE_TEST_SUITE_P(TestResampleEF, resampling_test_float, |
369 | ::testing::Values(resampling_test_params_t {prop_kind::forward, |
370 | algorithm::resampling_linear, memory::format_tag::any, |
371 | EXPAND_SIZES_1D(1, 1, 5, 10, 2.f), true, |
372 | dnnl_invalid_arguments})); |
373 | |
374 | INSTANTIATE_TEST_SUITE_P(TestResampleForwardPlainLinear, resampling_test_float, |
375 | ::testing::Values( |
376 | resampling_test_params_t {prop_kind::forward, |
377 | algorithm::resampling_linear, memory::format_tag::ncw, |
378 | EXPAND_SIZES_1D(1, 1, 5, 10, 2.f)}, |
379 | resampling_test_params_t {prop_kind::forward, |
380 | algorithm::resampling_linear, memory::format_tag::ncw, |
381 | EXPAND_SIZES_1D(1, 1, 525, 5, 0.01f)}, |
382 | resampling_test_params_t {prop_kind::forward, |
383 | algorithm::resampling_linear, memory::format_tag::ncw, |
384 | EXPAND_SIZES_1D(13, 10, 7, 13, 1.99f)}, |
385 | resampling_test_params_t {prop_kind::forward, |
386 | algorithm::resampling_linear, memory::format_tag::ncw, |
387 | EXPAND_SIZES_1D(10, 16, 7, 13, 1.9f)}, |
388 | resampling_test_params_t {prop_kind::forward, |
389 | algorithm::resampling_linear, memory::format_tag::nchw, |
390 | EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)}, |
391 | resampling_test_params_t {prop_kind::forward, |
392 | algorithm::resampling_linear, memory::format_tag::nhwc, |
393 | EXPAND_SIZES_2D(2, 14, 5, 5, 2, 3, 0.5f, 0.6f)}, |
394 | resampling_test_params_t {prop_kind::forward, |
395 | algorithm::resampling_linear, memory::format_tag::ndhwc, |
396 | EXPAND_SIZES_3D( |
397 | 1, 16, 5, 10, 1, 10, 5, 1, 2.f, 0.5f, 1.f)})); |
398 | |
399 | GPU_INSTANTIATE_TEST_SUITE_P(TestResamplePlainLinear, resampling_test_float, |
400 | ::testing::Values( |
401 | resampling_test_params_t {prop_kind::forward, |
402 | algorithm::resampling_linear, memory::format_tag::ncw, |
403 | EXPAND_SIZES_1D(1, 1, 5, 10, 2.f)}, |
404 | resampling_test_params_t {prop_kind::forward, |
405 | algorithm::resampling_linear, memory::format_tag::ncw, |
406 | EXPAND_SIZES_1D(1, 1, 525, 5, 0.01f)}, |
407 | resampling_test_params_t {prop_kind::forward, |
408 | algorithm::resampling_linear, memory::format_tag::ncw, |
409 | EXPAND_SIZES_1D(13, 10, 7, 13, 1.99f)}, |
410 | resampling_test_params_t {prop_kind::forward, |
411 | algorithm::resampling_linear, memory::format_tag::ncw, |
412 | EXPAND_SIZES_1D(10, 16, 7, 13, 1.9f)}, |
413 | resampling_test_params_t {prop_kind::forward, |
414 | algorithm::resampling_linear, memory::format_tag::nchw, |
415 | EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)}, |
416 | resampling_test_params_t {prop_kind::forward, |
417 | algorithm::resampling_linear, memory::format_tag::nhwc, |
418 | EXPAND_SIZES_2D(2, 14, 5, 5, 2, 3, 0.5f, 0.6f)})); |
419 | INSTANTIATE_TEST_SUITE_P(TestResampleForwardBlockedLinear, |
420 | resampling_test_float, |
421 | ::testing::Values( |
422 | resampling_test_params_t {prop_kind::forward, |
423 | algorithm::resampling_linear, |
424 | memory::format_tag::nChw8c, |
425 | EXPAND_SIZES_2D(32, 16, 14, 6, 28, 3, 2, 0.5f)}, |
426 | resampling_test_params_t {prop_kind::forward, |
427 | algorithm::resampling_linear, |
428 | memory::format_tag::nChw16c, |
429 | EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)}, |
430 | resampling_test_params_t {prop_kind::forward, |
431 | algorithm::resampling_linear, memory::format_tag::ncdhw, |
432 | EXPAND_SIZES_3D( |
433 | 1, 1, 5, 10, 15, 10, 5, 7, 2.f, 0.5f, 0.5f)})); |
434 | |
435 | INSTANTIATE_TEST_SUITE_P(TestResampleForwardPlainNN, resampling_test_float, |
436 | ::testing::Values( |
437 | resampling_test_params_t {prop_kind::forward, |
438 | algorithm::resampling_nearest, memory::format_tag::ncw, |
439 | EXPAND_SIZES_1D(10, 16, 5, 10, 2.f)}, |
440 | resampling_test_params_t {prop_kind::forward, |
441 | algorithm::resampling_nearest, memory::format_tag::ncw, |
442 | EXPAND_SIZES_1D(13, 10, 7, 13, 1.99f)}, |
443 | resampling_test_params_t {prop_kind::forward, |
444 | algorithm::resampling_nearest, memory::format_tag::ncw, |
445 | EXPAND_SIZES_1D(10, 16, 7, 13, 1.9f)}, |
446 | resampling_test_params_t {prop_kind::forward, |
447 | algorithm::resampling_nearest, memory::format_tag::nchw, |
448 | EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)}, |
449 | resampling_test_params_t {prop_kind::forward, |
450 | algorithm::resampling_nearest, memory::format_tag::nhwc, |
451 | EXPAND_SIZES_2D(64, 32, 5, 5, 2, 3, 0.5f, 0.6f)}, |
452 | resampling_test_params_t {prop_kind::forward, |
453 | algorithm::resampling_nearest, |
454 | memory::format_tag::ndhwc, |
455 | EXPAND_SIZES_3D( |
456 | 5, 5, 5, 10, 15, 10, 5, 7, 2.f, 0.5f, 0.5f)})); |
457 | |
458 | INSTANTIATE_TEST_SUITE_P(TestResampleForwardBlockedNN, resampling_test_float, |
459 | ::testing::Values( |
460 | resampling_test_params_t {prop_kind::forward, |
461 | algorithm::resampling_nearest, |
462 | memory::format_tag::nChw8c, |
463 | EXPAND_SIZES_2D(32, 16, 14, 6, 28, 3, 2, 0.5f)}, |
464 | resampling_test_params_t {prop_kind::forward, |
465 | algorithm::resampling_nearest, |
466 | memory::format_tag::nChw16c, |
467 | EXPAND_SIZES_2D(32, 10, 14, 7, 29, 5, 2.1f, 0.72f)}, |
468 | resampling_test_params_t {prop_kind::forward, |
469 | algorithm::resampling_nearest, |
470 | memory::format_tag::nCdhw16c, |
471 | EXPAND_SIZES_3D( |
472 | 5, 5, 5, 10, 15, 10, 5, 7, 2.f, 0.5f, 0.5f)})); |
473 | } // namespace dnnl |
474 | |