1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | struct reduction_test_params_t { |
25 | memory::format_tag src_format; |
26 | memory::format_tag dst_format; |
27 | algorithm aalgorithm; |
28 | float p; |
29 | float eps; |
30 | memory::dims src_dims; |
31 | memory::dims dst_dims; |
32 | bool expect_to_fail; |
33 | dnnl_status_t expected_status; |
34 | }; |
35 | |
36 | template <typename src_data_t, typename dst_data_t = src_data_t> |
37 | class reduction_test_t |
38 | : public ::testing::TestWithParam<reduction_test_params_t> { |
39 | private: |
40 | reduction_test_params_t p; |
41 | memory::data_type src_dt, dst_dt; |
42 | |
43 | protected: |
44 | void SetUp() override { |
45 | src_dt = data_traits<src_data_t>::data_type; |
46 | dst_dt = data_traits<dst_data_t>::data_type; |
47 | |
48 | p = ::testing::TestWithParam<reduction_test_params_t>::GetParam(); |
49 | |
50 | SKIP_IF(unsupported_data_type(src_dt), |
51 | "Engine does not support this data type." ); |
52 | SKIP_IF(get_test_engine().get_kind() != engine::kind::cpu, |
53 | "Engine does not support this primitive." ); |
54 | SKIP_IF_CUDA(p.aalgorithm != algorithm::reduction_max |
55 | && p.aalgorithm != algorithm::reduction_min |
56 | && p.aalgorithm != algorithm::reduction_sum |
57 | && p.aalgorithm != algorithm::reduction_mul |
58 | && p.aalgorithm != algorithm::reduction_mean |
59 | && p.aalgorithm != algorithm::reduction_norm_lp_max |
60 | && p.aalgorithm |
61 | != algorithm::reduction_norm_lp_power_p_max |
62 | && p.eps != 0.0f, |
63 | "Unsupported algorithm type for CUDA" ); |
64 | |
65 | catch_expected_failures( |
66 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
67 | } |
68 | |
69 | void Test() { |
70 | // reduction specific types and values |
71 | using pd_t = reduction::primitive_desc; |
72 | allows_attr_t allowed_attributes {false}; // doesn't support anything |
73 | allowed_attributes.po_sum = true; |
74 | allowed_attributes.po_eltwise = true; |
75 | allowed_attributes.po_binary = true; |
76 | |
77 | auto eng = get_test_engine(); |
78 | auto strm = make_stream(eng); |
79 | |
80 | auto desc_src = memory::desc(p.src_dims, src_dt, p.src_format); |
81 | auto desc_dst = memory::desc(p.dst_dims, dst_dt, p.dst_format); |
82 | |
83 | // default pd ctor |
84 | auto pd = pd_t(); |
85 | // regular pd ctor |
86 | pd = pd_t(eng, p.aalgorithm, desc_src, desc_dst, p.p, p.eps); |
87 | // test all pd ctors |
88 | test_fwd_pd_constructors<pd_t>(pd, allowed_attributes, p.aalgorithm, |
89 | desc_src, desc_dst, p.p, p.eps); |
90 | |
91 | EXPECT_ANY_THROW(reduction(pd, {})); |
92 | // default primitive ctor |
93 | auto prim = reduction(); |
94 | // regular primitive ctor |
95 | prim = reduction(pd); |
96 | |
97 | const auto src_desc = pd.src_desc(); |
98 | const auto dst_desc = pd.dst_desc(); |
99 | |
100 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
101 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
102 | |
103 | ASSERT_EQ(pd.get_algorithm(), p.aalgorithm); |
104 | ASSERT_EQ(pd.get_p(), p.p); |
105 | ASSERT_EQ(pd.get_epsilon(), p.eps); |
106 | |
107 | const auto test_engine = pd.get_engine(); |
108 | |
109 | auto mem_src = memory(src_desc, test_engine); |
110 | auto mem_dst = memory(dst_desc, test_engine); |
111 | |
112 | fill_data<src_data_t>( |
113 | src_desc.get_size() / sizeof(src_data_t), mem_src); |
114 | |
115 | prim.execute(strm, {{DNNL_ARG_SRC, mem_src}, {DNNL_ARG_DST, mem_dst}}); |
116 | strm.wait(); |
117 | } |
118 | }; |
119 | |
120 | using tag = memory::format_tag; |
121 | |
122 | static auto expected_failures = []() { |
123 | return ::testing::Values( |
124 | // The same src and dst dims |
125 | reduction_test_params_t {tag::nchw, tag::nchw, |
126 | algorithm::reduction_sum, 0.0f, 0.0f, {1, 1, 1, 4}, |
127 | {1, 1, 1, 4}, true, dnnl_invalid_arguments}, |
128 | // not supported alg_kind |
129 | reduction_test_params_t {tag::nchw, tag::nchw, |
130 | algorithm::eltwise_relu, 0.0f, 0.0f, {1, 1, 1, 4}, |
131 | {1, 1, 1, 4}, true, dnnl_invalid_arguments}, |
132 | // negative dim |
133 | reduction_test_params_t {tag::nchw, tag::nchw, |
134 | algorithm::reduction_sum, 0.0f, 0.0f, {-1, 1, 1, 4}, |
135 | {-1, 1, 1, 1}, true, dnnl_invalid_arguments}, |
136 | // not supported p |
137 | reduction_test_params_t {tag::nchw, tag::nchw, |
138 | algorithm::reduction_norm_lp_max, 0.5f, 0.0f, {1, 8, 4, 4}, |
139 | {1, 8, 4, 4}, true, dnnl_invalid_arguments}, |
140 | // invalid tag |
141 | reduction_test_params_t {tag::any, tag::nchw, |
142 | algorithm::reduction_sum, 0.0f, 0.0f, {1, 1, 1, 4}, |
143 | {1, 1, 1, 1}, true, dnnl_invalid_arguments}); |
144 | }; |
145 | |
146 | static auto zero_dim = []() { |
147 | return ::testing::Values(reduction_test_params_t {tag::nchw, tag::nchw, |
148 | algorithm::reduction_sum, 0.0f, 0.0f, {0, 1, 1, 4}, {0, 1, 1, 1}}); |
149 | }; |
150 | |
151 | static auto simple_cases = []() { |
152 | return ::testing::Values(reduction_test_params_t {tag::nchw, tag::nchw, |
153 | algorithm::reduction_sum, 0.0f, 0.0f, |
154 | {1, 1, 1, 4}, {1, 1, 1, 1}}, |
155 | reduction_test_params_t {tag::nchw, tag::nchw, |
156 | algorithm::reduction_max, 0.0f, 0.0f, {1, 1, 4, 4}, |
157 | {1, 1, 1, 4}}, |
158 | reduction_test_params_t {tag::nChw16c, tag::nChw16c, |
159 | algorithm::reduction_min, 0.0f, 0.0f, {4, 4, 4, 4}, |
160 | {1, 4, 4, 4}}, |
161 | reduction_test_params_t {tag::nChw16c, tag::nchw, |
162 | algorithm::reduction_sum, 0.0f, 0.0f, {4, 4, 4, 4}, |
163 | {1, 4, 4, 1}}, |
164 | reduction_test_params_t {tag::nChw16c, tag::any, |
165 | algorithm::reduction_min, 0.0f, 0.0f, {4, 4, 4, 4}, |
166 | {1, 1, 1, 1}}); |
167 | }; |
168 | |
169 | static auto f32_cases = []() { |
170 | return ::testing::Values(reduction_test_params_t {tag::nchw, tag::nchw, |
171 | algorithm::reduction_norm_lp_max, 1.0f, |
172 | 0.0f, {1, 1, 1, 4}, {1, 1, 1, 1}}, |
173 | reduction_test_params_t {tag::nchw, tag::nchw, |
174 | algorithm::reduction_norm_lp_power_p_max, 2.0f, 0.0f, |
175 | {1, 1, 1, 4}, {1, 1, 1, 1}}, |
176 | reduction_test_params_t {tag::nchw, tag::nchw, |
177 | algorithm::reduction_mean, 0.0f, 0.0f, {1, 4, 4, 4}, |
178 | {1, 1, 4, 4}}); |
179 | }; |
180 | |
181 | #define INST_TEST_CASE(test) \ |
182 | TEST_P(test, TestsReduction) {} \ |
183 | INSTANTIATE_TEST_SUITE_P(TestReductionEF, test, expected_failures()); \ |
184 | INSTANTIATE_TEST_SUITE_P(TestReductionZero, test, zero_dim()); \ |
185 | INSTANTIATE_TEST_SUITE_P(TestReductionSimple, test, simple_cases()); |
186 | |
187 | #define INST_TEST_CASE_F32(test) \ |
188 | TEST_P(test, TestsReduction) {} \ |
189 | INSTANTIATE_TEST_SUITE_P(TestReductionEF, test, expected_failures()); \ |
190 | INSTANTIATE_TEST_SUITE_P(TestReductionZero, test, zero_dim()); \ |
191 | INSTANTIATE_TEST_SUITE_P(TestReductionSimple, test, simple_cases()); \ |
192 | INSTANTIATE_TEST_SUITE_P(TestReductionNorm, test, f32_cases()); |
193 | |
194 | using reduction_test_f32 = reduction_test_t<float>; |
195 | using reduction_test_bf16 = reduction_test_t<bfloat16_t>; |
196 | using reduction_test_f16 = reduction_test_t<float16_t>; |
197 | using reduction_test_s8 = reduction_test_t<int8_t>; |
198 | using reduction_test_u8 = reduction_test_t<uint8_t>; |
199 | |
200 | INST_TEST_CASE_F32(reduction_test_f32) |
201 | INST_TEST_CASE(reduction_test_bf16) |
202 | INST_TEST_CASE(reduction_test_f16) |
203 | INST_TEST_CASE(reduction_test_s8) |
204 | INST_TEST_CASE(reduction_test_u8) |
205 | |
206 | } // namespace dnnl |
207 | |