1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24using tag = memory::format_tag;
25using dt = memory::data_type;
26
27struct prelu_test_params_t {
28 dt src_dt;
29 dt wei_dt;
30 dt dst_dt;
31 tag src_tag;
32 tag wei_tag;
33 tag dst_tag;
34 memory::dims src_dims;
35 memory::dims wei_dims;
36 bool expect_to_fail;
37 dnnl_status_t expected_status;
38};
39
40class prelu_test_t : public ::testing::TestWithParam<prelu_test_params_t> {
41private:
42 prelu_test_params_t p;
43 memory src, wei;
44 std::shared_ptr<prelu_forward::primitive_desc> pd_fwd_hint;
45
46protected:
47 void SetUp() override {
48 p = ::testing::TestWithParam<prelu_test_params_t>::GetParam();
49
50 SKIP_IF_CUDA(true, "Prelu primitive not supported by CUDA");
51
52 SKIP_IF(unsupported_data_type(p.src_dt, p.wei_dt, p.dst_dt),
53 "Engine does not support this data type.");
54
55 catch_expected_failures(
56 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
57 }
58
59 void Forward(prop_kind pk) {
60 // prelu specific types and values
61 using pd_t = prelu_forward::primitive_desc;
62
63 auto eng = get_test_engine();
64 auto strm = make_stream(eng);
65
66 auto aa = allows_attr_t {false};
67
68 auto src_md = memory::desc(p.src_dims, p.src_dt, p.src_tag);
69 auto wei_md = memory::desc(p.wei_dims, p.wei_dt, p.wei_tag);
70 auto dst_md = memory::desc(p.src_dims, p.dst_dt, p.dst_tag);
71
72 // default pd ctor
73 auto pd = pd_t();
74 // regular pd ctor
75 pd = pd_t(eng, pk, src_md, wei_md, dst_md);
76 // test all pd ctors
77 test_fwd_pd_constructors<pd_t>(pd, aa, pk, src_md, wei_md, dst_md);
78 pd_fwd_hint = std::make_shared<pd_t>(pd);
79
80 EXPECT_ANY_THROW(prelu_forward(pd, {}));
81 // default primitive ctor
82 auto prelu = prelu_forward();
83 // regular primitive ctor
84 prelu = prelu_forward(pd);
85
86 // check primitive kind is prelu
87 ASSERT_TRUE(prelu.get_kind() == primitive::kind::prelu);
88 // query for descs from pd
89 const auto src_desc = pd.src_desc();
90 const auto wei_desc = pd.weights_desc();
91 const auto dst_desc = pd.dst_desc();
92 // query for src_desc via exec arg
93 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
94 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
95 // query for weights_desc via exec arg
96 ASSERT_TRUE(
97 pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS) == wei_desc);
98 if (p.src_tag != tag::any) { ASSERT_TRUE(wei_md == wei_desc); }
99 // query for dst_desc via exec arg
100 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
101 if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); }
102
103 // query primitive parameters
104 ASSERT_EQ(pd.get_prop_kind(), pk);
105
106 // check primitive returns zero_md for all rest md
107 ASSERT_TRUE(pd.diff_src_desc().is_zero());
108 ASSERT_TRUE(pd.diff_dst_desc().is_zero());
109 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
110
111 src = test::make_memory(src_desc, eng);
112 wei = test::make_memory(wei_desc, eng);
113 auto dst = test::make_memory(dst_desc, eng);
114
115 fill_data(p.src_dt, src, 1, 1);
116 fill_data(p.wei_dt, wei, 2, 2);
117 // test out-place mode
118 prelu.execute(strm,
119 {{DNNL_ARG_SRC, src}, {DNNL_ARG_WEIGHTS, wei},
120 {DNNL_ARG_DST, dst}});
121 strm.wait();
122 }
123
124 void Backward() {
125 // prelu specific types and values
126 using pd_t = prelu_backward::primitive_desc;
127 using hint_pd_t = prelu_forward::primitive_desc;
128 allows_attr_t aa {false}; // doesn't support anything
129
130 auto eng = get_test_engine();
131 auto strm = make_stream(eng);
132
133 auto src_md = memory::desc(p.src_dims, p.src_dt, p.src_tag);
134 auto wei_md = memory::desc(p.wei_dims, p.wei_dt, p.wei_tag);
135 auto diff_src_md = memory::desc(p.src_dims, p.src_dt, p.src_tag);
136 auto diff_wei_md = memory::desc(p.wei_dims, p.wei_dt, p.wei_tag);
137 auto diff_dst_md = memory::desc(p.src_dims, p.dst_dt, p.dst_tag);
138
139 // default pd ctor
140 auto pd = pd_t();
141 // regular pd ctor
142 pd = pd_t(eng, src_md, wei_md, diff_src_md, diff_wei_md, diff_dst_md,
143 *pd_fwd_hint);
144 // test all pd ctors
145 test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa, src_md,
146 wei_md, diff_src_md, diff_wei_md, diff_dst_md);
147
148 EXPECT_ANY_THROW(prelu_backward(pd, {}));
149 // default primitive ctor
150 auto prelu = prelu_backward();
151 // regular primitive ctor
152 prelu = prelu_backward(pd);
153
154 // check primitive kind is prelu
155 ASSERT_TRUE(prelu.get_kind() == primitive::kind::prelu);
156
157 // query for descs from pd
158 const auto src_desc = pd.src_desc();
159 const auto wei_desc = pd.weights_desc();
160 const auto diff_src_desc = pd.diff_src_desc();
161 const auto diff_wei_desc = pd.diff_weights_desc();
162 const auto diff_dst_desc = pd.diff_dst_desc();
163 // query for src_desc via exec arg
164 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
165 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
166 // query for weights_desc via exec arg
167 ASSERT_TRUE(
168 pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS) == wei_desc);
169 if (p.src_tag != tag::any) { ASSERT_TRUE(wei_md == wei_desc); }
170 // query for diff_src_desc via exec arg
171 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
172 == diff_src_desc);
173 if (p.src_tag != tag::any) {
174 ASSERT_TRUE(diff_src_md == diff_src_desc);
175 }
176 // query for diff_wei_desc via exec arg
177 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS)
178 == diff_wei_desc);
179 if (p.src_tag != tag::any) {
180 ASSERT_TRUE(diff_wei_md == diff_wei_desc);
181 }
182 // query for diff_dst_desc via exec arg
183 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
184 == diff_dst_desc);
185 if (p.dst_tag != tag::any) {
186 ASSERT_TRUE(diff_dst_md == diff_dst_desc);
187 }
188
189 // query primitive parameters
190 ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward);
191
192 // check primitive returns zero_md for all rest md
193 ASSERT_TRUE(pd.dst_desc().is_zero());
194
195 auto diff_src = test::make_memory(diff_src_desc, eng);
196 auto diff_wei = test::make_memory(diff_wei_desc, eng);
197 auto diff_dst = test::make_memory(diff_dst_desc, eng);
198
199 fill_data(p.dst_dt, diff_dst, 2, 2);
200
201 // test out-place mode
202 prelu.execute(strm,
203 {{DNNL_ARG_SRC, src}, {DNNL_ARG_WEIGHTS, wei},
204 {DNNL_ARG_DIFF_SRC, diff_src},
205 {DNNL_ARG_DIFF_WEIGHTS, diff_wei},
206 {DNNL_ARG_DIFF_DST, diff_dst}});
207 strm.wait();
208 }
209
210 void Test() {
211 const bool is_int8 = p.src_dt == dt::s8 || p.src_dt == dt::u8;
212 std::vector<prop_kind> pks = {is_int8 ? prop_kind::forward_inference
213 : prop_kind::forward_training};
214
215 for (auto pk : pks) {
216 Forward(pk);
217
218 bool to_continue = pk != prop_kind::forward_training;
219 if (to_continue) continue;
220
221 Backward();
222 }
223 }
224
225 bool is_fwd(prop_kind pk) const {
226 return pk == prop_kind::forward_training
227 || pk == prop_kind::forward_inference;
228 }
229};
230
231using tp = prelu_test_params_t;
232
233TEST_P(prelu_test_t, TestsPrelu) {}
234
235INSTANTIATE_TEST_SUITE_P(Test_Prelu_EF, prelu_test_t,
236 ::testing::Values(
237 // Negative dims
238 tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nchw, tag::nchw,
239 {2, -4, 128, 256}, {2, 4, 128, 256}, true,
240 dnnl_invalid_arguments},
241 // Negative dims
242 tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nchw, tag::nchw,
243 {2, 4, 128, 256}, {2, 4, -128, 256}, true,
244 dnnl_invalid_arguments},
245 // Incompatible dims
246 tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nchw, tag::nchw,
247 {2, 4, 128, 256}, {2, 4, 2, 2}, true,
248 dnnl_invalid_arguments},
249 // Tag for src on forward is not specified
250 tp {dt::f32, dt::f32, dt::f32, tag::any, tag::nchw, tag::nchw,
251 {2, 4, 128, 256}, {2, 4, 128, 256}, true,
252 dnnl_invalid_arguments},
253 // Data type for src is not specified
254 tp {dt::undef, dt::f32, dt::f32, tag::nchw, tag::nchw,
255 tag::nchw, {2, 4, 128, 256}, {2, 4, 128, 256}, true,
256 dnnl_invalid_arguments},
257 // Different data types are not supported
258 tp {dt::f32, dt::f32, dt::bf16, tag::nchw, tag::nchw, tag::nchw,
259 {2, 4, 128, 256}, {2, 4, 128, 256}, true,
260 dnnl_unimplemented},
261 // Different memory formats are not supported
262 tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nchw, tag::nhwc,
263 {2, 4, 128, 256}, {2, 4, 128, 256}, true,
264 dnnl_unimplemented}));
265
266static auto all_cases = [](dt src_dt, dt wei_dt, dt dst_dt) {
267 return ::testing::Values(tp {src_dt, wei_dt, dst_dt, tag::nwc, tag::nwc,
268 tag::nwc, {2, 16, 10}, {2, 16, 10}},
269 tp {src_dt, wei_dt, dst_dt, tag::ncw, tag::ncw, tag::ncw,
270 {2, 64, 27}, {1, 1, 1}},
271 tp {src_dt, wei_dt, dst_dt, tag::nhwc, tag::nhwc, tag::nhwc,
272 {2, 15, 10, 8}, {2, 15, 10, 8}},
273 tp {src_dt, wei_dt, dst_dt, tag::nchw, tag::nchw, tag::nchw,
274 {2, 64, 27, 27}, {1, 64, 1, 1}},
275 tp {src_dt, wei_dt, dst_dt, tag::nChw8c, tag::nChw8c, tag::nChw8c,
276 {2, 16, 16, 8}, {1, 16, 1, 1}},
277 tp {src_dt, wei_dt, dst_dt, tag::nChw16c, tag::nChw16c,
278 tag::nChw16c, {2, 16, 4, 4}, {1, 1, 1, 1}},
279 tp {src_dt, wei_dt, dst_dt, tag::ncdhw, tag::ncdhw, tag::ncdhw,
280 {2, 64, 7, 7, 7}, {1, 1, 1, 1, 1}},
281 tp {src_dt, wei_dt, dst_dt, tag::ndhwc, tag::ndhwc, tag::ndhwc,
282 {10, 10, 10, 10, 10}, {10, 10, 10, 10, 10}},
283 tp {src_dt, wei_dt, dst_dt, tag::nCdhw16c, tag::nCdhw16c,
284 tag::nCdhw16c, {4, 16, 2, 2, 2}, {1, 16, 1, 1, 1}});
285};
286
287#define EXPAND_DTS(src, wei, dst) \
288 memory::data_type::src, memory::data_type::wei, memory::data_type::dst
289
290#define INST_TEST_CASE(name, suite, ...) \
291 INSTANTIATE_TEST_SUITE_P(name, prelu_test_t, suite(__VA_ARGS__));
292
293#define CPU_INST_TEST_CASE(name, suite, ...) \
294 CPU_INSTANTIATE_TEST_SUITE_P(name, prelu_test_t, suite(__VA_ARGS__));
295
296#define GPU_INST_TEST_CASE(name, suite, ...) \
297 GPU_INSTANTIATE_TEST_SUITE_P(name, prelu_test_t, suite(__VA_ARGS__));
298
299INST_TEST_CASE(PreluSimpleF32, all_cases, EXPAND_DTS(f32, f32, f32));
300INST_TEST_CASE(PreluSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16, bf16));
301INST_TEST_CASE(PreluSimpleBF16F32, all_cases, EXPAND_DTS(bf16, f32, bf16));
302INST_TEST_CASE(PreluSimpleF16, all_cases, EXPAND_DTS(f16, f16, f16));
303INST_TEST_CASE(PreluSimpleU8, all_cases, EXPAND_DTS(u8, u8, u8));
304INST_TEST_CASE(PreluSimpleS8, all_cases, EXPAND_DTS(s8, s8, s8));
305
306} // namespace dnnl
307