1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | using tag = memory::format_tag; |
25 | |
26 | /* iface tests */ |
27 | |
28 | class iface_sum_test_t : public ::testing::Test { |
29 | protected: |
30 | engine eng; |
31 | stream strm; |
32 | |
33 | void SetUp() override { |
34 | eng = get_test_engine(); |
35 | strm = make_stream(eng); |
36 | } |
37 | }; |
38 | |
39 | TEST_F(iface_sum_test_t, SumTestDstDataTypeCompliance) { |
40 | using dt = memory::data_type; |
41 | |
42 | const dt src_dt = dt::s8; |
43 | |
44 | memory::dims shape = {10, 10, 10, 10}; |
45 | auto src_md = memory::desc(shape, src_dt, tag::abcd); |
46 | |
47 | for_(tag dst_tag : {tag::any, tag::abcd, tag::acdb}) |
48 | for (dt dst_dt : {dt::undef, dt::s8, dt::s32, dt::f32}) { |
49 | sum::primitive_desc sum_pd; |
50 | SKIP_FOR_LOOP_CUDA(dst_dt == dt::s32, "Unsupported data_type" ); |
51 | if (dst_dt != dt::undef) { |
52 | memory::desc dst_md(shape, dst_dt, dst_tag); |
53 | sum_pd = sum::primitive_desc( |
54 | eng, dst_md, {2., 2.}, {src_md, src_md}); |
55 | } else { |
56 | sum_pd = sum::primitive_desc(eng, {2., 2.}, {src_md, src_md}); |
57 | } |
58 | |
59 | dt expect_dst_dt = dst_dt == dt::undef ? src_dt : dst_dt; |
60 | ASSERT_EQ(sum_pd.dst_desc().get_data_type(), expect_dst_dt); |
61 | } |
62 | } |
63 | |
64 | /* correctness tests */ |
65 | |
66 | struct sum_test_params { |
67 | std::vector<tag> srcs_format; |
68 | tag dst_format; |
69 | memory::dims dims; |
70 | std::vector<float> scale; |
71 | bool is_output_omitted; |
72 | bool expect_to_fail; |
73 | dnnl_status_t expected_status; |
74 | }; |
75 | |
76 | template <typename src_data_t, typename acc_t, typename dst_data_t = src_data_t> |
77 | class sum_test_t : public ::testing::TestWithParam<sum_test_params> { |
78 | private: |
79 | memory::data_type src_data_type; |
80 | memory::data_type dst_data_type; |
81 | |
82 | void check_data(const std::vector<memory> &srcs, |
83 | const std::vector<float> &scale, const memory &dst) { |
84 | auto dst_data = map_memory<const dst_data_t>(dst); |
85 | const auto &dst_d = dst.get_desc(); |
86 | const auto dst_dims = dst_d.get_dims(); |
87 | const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get()); |
88 | |
89 | std::vector<mapped_ptr_t<const src_data_t>> mapped_srcs; |
90 | mapped_srcs.reserve(srcs.size()); |
91 | for (auto &src : srcs) |
92 | mapped_srcs.emplace_back(map_memory<const src_data_t>(src)); |
93 | |
94 | dnnl::impl::parallel_nd(dst_dims[0], dst_dims[1], dst_dims[2], |
95 | dst_dims[3], |
96 | [&](memory::dim n, memory::dim c, memory::dim h, |
97 | memory::dim w) { |
98 | if (is_current_test_failed()) return; |
99 | |
100 | acc_t src_sum = 0.0; |
101 | for (size_t num = 0; num < srcs.size(); num++) { |
102 | auto &src_data = mapped_srcs[num]; |
103 | const auto &src_d = srcs[num].get_desc(); |
104 | const auto src_dims = src_d.get_dims(); |
105 | const dnnl::impl::memory_desc_wrapper src_mdw( |
106 | src_d.get()); |
107 | |
108 | auto src_idx = w + src_dims[3] * h |
109 | + src_dims[2] * src_dims[3] * c |
110 | + src_dims[1] * src_dims[2] * src_dims[3] * n; |
111 | if (num == 0) { |
112 | src_sum = acc_t(scale[num]) |
113 | * src_data[src_mdw.off_l(src_idx, false)]; |
114 | } else { |
115 | src_sum += acc_t(scale[num]) |
116 | * src_data[src_mdw.off_l(src_idx, false)]; |
117 | } |
118 | |
119 | src_sum = (std::max)( |
120 | (std::min)(src_sum, |
121 | (std::numeric_limits<acc_t>::max)()), |
122 | std::numeric_limits<acc_t>::lowest()); |
123 | } |
124 | |
125 | auto dst_idx = w + dst_dims[3] * h |
126 | + dst_dims[2] * dst_dims[3] * c |
127 | + dst_dims[1] * dst_dims[2] * dst_dims[3] * n; |
128 | |
129 | acc_t dst_val = dst_data[dst_mdw.off_l(dst_idx, false)]; |
130 | ASSERT_EQ(src_sum, dst_val); |
131 | }); |
132 | } |
133 | |
134 | protected: |
135 | bool cuda_supported_format_tag(memory::format_tag tag) { |
136 | return impl::utils::one_of(tag, dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd, |
137 | dnnl_abcde, dnnl_abcdef, dnnl_abdec, dnnl_acb, dnnl_acbde, |
138 | dnnl_acbdef, dnnl_acdb, dnnl_acdeb, dnnl_ba, dnnl_bac, |
139 | dnnl_bacd, dnnl_bca, dnnl_bcda, dnnl_bcdea, dnnl_cba, dnnl_cdba, |
140 | dnnl_cdeba, dnnl_decab, dnnl_defcab, dnnl_aBc4b, dnnl_aBcd4b, |
141 | dnnl_aBcde4b); |
142 | } |
143 | void SetUp() override { |
144 | src_data_type = data_traits<src_data_t>::data_type; |
145 | dst_data_type = data_traits<dst_data_t>::data_type; |
146 | sum_test_params p |
147 | = ::testing::TestWithParam<sum_test_params>::GetParam(); |
148 | SKIP_IF(get_test_engine_kind() == engine::kind::gpu |
149 | && src_data_type == memory::data_type::bf16, |
150 | "GPU does not support bfloat16 data type." ); |
151 | SKIP_IF(unsupported_data_type(src_data_type), |
152 | "Engine does not support this data type." ); |
153 | SKIP_IF(unsupported_data_type(dst_data_type), |
154 | "Engine does not support this data type." ); |
155 | |
156 | SKIP_IF_CUDA(!cuda_supported_format_tag(p.dst_format), |
157 | "Unsupported format tag" ); |
158 | for (size_t i = 0; i < p.srcs_format.size(); i++) { |
159 | SKIP_IF_CUDA(!cuda_supported_format_tag(p.srcs_format[i]), |
160 | "Unsupported format tag" ); |
161 | } |
162 | catch_expected_failures( |
163 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
164 | } |
165 | |
166 | void Test() { |
167 | sum_test_params p |
168 | = ::testing::TestWithParam<sum_test_params>::GetParam(); |
169 | |
170 | const auto num_srcs = p.srcs_format.size(); |
171 | |
172 | auto eng = get_test_engine(); |
173 | auto strm = make_stream(eng); |
174 | |
175 | std::vector<memory::desc> srcs_md; |
176 | std::vector<memory> srcs; |
177 | |
178 | for (size_t i = 0; i < num_srcs; i++) { |
179 | auto desc = memory::desc(p.dims, src_data_type, p.srcs_format[i]); |
180 | srcs_md.push_back(desc); |
181 | } |
182 | |
183 | memory dst; |
184 | sum::primitive_desc sum_pd; |
185 | |
186 | if (p.is_output_omitted) { |
187 | ASSERT_NO_THROW( |
188 | sum_pd = sum::primitive_desc(eng, p.scale, srcs_md)); |
189 | } else { |
190 | auto dst_desc = memory::desc(p.dims, dst_data_type, p.dst_format); |
191 | sum_pd = sum::primitive_desc(eng, dst_desc, p.scale, srcs_md); |
192 | |
193 | ASSERT_EQ(sum_pd.dst_desc().get_ndims(), dst_desc.get_ndims()); |
194 | } |
195 | dst = test::make_memory(sum_pd.dst_desc(), eng); |
196 | // test construction from a C pd |
197 | sum_pd = sum::primitive_desc(sum_pd.get()); |
198 | for (size_t i = 0; i < num_srcs; i++) { |
199 | if (p.srcs_format[i] != memory::format_tag::any) { |
200 | ASSERT_TRUE(srcs_md[(int)i] == sum_pd.src_desc((int)i)); |
201 | } |
202 | auto src_memory = test::make_memory(sum_pd.src_desc((int)i), eng); |
203 | const size_t sz |
204 | = src_memory.get_desc().get_size() / sizeof(src_data_t); |
205 | fill_data<src_data_t>(sz, src_memory); |
206 | |
207 | // Keep few mantissa digits for fp types to avoid round-off errors |
208 | // With proper scalars the computations give exact results |
209 | if (!std::is_integral<src_data_t>::value) { |
210 | using uint_type = typename data_traits<src_data_t>::uint_type; |
211 | int mant_digits |
212 | = dnnl::impl::nstl::numeric_limits<src_data_t>::digits; |
213 | int want_mant_digits = 3; |
214 | auto src_ptr = map_memory<src_data_t>(src_memory); |
215 | for (size_t i = 0; i < sz; i++) { |
216 | uint_type mask = (uint_type)-1 |
217 | << (mant_digits - want_mant_digits); |
218 | *((uint_type *)&src_ptr[i]) &= mask; |
219 | } |
220 | } |
221 | srcs.push_back(src_memory); |
222 | } |
223 | |
224 | ASSERT_TRUE(sum_pd.query_md(query::exec_arg_md, DNNL_ARG_DST) |
225 | == sum_pd.dst_desc()); |
226 | for (int i = 0; i < (int)srcs.size(); i++) |
227 | ASSERT_TRUE(sum_pd.query_md( |
228 | query::exec_arg_md, DNNL_ARG_MULTIPLE_SRC + i) |
229 | == sum_pd.src_desc(i)); |
230 | |
231 | { |
232 | auto dst_data = map_memory<dst_data_t>(dst); |
233 | const size_t sz = dst.get_desc().get_size() / sizeof(dst_data_t); |
234 | // overwriting dst to prevent false positives for test cases. |
235 | dnnl::impl::parallel_nd( |
236 | (ptrdiff_t)sz, [&](ptrdiff_t i) { dst_data[i] = -32; }); |
237 | } |
238 | EXPECT_ANY_THROW(sum(sum_pd, {})); |
239 | sum c(sum_pd); |
240 | std::unordered_map<int, memory> args = {{DNNL_ARG_DST, dst}}; |
241 | for (int i = 0; i < (int)num_srcs; i++) { |
242 | args.insert({DNNL_ARG_MULTIPLE_SRC + i, srcs[i]}); |
243 | } |
244 | c.execute(strm, args); |
245 | strm.wait(); |
246 | |
247 | check_data(srcs, p.scale, dst); |
248 | } |
249 | }; |
250 | |
251 | static auto simple_test_cases = [](bool omit_output) { |
252 | return ::testing::Values( |
253 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {0, 7, 4, 4}, |
254 | {1.0f, 1.0f}, omit_output}, |
255 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 0, 4, 4}, |
256 | {1.0f, 1.0f}, omit_output}, |
257 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 8, 0, 4}, |
258 | {1.0f, 1.0f}, omit_output}, |
259 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {-1, 8, 4, 4}, |
260 | {1.0f, 1.0f}, omit_output, true, dnnl_invalid_arguments}, |
261 | |
262 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, |
263 | {1, 1024, 38, 50}, {1.0f, 1.0f}, omit_output}, |
264 | sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 8, 2, 2}, |
265 | {1.0f, 1.0f}, omit_output}, |
266 | sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nChw8c, |
267 | {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output}, |
268 | sum_test_params {{tag::nchw, tag::nchw}, tag::nChw8c, {2, 16, 2, 2}, |
269 | {1.0f, 1.0f}, omit_output}, |
270 | sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nchw, |
271 | {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output}, |
272 | sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 8, 2, 2}, |
273 | {2.0f, 3.0f}, omit_output}, |
274 | sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nChw8c, |
275 | {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output}, |
276 | sum_test_params {{tag::nchw, tag::nchw}, tag::nChw8c, {2, 16, 2, 2}, |
277 | {2.0f, 3.0f}, omit_output}, |
278 | sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nchw, |
279 | {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output}, |
280 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {5, 8, 3, 3}, |
281 | {2.0f, 3.0f}, omit_output}, |
282 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, |
283 | {32, 32, 13, 14}, {2.0f, 3.0f}, omit_output}, |
284 | sum_test_params {{tag::nChw16c, tag::nChw8c}, tag::nChw16c, |
285 | {2, 16, 3, 3}, {2.0f, 3.0f}, omit_output}); |
286 | }; |
287 | |
288 | static auto simple_test_cases_bf16 = [](bool omit_output) { |
289 | return ::testing::Values( |
290 | sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c, |
291 | {1, 16, 1, 1}, {2.0f, 3.0f}, omit_output}, |
292 | sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {1, 16, 1, 1}, |
293 | {2.0f, 3.0f}, omit_output}, |
294 | sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 16, 13, 7}, |
295 | {2.0f, 3.0f}, omit_output}, |
296 | sum_test_params {{tag::nchw, tag::nchw, tag::nchw, tag::nchw}, |
297 | tag::nchw, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f}, |
298 | omit_output}, |
299 | sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw, |
300 | {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output}, |
301 | sum_test_params { |
302 | {tag::nchw, tag::nchw, tag::nchw, tag::nchw, tag::nchw}, |
303 | tag::nchw, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, |
304 | omit_output}, |
305 | sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw, |
306 | {2, 37, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output}, |
307 | sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw, |
308 | {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output}, |
309 | sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c, |
310 | {2, 16, 13, 7}, {2.0f, 3.0f}, omit_output}, |
311 | sum_test_params {{tag::nChw16c, tag::nChw16c, tag::nChw16c}, |
312 | tag::nChw16c, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, |
313 | omit_output}, |
314 | sum_test_params {{tag::nChw16c, tag::nChw16c, tag::nChw16c, |
315 | tag::nChw16c, tag::nChw16c}, |
316 | tag::nChw16c, {2, 16, 13, 7}, |
317 | {2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, omit_output}, |
318 | sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c, |
319 | {2, 128, 23, 15}, {2.5f, 0.125f}, omit_output}); |
320 | }; |
321 | |
322 | static auto special_test_cases = []() { |
323 | return ::testing::Values( |
324 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 8, 4, 4}, |
325 | {1.0f}, false, true, dnnl_invalid_arguments}, |
326 | sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {2, 8, 4, 4}, |
327 | {0.1f}, false, true, dnnl_invalid_arguments}, |
328 | sum_test_params {{tag::any, tag::nchw}, tag::nchw, {1, 16, 1, 1}, |
329 | {2.0f, 3.0f}, false, true, dnnl_invalid_arguments}, |
330 | sum_test_params {{tag::nchw, tag::any}, tag::nchw, {1, 16, 1, 1}, |
331 | {2.0f, 3.0f}, false, true, dnnl_invalid_arguments}); |
332 | }; |
333 | |
334 | /* corner cases */ |
335 | #define CASE_CC(itag0, itag1, otag, dims_, ef, st) \ |
336 | sum_test_params { \ |
337 | {tag::itag0, tag::itag1}, tag::otag, memory::dims dims_, {1.0f, 1.0f}, \ |
338 | 0, ef, st \ |
339 | } |
340 | static auto corner_test_cases = []() { |
341 | return ::testing::Values( |
342 | CASE_CC(nchw, nChw8c, nchw, ({0, 7, 4, 4}), false, dnnl_success), |
343 | CASE_CC(nchw, nChw8c, nchw, ({1, 0, 4, 4}), false, dnnl_success), |
344 | CASE_CC(nchw, nChw8c, nchw, ({1, 8, 0, 4}), false, dnnl_success), |
345 | CASE_CC(nchw, nChw8c, nchw, ({-1, 8, 4, 4}), true, |
346 | dnnl_invalid_arguments)); |
347 | }; |
348 | #undef CASE_CC |
349 | |
350 | #define CPU_INST_TEST_CASE(test, omit_output) \ |
351 | CPU_TEST_P(test, TestsSum) {} \ |
352 | CPU_INSTANTIATE_TEST_SUITE_P( \ |
353 | TestSum, test, simple_test_cases(omit_output)); \ |
354 | CPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases()); |
355 | |
356 | #define INST_TEST_CASE_BF16(test, omit_output) \ |
357 | CPU_TEST_P(test, TestsSum) {} \ |
358 | CPU_INSTANTIATE_TEST_SUITE_P( \ |
359 | TestSum, test, simple_test_cases(omit_output)); \ |
360 | CPU_INSTANTIATE_TEST_SUITE_P( \ |
361 | TestSumBf16, test, simple_test_cases_bf16(omit_output)); \ |
362 | CPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases()); |
363 | |
364 | #define GPU_INST_TEST_CASE(test, omit_output) \ |
365 | GPU_TEST_P(test, TestsSum) {} \ |
366 | GPU_INSTANTIATE_TEST_SUITE_P( \ |
367 | TestSum, test, simple_test_cases(omit_output)); \ |
368 | GPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases()); |
369 | |
370 | #define INST_TEST_CASE(test, omit_output) \ |
371 | CPU_INST_TEST_CASE(test, omit_output) \ |
372 | GPU_INST_TEST_CASE(test, omit_output) |
373 | |
374 | using sum_test_float_omit_output = sum_test_t<float, float>; |
375 | using sum_test_u8_omit_output = sum_test_t<uint8_t, int32_t>; |
376 | using sum_test_s8_omit_output = sum_test_t<int8_t, int32_t>; |
377 | using sum_test_s32_omit_output = sum_test_t<int32_t, float>; |
378 | using sum_test_f16_omit_output = sum_test_t<float16_t, float>; |
379 | using sum_test_bf16bf16_omit_output = sum_test_t<bfloat16_t, float>; |
380 | using sum_test_bf16f32_omit_output = sum_test_t<bfloat16_t, float, float>; |
381 | |
382 | using sum_test_float = sum_test_t<float, float>; |
383 | using sum_test_u8 = sum_test_t<uint8_t, int32_t>; |
384 | using sum_test_s8 = sum_test_t<int8_t, int32_t>; |
385 | using sum_test_s32 = sum_test_t<int32_t, float>; |
386 | using sum_test_f16 = sum_test_t<float16_t, float>; |
387 | using sum_test_bf16bf16 = sum_test_t<bfloat16_t, float>; |
388 | using sum_test_bf16f32 = sum_test_t<bfloat16_t, float, float>; |
389 | |
390 | using sum_cc_f32 = sum_test_t<float, float>; |
391 | |
392 | TEST_P(sum_cc_f32, TestSumCornerCases) {} |
393 | INSTANTIATE_TEST_SUITE_P(TestSumCornerCases, sum_cc_f32, corner_test_cases()); |
394 | |
395 | INST_TEST_CASE(sum_test_float_omit_output, 1) |
396 | INST_TEST_CASE(sum_test_u8_omit_output, 1) |
397 | INST_TEST_CASE(sum_test_s8_omit_output, 1) |
398 | INST_TEST_CASE(sum_test_s32_omit_output, 1) |
399 | INST_TEST_CASE_BF16(sum_test_bf16bf16_omit_output, 1) |
400 | // Automatically created dst descriptor has bf16 data type so this test is not |
401 | // valid: INST_TEST_CASE(sum_test_bf16f32_omit_output, 1) |
402 | INST_TEST_CASE(sum_test_f16_omit_output, 1) |
403 | |
404 | INST_TEST_CASE(sum_test_float, 0) |
405 | INST_TEST_CASE(sum_test_u8, 0) |
406 | INST_TEST_CASE(sum_test_s8, 0) |
407 | INST_TEST_CASE(sum_test_s32, 0) |
408 | INST_TEST_CASE_BF16(sum_test_bf16bf16, 0) |
409 | INST_TEST_CASE_BF16(sum_test_bf16f32, 0) |
410 | INST_TEST_CASE(sum_test_f16, 0) |
411 | |
412 | #undef CPU_INST_TEST_CASE |
413 | #undef GPU_INST_TEST_CASE |
414 | } // namespace dnnl |
415 | |