1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | using tag = memory::format_tag; |
25 | using dt = memory::data_type; |
26 | |
27 | struct prelu_test_params_t { |
28 | dt src_dt; |
29 | dt wei_dt; |
30 | dt dst_dt; |
31 | tag src_tag; |
32 | tag wei_tag; |
33 | tag dst_tag; |
34 | memory::dims src_dims; |
35 | memory::dims wei_dims; |
36 | bool expect_to_fail; |
37 | dnnl_status_t expected_status; |
38 | }; |
39 | |
40 | class prelu_test_t : public ::testing::TestWithParam<prelu_test_params_t> { |
41 | private: |
42 | prelu_test_params_t p; |
43 | memory src, wei; |
44 | std::shared_ptr<prelu_forward::primitive_desc> pd_fwd_hint; |
45 | |
46 | protected: |
47 | void SetUp() override { |
48 | p = ::testing::TestWithParam<prelu_test_params_t>::GetParam(); |
49 | |
50 | SKIP_IF_CUDA(true, "Prelu primitive not supported by CUDA" ); |
51 | |
52 | SKIP_IF(unsupported_data_type(p.src_dt, p.wei_dt, p.dst_dt), |
53 | "Engine does not support this data type." ); |
54 | |
55 | catch_expected_failures( |
56 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
57 | } |
58 | |
59 | void Forward(prop_kind pk) { |
60 | // prelu specific types and values |
61 | using pd_t = prelu_forward::primitive_desc; |
62 | |
63 | auto eng = get_test_engine(); |
64 | auto strm = make_stream(eng); |
65 | |
66 | auto aa = allows_attr_t {false}; |
67 | |
68 | auto src_md = memory::desc(p.src_dims, p.src_dt, p.src_tag); |
69 | auto wei_md = memory::desc(p.wei_dims, p.wei_dt, p.wei_tag); |
70 | auto dst_md = memory::desc(p.src_dims, p.dst_dt, p.dst_tag); |
71 | |
72 | // default pd ctor |
73 | auto pd = pd_t(); |
74 | // regular pd ctor |
75 | pd = pd_t(eng, pk, src_md, wei_md, dst_md); |
76 | // test all pd ctors |
77 | test_fwd_pd_constructors<pd_t>(pd, aa, pk, src_md, wei_md, dst_md); |
78 | pd_fwd_hint = std::make_shared<pd_t>(pd); |
79 | |
80 | EXPECT_ANY_THROW(prelu_forward(pd, {})); |
81 | // default primitive ctor |
82 | auto prelu = prelu_forward(); |
83 | // regular primitive ctor |
84 | prelu = prelu_forward(pd); |
85 | |
86 | // check primitive kind is prelu |
87 | ASSERT_TRUE(prelu.get_kind() == primitive::kind::prelu); |
88 | // query for descs from pd |
89 | const auto src_desc = pd.src_desc(); |
90 | const auto wei_desc = pd.weights_desc(); |
91 | const auto dst_desc = pd.dst_desc(); |
92 | // query for src_desc via exec arg |
93 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
94 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
95 | // query for weights_desc via exec arg |
96 | ASSERT_TRUE( |
97 | pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS) == wei_desc); |
98 | if (p.src_tag != tag::any) { ASSERT_TRUE(wei_md == wei_desc); } |
99 | // query for dst_desc via exec arg |
100 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
101 | if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); } |
102 | |
103 | // query primitive parameters |
104 | ASSERT_EQ(pd.get_prop_kind(), pk); |
105 | |
106 | // check primitive returns zero_md for all rest md |
107 | ASSERT_TRUE(pd.diff_src_desc().is_zero()); |
108 | ASSERT_TRUE(pd.diff_dst_desc().is_zero()); |
109 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
110 | |
111 | src = test::make_memory(src_desc, eng); |
112 | wei = test::make_memory(wei_desc, eng); |
113 | auto dst = test::make_memory(dst_desc, eng); |
114 | |
115 | fill_data(p.src_dt, src, 1, 1); |
116 | fill_data(p.wei_dt, wei, 2, 2); |
117 | // test out-place mode |
118 | prelu.execute(strm, |
119 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_WEIGHTS, wei}, |
120 | {DNNL_ARG_DST, dst}}); |
121 | strm.wait(); |
122 | } |
123 | |
124 | void Backward() { |
125 | // prelu specific types and values |
126 | using pd_t = prelu_backward::primitive_desc; |
127 | using hint_pd_t = prelu_forward::primitive_desc; |
128 | allows_attr_t aa {false}; // doesn't support anything |
129 | |
130 | auto eng = get_test_engine(); |
131 | auto strm = make_stream(eng); |
132 | |
133 | auto src_md = memory::desc(p.src_dims, p.src_dt, p.src_tag); |
134 | auto wei_md = memory::desc(p.wei_dims, p.wei_dt, p.wei_tag); |
135 | auto diff_src_md = memory::desc(p.src_dims, p.src_dt, p.src_tag); |
136 | auto diff_wei_md = memory::desc(p.wei_dims, p.wei_dt, p.wei_tag); |
137 | auto diff_dst_md = memory::desc(p.src_dims, p.dst_dt, p.dst_tag); |
138 | |
139 | // default pd ctor |
140 | auto pd = pd_t(); |
141 | // regular pd ctor |
142 | pd = pd_t(eng, src_md, wei_md, diff_src_md, diff_wei_md, diff_dst_md, |
143 | *pd_fwd_hint); |
144 | // test all pd ctors |
145 | test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa, src_md, |
146 | wei_md, diff_src_md, diff_wei_md, diff_dst_md); |
147 | |
148 | EXPECT_ANY_THROW(prelu_backward(pd, {})); |
149 | // default primitive ctor |
150 | auto prelu = prelu_backward(); |
151 | // regular primitive ctor |
152 | prelu = prelu_backward(pd); |
153 | |
154 | // check primitive kind is prelu |
155 | ASSERT_TRUE(prelu.get_kind() == primitive::kind::prelu); |
156 | |
157 | // query for descs from pd |
158 | const auto src_desc = pd.src_desc(); |
159 | const auto wei_desc = pd.weights_desc(); |
160 | const auto diff_src_desc = pd.diff_src_desc(); |
161 | const auto diff_wei_desc = pd.diff_weights_desc(); |
162 | const auto diff_dst_desc = pd.diff_dst_desc(); |
163 | // query for src_desc via exec arg |
164 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
165 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
166 | // query for weights_desc via exec arg |
167 | ASSERT_TRUE( |
168 | pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS) == wei_desc); |
169 | if (p.src_tag != tag::any) { ASSERT_TRUE(wei_md == wei_desc); } |
170 | // query for diff_src_desc via exec arg |
171 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
172 | == diff_src_desc); |
173 | if (p.src_tag != tag::any) { |
174 | ASSERT_TRUE(diff_src_md == diff_src_desc); |
175 | } |
176 | // query for diff_wei_desc via exec arg |
177 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS) |
178 | == diff_wei_desc); |
179 | if (p.src_tag != tag::any) { |
180 | ASSERT_TRUE(diff_wei_md == diff_wei_desc); |
181 | } |
182 | // query for diff_dst_desc via exec arg |
183 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST) |
184 | == diff_dst_desc); |
185 | if (p.dst_tag != tag::any) { |
186 | ASSERT_TRUE(diff_dst_md == diff_dst_desc); |
187 | } |
188 | |
189 | // query primitive parameters |
190 | ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward); |
191 | |
192 | // check primitive returns zero_md for all rest md |
193 | ASSERT_TRUE(pd.dst_desc().is_zero()); |
194 | |
195 | auto diff_src = test::make_memory(diff_src_desc, eng); |
196 | auto diff_wei = test::make_memory(diff_wei_desc, eng); |
197 | auto diff_dst = test::make_memory(diff_dst_desc, eng); |
198 | |
199 | fill_data(p.dst_dt, diff_dst, 2, 2); |
200 | |
201 | // test out-place mode |
202 | prelu.execute(strm, |
203 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_WEIGHTS, wei}, |
204 | {DNNL_ARG_DIFF_SRC, diff_src}, |
205 | {DNNL_ARG_DIFF_WEIGHTS, diff_wei}, |
206 | {DNNL_ARG_DIFF_DST, diff_dst}}); |
207 | strm.wait(); |
208 | } |
209 | |
210 | void Test() { |
211 | const bool is_int8 = p.src_dt == dt::s8 || p.src_dt == dt::u8; |
212 | std::vector<prop_kind> pks = {is_int8 ? prop_kind::forward_inference |
213 | : prop_kind::forward_training}; |
214 | |
215 | for (auto pk : pks) { |
216 | Forward(pk); |
217 | |
218 | bool to_continue = pk != prop_kind::forward_training; |
219 | if (to_continue) continue; |
220 | |
221 | Backward(); |
222 | } |
223 | } |
224 | |
225 | bool is_fwd(prop_kind pk) const { |
226 | return pk == prop_kind::forward_training |
227 | || pk == prop_kind::forward_inference; |
228 | } |
229 | }; |
230 | |
231 | using tp = prelu_test_params_t; |
232 | |
233 | TEST_P(prelu_test_t, TestsPrelu) {} |
234 | |
235 | INSTANTIATE_TEST_SUITE_P(Test_Prelu_EF, prelu_test_t, |
236 | ::testing::Values( |
237 | // Negative dims |
238 | tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nchw, tag::nchw, |
239 | {2, -4, 128, 256}, {2, 4, 128, 256}, true, |
240 | dnnl_invalid_arguments}, |
241 | // Negative dims |
242 | tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nchw, tag::nchw, |
243 | {2, 4, 128, 256}, {2, 4, -128, 256}, true, |
244 | dnnl_invalid_arguments}, |
245 | // Incompatible dims |
246 | tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nchw, tag::nchw, |
247 | {2, 4, 128, 256}, {2, 4, 2, 2}, true, |
248 | dnnl_invalid_arguments}, |
249 | // Tag for src on forward is not specified |
250 | tp {dt::f32, dt::f32, dt::f32, tag::any, tag::nchw, tag::nchw, |
251 | {2, 4, 128, 256}, {2, 4, 128, 256}, true, |
252 | dnnl_invalid_arguments}, |
253 | // Data type for src is not specified |
254 | tp {dt::undef, dt::f32, dt::f32, tag::nchw, tag::nchw, |
255 | tag::nchw, {2, 4, 128, 256}, {2, 4, 128, 256}, true, |
256 | dnnl_invalid_arguments}, |
257 | // Different data types are not supported |
258 | tp {dt::f32, dt::f32, dt::bf16, tag::nchw, tag::nchw, tag::nchw, |
259 | {2, 4, 128, 256}, {2, 4, 128, 256}, true, |
260 | dnnl_unimplemented}, |
261 | // Different memory formats are not supported |
262 | tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nchw, tag::nhwc, |
263 | {2, 4, 128, 256}, {2, 4, 128, 256}, true, |
264 | dnnl_unimplemented})); |
265 | |
266 | static auto all_cases = [](dt src_dt, dt wei_dt, dt dst_dt) { |
267 | return ::testing::Values(tp {src_dt, wei_dt, dst_dt, tag::nwc, tag::nwc, |
268 | tag::nwc, {2, 16, 10}, {2, 16, 10}}, |
269 | tp {src_dt, wei_dt, dst_dt, tag::ncw, tag::ncw, tag::ncw, |
270 | {2, 64, 27}, {1, 1, 1}}, |
271 | tp {src_dt, wei_dt, dst_dt, tag::nhwc, tag::nhwc, tag::nhwc, |
272 | {2, 15, 10, 8}, {2, 15, 10, 8}}, |
273 | tp {src_dt, wei_dt, dst_dt, tag::nchw, tag::nchw, tag::nchw, |
274 | {2, 64, 27, 27}, {1, 64, 1, 1}}, |
275 | tp {src_dt, wei_dt, dst_dt, tag::nChw8c, tag::nChw8c, tag::nChw8c, |
276 | {2, 16, 16, 8}, {1, 16, 1, 1}}, |
277 | tp {src_dt, wei_dt, dst_dt, tag::nChw16c, tag::nChw16c, |
278 | tag::nChw16c, {2, 16, 4, 4}, {1, 1, 1, 1}}, |
279 | tp {src_dt, wei_dt, dst_dt, tag::ncdhw, tag::ncdhw, tag::ncdhw, |
280 | {2, 64, 7, 7, 7}, {1, 1, 1, 1, 1}}, |
281 | tp {src_dt, wei_dt, dst_dt, tag::ndhwc, tag::ndhwc, tag::ndhwc, |
282 | {10, 10, 10, 10, 10}, {10, 10, 10, 10, 10}}, |
283 | tp {src_dt, wei_dt, dst_dt, tag::nCdhw16c, tag::nCdhw16c, |
284 | tag::nCdhw16c, {4, 16, 2, 2, 2}, {1, 16, 1, 1, 1}}); |
285 | }; |
286 | |
287 | #define EXPAND_DTS(src, wei, dst) \ |
288 | memory::data_type::src, memory::data_type::wei, memory::data_type::dst |
289 | |
290 | #define INST_TEST_CASE(name, suite, ...) \ |
291 | INSTANTIATE_TEST_SUITE_P(name, prelu_test_t, suite(__VA_ARGS__)); |
292 | |
293 | #define CPU_INST_TEST_CASE(name, suite, ...) \ |
294 | CPU_INSTANTIATE_TEST_SUITE_P(name, prelu_test_t, suite(__VA_ARGS__)); |
295 | |
296 | #define GPU_INST_TEST_CASE(name, suite, ...) \ |
297 | GPU_INSTANTIATE_TEST_SUITE_P(name, prelu_test_t, suite(__VA_ARGS__)); |
298 | |
299 | INST_TEST_CASE(PreluSimpleF32, all_cases, EXPAND_DTS(f32, f32, f32)); |
300 | INST_TEST_CASE(PreluSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16, bf16)); |
301 | INST_TEST_CASE(PreluSimpleBF16F32, all_cases, EXPAND_DTS(bf16, f32, bf16)); |
302 | INST_TEST_CASE(PreluSimpleF16, all_cases, EXPAND_DTS(f16, f16, f16)); |
303 | INST_TEST_CASE(PreluSimpleU8, all_cases, EXPAND_DTS(u8, u8, u8)); |
304 | INST_TEST_CASE(PreluSimpleS8, all_cases, EXPAND_DTS(s8, s8, s8)); |
305 | |
306 | } // namespace dnnl |
307 | |