1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24struct reduction_test_params_t {
25 memory::format_tag src_format;
26 memory::format_tag dst_format;
27 algorithm aalgorithm;
28 float p;
29 float eps;
30 memory::dims src_dims;
31 memory::dims dst_dims;
32 bool expect_to_fail;
33 dnnl_status_t expected_status;
34};
35
36template <typename src_data_t, typename dst_data_t = src_data_t>
37class reduction_test_t
38 : public ::testing::TestWithParam<reduction_test_params_t> {
39private:
40 reduction_test_params_t p;
41 memory::data_type src_dt, dst_dt;
42
43protected:
44 void SetUp() override {
45 src_dt = data_traits<src_data_t>::data_type;
46 dst_dt = data_traits<dst_data_t>::data_type;
47
48 p = ::testing::TestWithParam<reduction_test_params_t>::GetParam();
49
50 SKIP_IF(unsupported_data_type(src_dt),
51 "Engine does not support this data type.");
52 SKIP_IF(get_test_engine().get_kind() != engine::kind::cpu,
53 "Engine does not support this primitive.");
54 SKIP_IF_CUDA(p.aalgorithm != algorithm::reduction_max
55 && p.aalgorithm != algorithm::reduction_min
56 && p.aalgorithm != algorithm::reduction_sum
57 && p.aalgorithm != algorithm::reduction_mul
58 && p.aalgorithm != algorithm::reduction_mean
59 && p.aalgorithm != algorithm::reduction_norm_lp_max
60 && p.aalgorithm
61 != algorithm::reduction_norm_lp_power_p_max
62 && p.eps != 0.0f,
63 "Unsupported algorithm type for CUDA");
64
65 catch_expected_failures(
66 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
67 }
68
69 void Test() {
70 // reduction specific types and values
71 using pd_t = reduction::primitive_desc;
72 allows_attr_t allowed_attributes {false}; // doesn't support anything
73 allowed_attributes.po_sum = true;
74 allowed_attributes.po_eltwise = true;
75 allowed_attributes.po_binary = true;
76
77 auto eng = get_test_engine();
78 auto strm = make_stream(eng);
79
80 auto desc_src = memory::desc(p.src_dims, src_dt, p.src_format);
81 auto desc_dst = memory::desc(p.dst_dims, dst_dt, p.dst_format);
82
83 // default pd ctor
84 auto pd = pd_t();
85 // regular pd ctor
86 pd = pd_t(eng, p.aalgorithm, desc_src, desc_dst, p.p, p.eps);
87 // test all pd ctors
88 test_fwd_pd_constructors<pd_t>(pd, allowed_attributes, p.aalgorithm,
89 desc_src, desc_dst, p.p, p.eps);
90
91 EXPECT_ANY_THROW(reduction(pd, {}));
92 // default primitive ctor
93 auto prim = reduction();
94 // regular primitive ctor
95 prim = reduction(pd);
96
97 const auto src_desc = pd.src_desc();
98 const auto dst_desc = pd.dst_desc();
99
100 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
101 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
102
103 ASSERT_EQ(pd.get_algorithm(), p.aalgorithm);
104 ASSERT_EQ(pd.get_p(), p.p);
105 ASSERT_EQ(pd.get_epsilon(), p.eps);
106
107 const auto test_engine = pd.get_engine();
108
109 auto mem_src = memory(src_desc, test_engine);
110 auto mem_dst = memory(dst_desc, test_engine);
111
112 fill_data<src_data_t>(
113 src_desc.get_size() / sizeof(src_data_t), mem_src);
114
115 prim.execute(strm, {{DNNL_ARG_SRC, mem_src}, {DNNL_ARG_DST, mem_dst}});
116 strm.wait();
117 }
118};
119
120using tag = memory::format_tag;
121
122static auto expected_failures = []() {
123 return ::testing::Values(
124 // The same src and dst dims
125 reduction_test_params_t {tag::nchw, tag::nchw,
126 algorithm::reduction_sum, 0.0f, 0.0f, {1, 1, 1, 4},
127 {1, 1, 1, 4}, true, dnnl_invalid_arguments},
128 // not supported alg_kind
129 reduction_test_params_t {tag::nchw, tag::nchw,
130 algorithm::eltwise_relu, 0.0f, 0.0f, {1, 1, 1, 4},
131 {1, 1, 1, 4}, true, dnnl_invalid_arguments},
132 // negative dim
133 reduction_test_params_t {tag::nchw, tag::nchw,
134 algorithm::reduction_sum, 0.0f, 0.0f, {-1, 1, 1, 4},
135 {-1, 1, 1, 1}, true, dnnl_invalid_arguments},
136 // not supported p
137 reduction_test_params_t {tag::nchw, tag::nchw,
138 algorithm::reduction_norm_lp_max, 0.5f, 0.0f, {1, 8, 4, 4},
139 {1, 8, 4, 4}, true, dnnl_invalid_arguments},
140 // invalid tag
141 reduction_test_params_t {tag::any, tag::nchw,
142 algorithm::reduction_sum, 0.0f, 0.0f, {1, 1, 1, 4},
143 {1, 1, 1, 1}, true, dnnl_invalid_arguments});
144};
145
146static auto zero_dim = []() {
147 return ::testing::Values(reduction_test_params_t {tag::nchw, tag::nchw,
148 algorithm::reduction_sum, 0.0f, 0.0f, {0, 1, 1, 4}, {0, 1, 1, 1}});
149};
150
151static auto simple_cases = []() {
152 return ::testing::Values(reduction_test_params_t {tag::nchw, tag::nchw,
153 algorithm::reduction_sum, 0.0f, 0.0f,
154 {1, 1, 1, 4}, {1, 1, 1, 1}},
155 reduction_test_params_t {tag::nchw, tag::nchw,
156 algorithm::reduction_max, 0.0f, 0.0f, {1, 1, 4, 4},
157 {1, 1, 1, 4}},
158 reduction_test_params_t {tag::nChw16c, tag::nChw16c,
159 algorithm::reduction_min, 0.0f, 0.0f, {4, 4, 4, 4},
160 {1, 4, 4, 4}},
161 reduction_test_params_t {tag::nChw16c, tag::nchw,
162 algorithm::reduction_sum, 0.0f, 0.0f, {4, 4, 4, 4},
163 {1, 4, 4, 1}},
164 reduction_test_params_t {tag::nChw16c, tag::any,
165 algorithm::reduction_min, 0.0f, 0.0f, {4, 4, 4, 4},
166 {1, 1, 1, 1}});
167};
168
169static auto f32_cases = []() {
170 return ::testing::Values(reduction_test_params_t {tag::nchw, tag::nchw,
171 algorithm::reduction_norm_lp_max, 1.0f,
172 0.0f, {1, 1, 1, 4}, {1, 1, 1, 1}},
173 reduction_test_params_t {tag::nchw, tag::nchw,
174 algorithm::reduction_norm_lp_power_p_max, 2.0f, 0.0f,
175 {1, 1, 1, 4}, {1, 1, 1, 1}},
176 reduction_test_params_t {tag::nchw, tag::nchw,
177 algorithm::reduction_mean, 0.0f, 0.0f, {1, 4, 4, 4},
178 {1, 1, 4, 4}});
179};
180
181#define INST_TEST_CASE(test) \
182 TEST_P(test, TestsReduction) {} \
183 INSTANTIATE_TEST_SUITE_P(TestReductionEF, test, expected_failures()); \
184 INSTANTIATE_TEST_SUITE_P(TestReductionZero, test, zero_dim()); \
185 INSTANTIATE_TEST_SUITE_P(TestReductionSimple, test, simple_cases());
186
187#define INST_TEST_CASE_F32(test) \
188 TEST_P(test, TestsReduction) {} \
189 INSTANTIATE_TEST_SUITE_P(TestReductionEF, test, expected_failures()); \
190 INSTANTIATE_TEST_SUITE_P(TestReductionZero, test, zero_dim()); \
191 INSTANTIATE_TEST_SUITE_P(TestReductionSimple, test, simple_cases()); \
192 INSTANTIATE_TEST_SUITE_P(TestReductionNorm, test, f32_cases());
193
194using reduction_test_f32 = reduction_test_t<float>;
195using reduction_test_bf16 = reduction_test_t<bfloat16_t>;
196using reduction_test_f16 = reduction_test_t<float16_t>;
197using reduction_test_s8 = reduction_test_t<int8_t>;
198using reduction_test_u8 = reduction_test_t<uint8_t>;
199
200INST_TEST_CASE_F32(reduction_test_f32)
201INST_TEST_CASE(reduction_test_bf16)
202INST_TEST_CASE(reduction_test_f16)
203INST_TEST_CASE(reduction_test_s8)
204INST_TEST_CASE(reduction_test_u8)
205
206} // namespace dnnl
207