1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24using tag = memory::format_tag;
25
26/* iface tests */
27
28class iface_sum_test_t : public ::testing::Test {
29protected:
30 engine eng;
31 stream strm;
32
33 void SetUp() override {
34 eng = get_test_engine();
35 strm = make_stream(eng);
36 }
37};
38
39TEST_F(iface_sum_test_t, SumTestDstDataTypeCompliance) {
40 using dt = memory::data_type;
41
42 const dt src_dt = dt::s8;
43
44 memory::dims shape = {10, 10, 10, 10};
45 auto src_md = memory::desc(shape, src_dt, tag::abcd);
46
47 for_(tag dst_tag : {tag::any, tag::abcd, tag::acdb})
48 for (dt dst_dt : {dt::undef, dt::s8, dt::s32, dt::f32}) {
49 sum::primitive_desc sum_pd;
50 SKIP_FOR_LOOP_CUDA(dst_dt == dt::s32, "Unsupported data_type");
51 if (dst_dt != dt::undef) {
52 memory::desc dst_md(shape, dst_dt, dst_tag);
53 sum_pd = sum::primitive_desc(
54 eng, dst_md, {2., 2.}, {src_md, src_md});
55 } else {
56 sum_pd = sum::primitive_desc(eng, {2., 2.}, {src_md, src_md});
57 }
58
59 dt expect_dst_dt = dst_dt == dt::undef ? src_dt : dst_dt;
60 ASSERT_EQ(sum_pd.dst_desc().get_data_type(), expect_dst_dt);
61 }
62}
63
64/* correctness tests */
65
66struct sum_test_params {
67 std::vector<tag> srcs_format;
68 tag dst_format;
69 memory::dims dims;
70 std::vector<float> scale;
71 bool is_output_omitted;
72 bool expect_to_fail;
73 dnnl_status_t expected_status;
74};
75
76template <typename src_data_t, typename acc_t, typename dst_data_t = src_data_t>
77class sum_test_t : public ::testing::TestWithParam<sum_test_params> {
78private:
79 memory::data_type src_data_type;
80 memory::data_type dst_data_type;
81
82 void check_data(const std::vector<memory> &srcs,
83 const std::vector<float> &scale, const memory &dst) {
84 auto dst_data = map_memory<const dst_data_t>(dst);
85 const auto &dst_d = dst.get_desc();
86 const auto dst_dims = dst_d.get_dims();
87 const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get());
88
89 std::vector<mapped_ptr_t<const src_data_t>> mapped_srcs;
90 mapped_srcs.reserve(srcs.size());
91 for (auto &src : srcs)
92 mapped_srcs.emplace_back(map_memory<const src_data_t>(src));
93
94 dnnl::impl::parallel_nd(dst_dims[0], dst_dims[1], dst_dims[2],
95 dst_dims[3],
96 [&](memory::dim n, memory::dim c, memory::dim h,
97 memory::dim w) {
98 if (is_current_test_failed()) return;
99
100 acc_t src_sum = 0.0;
101 for (size_t num = 0; num < srcs.size(); num++) {
102 auto &src_data = mapped_srcs[num];
103 const auto &src_d = srcs[num].get_desc();
104 const auto src_dims = src_d.get_dims();
105 const dnnl::impl::memory_desc_wrapper src_mdw(
106 src_d.get());
107
108 auto src_idx = w + src_dims[3] * h
109 + src_dims[2] * src_dims[3] * c
110 + src_dims[1] * src_dims[2] * src_dims[3] * n;
111 if (num == 0) {
112 src_sum = acc_t(scale[num])
113 * src_data[src_mdw.off_l(src_idx, false)];
114 } else {
115 src_sum += acc_t(scale[num])
116 * src_data[src_mdw.off_l(src_idx, false)];
117 }
118
119 src_sum = (std::max)(
120 (std::min)(src_sum,
121 (std::numeric_limits<acc_t>::max)()),
122 std::numeric_limits<acc_t>::lowest());
123 }
124
125 auto dst_idx = w + dst_dims[3] * h
126 + dst_dims[2] * dst_dims[3] * c
127 + dst_dims[1] * dst_dims[2] * dst_dims[3] * n;
128
129 acc_t dst_val = dst_data[dst_mdw.off_l(dst_idx, false)];
130 ASSERT_EQ(src_sum, dst_val);
131 });
132 }
133
134protected:
135 bool cuda_supported_format_tag(memory::format_tag tag) {
136 return impl::utils::one_of(tag, dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd,
137 dnnl_abcde, dnnl_abcdef, dnnl_abdec, dnnl_acb, dnnl_acbde,
138 dnnl_acbdef, dnnl_acdb, dnnl_acdeb, dnnl_ba, dnnl_bac,
139 dnnl_bacd, dnnl_bca, dnnl_bcda, dnnl_bcdea, dnnl_cba, dnnl_cdba,
140 dnnl_cdeba, dnnl_decab, dnnl_defcab, dnnl_aBc4b, dnnl_aBcd4b,
141 dnnl_aBcde4b);
142 }
143 void SetUp() override {
144 src_data_type = data_traits<src_data_t>::data_type;
145 dst_data_type = data_traits<dst_data_t>::data_type;
146 sum_test_params p
147 = ::testing::TestWithParam<sum_test_params>::GetParam();
148 SKIP_IF(get_test_engine_kind() == engine::kind::gpu
149 && src_data_type == memory::data_type::bf16,
150 "GPU does not support bfloat16 data type.");
151 SKIP_IF(unsupported_data_type(src_data_type),
152 "Engine does not support this data type.");
153 SKIP_IF(unsupported_data_type(dst_data_type),
154 "Engine does not support this data type.");
155
156 SKIP_IF_CUDA(!cuda_supported_format_tag(p.dst_format),
157 "Unsupported format tag");
158 for (size_t i = 0; i < p.srcs_format.size(); i++) {
159 SKIP_IF_CUDA(!cuda_supported_format_tag(p.srcs_format[i]),
160 "Unsupported format tag");
161 }
162 catch_expected_failures(
163 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
164 }
165
166 void Test() {
167 sum_test_params p
168 = ::testing::TestWithParam<sum_test_params>::GetParam();
169
170 const auto num_srcs = p.srcs_format.size();
171
172 auto eng = get_test_engine();
173 auto strm = make_stream(eng);
174
175 std::vector<memory::desc> srcs_md;
176 std::vector<memory> srcs;
177
178 for (size_t i = 0; i < num_srcs; i++) {
179 auto desc = memory::desc(p.dims, src_data_type, p.srcs_format[i]);
180 srcs_md.push_back(desc);
181 }
182
183 memory dst;
184 sum::primitive_desc sum_pd;
185
186 if (p.is_output_omitted) {
187 ASSERT_NO_THROW(
188 sum_pd = sum::primitive_desc(eng, p.scale, srcs_md));
189 } else {
190 auto dst_desc = memory::desc(p.dims, dst_data_type, p.dst_format);
191 sum_pd = sum::primitive_desc(eng, dst_desc, p.scale, srcs_md);
192
193 ASSERT_EQ(sum_pd.dst_desc().get_ndims(), dst_desc.get_ndims());
194 }
195 dst = test::make_memory(sum_pd.dst_desc(), eng);
196 // test construction from a C pd
197 sum_pd = sum::primitive_desc(sum_pd.get());
198 for (size_t i = 0; i < num_srcs; i++) {
199 if (p.srcs_format[i] != memory::format_tag::any) {
200 ASSERT_TRUE(srcs_md[(int)i] == sum_pd.src_desc((int)i));
201 }
202 auto src_memory = test::make_memory(sum_pd.src_desc((int)i), eng);
203 const size_t sz
204 = src_memory.get_desc().get_size() / sizeof(src_data_t);
205 fill_data<src_data_t>(sz, src_memory);
206
207 // Keep few mantissa digits for fp types to avoid round-off errors
208 // With proper scalars the computations give exact results
209 if (!std::is_integral<src_data_t>::value) {
210 using uint_type = typename data_traits<src_data_t>::uint_type;
211 int mant_digits
212 = dnnl::impl::nstl::numeric_limits<src_data_t>::digits;
213 int want_mant_digits = 3;
214 auto src_ptr = map_memory<src_data_t>(src_memory);
215 for (size_t i = 0; i < sz; i++) {
216 uint_type mask = (uint_type)-1
217 << (mant_digits - want_mant_digits);
218 *((uint_type *)&src_ptr[i]) &= mask;
219 }
220 }
221 srcs.push_back(src_memory);
222 }
223
224 ASSERT_TRUE(sum_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
225 == sum_pd.dst_desc());
226 for (int i = 0; i < (int)srcs.size(); i++)
227 ASSERT_TRUE(sum_pd.query_md(
228 query::exec_arg_md, DNNL_ARG_MULTIPLE_SRC + i)
229 == sum_pd.src_desc(i));
230
231 {
232 auto dst_data = map_memory<dst_data_t>(dst);
233 const size_t sz = dst.get_desc().get_size() / sizeof(dst_data_t);
234 // overwriting dst to prevent false positives for test cases.
235 dnnl::impl::parallel_nd(
236 (ptrdiff_t)sz, [&](ptrdiff_t i) { dst_data[i] = -32; });
237 }
238 EXPECT_ANY_THROW(sum(sum_pd, {}));
239 sum c(sum_pd);
240 std::unordered_map<int, memory> args = {{DNNL_ARG_DST, dst}};
241 for (int i = 0; i < (int)num_srcs; i++) {
242 args.insert({DNNL_ARG_MULTIPLE_SRC + i, srcs[i]});
243 }
244 c.execute(strm, args);
245 strm.wait();
246
247 check_data(srcs, p.scale, dst);
248 }
249};
250
251static auto simple_test_cases = [](bool omit_output) {
252 return ::testing::Values(
253 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {0, 7, 4, 4},
254 {1.0f, 1.0f}, omit_output},
255 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 0, 4, 4},
256 {1.0f, 1.0f}, omit_output},
257 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 8, 0, 4},
258 {1.0f, 1.0f}, omit_output},
259 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {-1, 8, 4, 4},
260 {1.0f, 1.0f}, omit_output, true, dnnl_invalid_arguments},
261
262 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw,
263 {1, 1024, 38, 50}, {1.0f, 1.0f}, omit_output},
264 sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 8, 2, 2},
265 {1.0f, 1.0f}, omit_output},
266 sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nChw8c,
267 {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output},
268 sum_test_params {{tag::nchw, tag::nchw}, tag::nChw8c, {2, 16, 2, 2},
269 {1.0f, 1.0f}, omit_output},
270 sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nchw,
271 {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output},
272 sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 8, 2, 2},
273 {2.0f, 3.0f}, omit_output},
274 sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nChw8c,
275 {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output},
276 sum_test_params {{tag::nchw, tag::nchw}, tag::nChw8c, {2, 16, 2, 2},
277 {2.0f, 3.0f}, omit_output},
278 sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nchw,
279 {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output},
280 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {5, 8, 3, 3},
281 {2.0f, 3.0f}, omit_output},
282 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw,
283 {32, 32, 13, 14}, {2.0f, 3.0f}, omit_output},
284 sum_test_params {{tag::nChw16c, tag::nChw8c}, tag::nChw16c,
285 {2, 16, 3, 3}, {2.0f, 3.0f}, omit_output});
286};
287
288static auto simple_test_cases_bf16 = [](bool omit_output) {
289 return ::testing::Values(
290 sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c,
291 {1, 16, 1, 1}, {2.0f, 3.0f}, omit_output},
292 sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {1, 16, 1, 1},
293 {2.0f, 3.0f}, omit_output},
294 sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 16, 13, 7},
295 {2.0f, 3.0f}, omit_output},
296 sum_test_params {{tag::nchw, tag::nchw, tag::nchw, tag::nchw},
297 tag::nchw, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f},
298 omit_output},
299 sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw,
300 {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output},
301 sum_test_params {
302 {tag::nchw, tag::nchw, tag::nchw, tag::nchw, tag::nchw},
303 tag::nchw, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
304 omit_output},
305 sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw,
306 {2, 37, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output},
307 sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw,
308 {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output},
309 sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c,
310 {2, 16, 13, 7}, {2.0f, 3.0f}, omit_output},
311 sum_test_params {{tag::nChw16c, tag::nChw16c, tag::nChw16c},
312 tag::nChw16c, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f},
313 omit_output},
314 sum_test_params {{tag::nChw16c, tag::nChw16c, tag::nChw16c,
315 tag::nChw16c, tag::nChw16c},
316 tag::nChw16c, {2, 16, 13, 7},
317 {2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, omit_output},
318 sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c,
319 {2, 128, 23, 15}, {2.5f, 0.125f}, omit_output});
320};
321
322static auto special_test_cases = []() {
323 return ::testing::Values(
324 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 8, 4, 4},
325 {1.0f}, false, true, dnnl_invalid_arguments},
326 sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {2, 8, 4, 4},
327 {0.1f}, false, true, dnnl_invalid_arguments},
328 sum_test_params {{tag::any, tag::nchw}, tag::nchw, {1, 16, 1, 1},
329 {2.0f, 3.0f}, false, true, dnnl_invalid_arguments},
330 sum_test_params {{tag::nchw, tag::any}, tag::nchw, {1, 16, 1, 1},
331 {2.0f, 3.0f}, false, true, dnnl_invalid_arguments});
332};
333
334/* corner cases */
335#define CASE_CC(itag0, itag1, otag, dims_, ef, st) \
336 sum_test_params { \
337 {tag::itag0, tag::itag1}, tag::otag, memory::dims dims_, {1.0f, 1.0f}, \
338 0, ef, st \
339 }
340static auto corner_test_cases = []() {
341 return ::testing::Values(
342 CASE_CC(nchw, nChw8c, nchw, ({0, 7, 4, 4}), false, dnnl_success),
343 CASE_CC(nchw, nChw8c, nchw, ({1, 0, 4, 4}), false, dnnl_success),
344 CASE_CC(nchw, nChw8c, nchw, ({1, 8, 0, 4}), false, dnnl_success),
345 CASE_CC(nchw, nChw8c, nchw, ({-1, 8, 4, 4}), true,
346 dnnl_invalid_arguments));
347};
348#undef CASE_CC
349
350#define CPU_INST_TEST_CASE(test, omit_output) \
351 CPU_TEST_P(test, TestsSum) {} \
352 CPU_INSTANTIATE_TEST_SUITE_P( \
353 TestSum, test, simple_test_cases(omit_output)); \
354 CPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases());
355
356#define INST_TEST_CASE_BF16(test, omit_output) \
357 CPU_TEST_P(test, TestsSum) {} \
358 CPU_INSTANTIATE_TEST_SUITE_P( \
359 TestSum, test, simple_test_cases(omit_output)); \
360 CPU_INSTANTIATE_TEST_SUITE_P( \
361 TestSumBf16, test, simple_test_cases_bf16(omit_output)); \
362 CPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases());
363
364#define GPU_INST_TEST_CASE(test, omit_output) \
365 GPU_TEST_P(test, TestsSum) {} \
366 GPU_INSTANTIATE_TEST_SUITE_P( \
367 TestSum, test, simple_test_cases(omit_output)); \
368 GPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases());
369
370#define INST_TEST_CASE(test, omit_output) \
371 CPU_INST_TEST_CASE(test, omit_output) \
372 GPU_INST_TEST_CASE(test, omit_output)
373
374using sum_test_float_omit_output = sum_test_t<float, float>;
375using sum_test_u8_omit_output = sum_test_t<uint8_t, int32_t>;
376using sum_test_s8_omit_output = sum_test_t<int8_t, int32_t>;
377using sum_test_s32_omit_output = sum_test_t<int32_t, float>;
378using sum_test_f16_omit_output = sum_test_t<float16_t, float>;
379using sum_test_bf16bf16_omit_output = sum_test_t<bfloat16_t, float>;
380using sum_test_bf16f32_omit_output = sum_test_t<bfloat16_t, float, float>;
381
382using sum_test_float = sum_test_t<float, float>;
383using sum_test_u8 = sum_test_t<uint8_t, int32_t>;
384using sum_test_s8 = sum_test_t<int8_t, int32_t>;
385using sum_test_s32 = sum_test_t<int32_t, float>;
386using sum_test_f16 = sum_test_t<float16_t, float>;
387using sum_test_bf16bf16 = sum_test_t<bfloat16_t, float>;
388using sum_test_bf16f32 = sum_test_t<bfloat16_t, float, float>;
389
390using sum_cc_f32 = sum_test_t<float, float>;
391
392TEST_P(sum_cc_f32, TestSumCornerCases) {}
393INSTANTIATE_TEST_SUITE_P(TestSumCornerCases, sum_cc_f32, corner_test_cases());
394
395INST_TEST_CASE(sum_test_float_omit_output, 1)
396INST_TEST_CASE(sum_test_u8_omit_output, 1)
397INST_TEST_CASE(sum_test_s8_omit_output, 1)
398INST_TEST_CASE(sum_test_s32_omit_output, 1)
399INST_TEST_CASE_BF16(sum_test_bf16bf16_omit_output, 1)
400// Automatically created dst descriptor has bf16 data type so this test is not
401// valid: INST_TEST_CASE(sum_test_bf16f32_omit_output, 1)
402INST_TEST_CASE(sum_test_f16_omit_output, 1)
403
404INST_TEST_CASE(sum_test_float, 0)
405INST_TEST_CASE(sum_test_u8, 0)
406INST_TEST_CASE(sum_test_s8, 0)
407INST_TEST_CASE(sum_test_s32, 0)
408INST_TEST_CASE_BF16(sum_test_bf16bf16, 0)
409INST_TEST_CASE_BF16(sum_test_bf16f32, 0)
410INST_TEST_CASE(sum_test_f16, 0)
411
412#undef CPU_INST_TEST_CASE
413#undef GPU_INST_TEST_CASE
414} // namespace dnnl
415