1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | using tag = memory::format_tag; |
25 | using dt = memory::data_type; |
26 | |
27 | struct eltwise_test_params_t { |
28 | dt src_dt; |
29 | dt dst_dt; |
30 | dt diff_src_dt; |
31 | tag src_tag; |
32 | tag dst_tag; |
33 | tag diff_src_tag; |
34 | memory::dims dims; |
35 | float alpha, beta; |
36 | bool expect_to_fail; |
37 | dnnl_status_t expected_status; |
38 | }; |
39 | |
40 | bool cuda_check_format_tag(tag atag) { |
41 | // Blocking is not supported by cuDNN |
42 | return !impl::utils::one_of( |
43 | atag, tag::aBcd8b, tag::aBcd16b, tag::aBcde8b, tag::aBcde16b); |
44 | } |
45 | |
46 | template <typename... Rest> |
47 | bool cuda_check_format_tag(tag first_tag, Rest... rest_tags) { |
48 | const bool ok = cuda_check_format_tag(first_tag); |
49 | if (!ok) return ok; |
50 | return cuda_check_format_tag(rest_tags...); |
51 | } |
52 | |
53 | bool hip_check_format_tag(tag atag) { |
54 | // HIP has the same limitations for `tag` as CUDA. |
55 | return cuda_check_format_tag(atag); |
56 | } |
57 | |
58 | template <typename... Rest> |
59 | bool hip_check_format_tag(tag first_tag, Rest... rest_tags) { |
60 | const bool ok = hip_check_format_tag(first_tag); |
61 | if (!ok) return ok; |
62 | return hip_check_format_tag(rest_tags...); |
63 | } |
64 | |
65 | class eltwise_test_t : public ::testing::TestWithParam<eltwise_test_params_t> { |
66 | private: |
67 | eltwise_test_params_t p; |
68 | memory src, dst; |
69 | std::shared_ptr<eltwise_forward::primitive_desc> pd_fwd_hint; |
70 | |
71 | protected: |
72 | void SetUp() override { |
73 | p = ::testing::TestWithParam<eltwise_test_params_t>::GetParam(); |
74 | |
75 | SKIP_IF(unsupported_data_type(p.src_dt, p.dst_dt), |
76 | "Engine does not support this data type." ); |
77 | |
78 | SKIP_IF_CUDA( |
79 | p.dst_dt == dt::s8, "Unsupported int8 destination data type" ); |
80 | SKIP_IF_HIP(p.src_dt == dt::s8, "Unsupported int8 source data type" ); |
81 | |
82 | SKIP_IF_CUDA(!cuda_check_format_tag(p.src_tag, p.dst_tag), |
83 | "Unsupported format tag" ); |
84 | SKIP_IF_HIP(!hip_check_format_tag(p.src_tag, p.dst_tag), |
85 | "Unsupported format tag" ); |
86 | SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef |
87 | && p.dst_dt != dt::undef, |
88 | "Unsupported different data types for source and " |
89 | "destination" ); |
90 | SKIP_IF_HIP(p.src_dt != p.dst_dt && p.src_dt != dt::undef |
91 | && p.dst_dt != dt::undef, |
92 | "Unsupported different data types for source and " |
93 | "destination" ); |
94 | SKIP_IF_CUDA(p.src_tag != p.dst_tag && p.src_tag != tag::any |
95 | && p.dst_tag != tag::any, |
96 | "Unsupported different memory formats for source and " |
97 | "destination" ); |
98 | SKIP_IF_HIP(p.src_tag != p.dst_tag && p.src_tag != tag::any |
99 | && p.dst_tag != tag::any, |
100 | "Unsupported different memory formats for source and " |
101 | "destination" ); |
102 | |
103 | catch_expected_failures( |
104 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
105 | } |
106 | |
107 | void Forward(prop_kind pk, algorithm aalgorithm) { |
108 | // eltwise specific types and values |
109 | using pd_t = eltwise_forward::primitive_desc; |
110 | |
111 | auto eng = get_test_engine(); |
112 | auto strm = make_stream(eng); |
113 | |
114 | auto aa = allows_attr_t {false}; |
115 | aa.po_sum = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); |
116 | aa.po_eltwise = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); |
117 | aa.po_binary = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); |
118 | |
119 | auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
120 | auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
121 | |
122 | // default pd ctor |
123 | auto pd = pd_t(); |
124 | // regular pd ctor |
125 | pd = pd_t(eng, pk, aalgorithm, src_md, dst_md, p.alpha, p.beta); |
126 | // test all pd ctors |
127 | test_fwd_pd_constructors<pd_t>( |
128 | pd, aa, pk, aalgorithm, src_md, dst_md, p.alpha, p.beta); |
129 | pd_fwd_hint = std::make_shared<pd_t>(pd); |
130 | |
131 | EXPECT_ANY_THROW(eltwise_forward(pd, {})); |
132 | // default primitive ctor |
133 | auto eltwise = eltwise_forward(); |
134 | // regular primitive ctor |
135 | eltwise = eltwise_forward(pd); |
136 | |
137 | // check primitive kind is eltwise |
138 | ASSERT_TRUE(eltwise.get_kind() == primitive::kind::eltwise); |
139 | // query for descs from pd |
140 | const auto src_desc = pd.src_desc(); |
141 | const auto dst_desc = pd.dst_desc(); |
142 | // query for src_desc via exec arg |
143 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
144 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
145 | // query for dst_desc via exec arg |
146 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
147 | if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); } |
148 | |
149 | // query primitive parameters |
150 | ASSERT_EQ(pd.get_prop_kind(), pk); |
151 | ASSERT_EQ(pd.get_algorithm(), aalgorithm); |
152 | ASSERT_EQ(pd.get_alpha(), p.alpha); |
153 | ASSERT_EQ(pd.get_beta(), p.beta); |
154 | |
155 | // check primitive returns zero_md for all rest md |
156 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
157 | ASSERT_TRUE(pd.diff_src_desc().is_zero()); |
158 | ASSERT_TRUE(pd.diff_dst_desc().is_zero()); |
159 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
160 | |
161 | src = test::make_memory(src_desc, eng); |
162 | dst = test::make_memory(dst_desc, eng); |
163 | |
164 | fill_data(p.src_dt, src, 1, 1); |
165 | // test out-place mode |
166 | eltwise.execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); |
167 | strm.wait(); |
168 | |
169 | // test in-place mode on forward |
170 | if (p.src_tag == p.dst_tag && p.src_dt == p.dst_dt) { |
171 | // TODO: add a copy of memory and result comparison with previous |
172 | // dst output. |
173 | eltwise.execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, src}}); |
174 | strm.wait(); |
175 | } |
176 | } |
177 | |
178 | void Backward(algorithm aalgorithm) { |
179 | // eltwise specific types and values |
180 | using pd_t = eltwise_backward::primitive_desc; |
181 | using hint_pd_t = eltwise_forward::primitive_desc; |
182 | allows_attr_t aa {false}; // doesn't support anything |
183 | |
184 | auto eng = get_test_engine(); |
185 | auto strm = make_stream(eng); |
186 | auto diff_src_md = memory::desc(p.dims, p.diff_src_dt, p.diff_src_tag); |
187 | auto diff_dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
188 | auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
189 | |
190 | // default pd ctor |
191 | auto pd = pd_t(); |
192 | // regular pd ctor |
193 | pd = pd_t(eng, aalgorithm, diff_src_md, diff_dst_md, src_md, p.alpha, |
194 | p.beta, *pd_fwd_hint); |
195 | // test all pd ctors |
196 | test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa, |
197 | aalgorithm, diff_src_md, diff_dst_md, src_md, p.alpha, p.beta); |
198 | |
199 | EXPECT_ANY_THROW(eltwise_backward(pd, {})); |
200 | // default primitive ctor |
201 | auto eltwise = eltwise_backward(); |
202 | // regular primitive ctor |
203 | eltwise = eltwise_backward(pd); |
204 | |
205 | // check primitive kind is eltwise |
206 | ASSERT_TRUE(eltwise.get_kind() == primitive::kind::eltwise); |
207 | |
208 | // query for descs from pd |
209 | const auto diff_src_desc = pd.diff_src_desc(); |
210 | const auto diff_dst_desc = pd.diff_dst_desc(); |
211 | const auto src_desc = pd.src_desc(); |
212 | const auto dst_desc = pd.dst_desc(); |
213 | // query for diff_src_desc via exec arg |
214 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
215 | == diff_src_desc); |
216 | if (p.diff_src_tag != tag::any) { |
217 | ASSERT_TRUE(diff_src_md == diff_src_desc); |
218 | } |
219 | // query for diff_dst_desc via exec arg |
220 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST) |
221 | == diff_dst_desc); |
222 | if (p.dst_tag != tag::any) { |
223 | ASSERT_TRUE(diff_dst_md == diff_dst_desc); |
224 | } |
225 | // query for src_desc via exec arg |
226 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
227 | ASSERT_TRUE((p.src_tag != tag::any && src_md == src_desc) |
228 | || (p.dst_tag != tag::any |
229 | && pd_fwd_hint.get()->dst_desc() == dst_desc)); |
230 | |
231 | // query primitive parameters |
232 | ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward_data); |
233 | ASSERT_EQ(pd.get_algorithm(), aalgorithm); |
234 | ASSERT_EQ(pd.get_alpha(), p.alpha); |
235 | ASSERT_EQ(pd.get_beta(), p.beta); |
236 | |
237 | // check primitive returns zero_md for all rest md |
238 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
239 | ASSERT_TRUE(pd.dst_desc().is_zero() || pd.src_desc().is_zero()); |
240 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
241 | |
242 | auto diff_src = test::make_memory(diff_src_desc, eng); |
243 | auto diff_dst = test::make_memory(diff_dst_desc, eng); |
244 | |
245 | fill_data(p.diff_src_dt, diff_dst, 2, 2); |
246 | |
247 | // test out-place mode |
248 | eltwise.execute(strm, |
249 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, |
250 | {DNNL_ARG_DIFF_DST, diff_dst}, |
251 | {DNNL_ARG_DIFF_SRC, diff_src}}); |
252 | strm.wait(); |
253 | |
254 | // test in-place mode |
255 | if (p.dst_tag == p.diff_src_tag && p.dst_dt == p.diff_src_dt) { |
256 | eltwise.execute(strm, |
257 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}, |
258 | {DNNL_ARG_DIFF_DST, diff_dst}, |
259 | {DNNL_ARG_DIFF_SRC, diff_dst}}); |
260 | strm.wait(); |
261 | } |
262 | } |
263 | |
264 | void Test() { |
265 | const bool is_int8 = p.src_dt == dt::s8 || p.src_dt == dt::u8; |
266 | std::vector<prop_kind> pks = {is_int8 ? prop_kind::forward_inference |
267 | : prop_kind::forward_training}; |
268 | |
269 | std::vector<algorithm> algs_all = {algorithm::eltwise_relu, |
270 | algorithm::eltwise_tanh, algorithm::eltwise_elu, |
271 | algorithm::eltwise_square, algorithm::eltwise_abs, |
272 | algorithm::eltwise_sqrt, algorithm::eltwise_swish, |
273 | algorithm::eltwise_linear, algorithm::eltwise_soft_relu, |
274 | algorithm::eltwise_mish, algorithm::eltwise_logistic, |
275 | algorithm::eltwise_exp, algorithm::eltwise_gelu_tanh, |
276 | algorithm::eltwise_gelu_erf, algorithm::eltwise_log, |
277 | algorithm::eltwise_clip, algorithm::eltwise_clip_v2, |
278 | algorithm::eltwise_pow, algorithm::eltwise_round, |
279 | algorithm::eltwise_hardswish, algorithm::eltwise_hardsigmoid, |
280 | algorithm::eltwise_relu_use_dst_for_bwd, |
281 | algorithm::eltwise_tanh_use_dst_for_bwd, |
282 | algorithm::eltwise_elu_use_dst_for_bwd, |
283 | algorithm::eltwise_sqrt_use_dst_for_bwd, |
284 | algorithm::eltwise_logistic_use_dst_for_bwd, |
285 | algorithm::eltwise_exp_use_dst_for_bwd, |
286 | algorithm::eltwise_clip_v2_use_dst_for_bwd}; |
287 | // TODO: generalize this function. |
288 | if (p.src_dt != dt::f32) { |
289 | auto it = algs_all.begin(); |
290 | while (true) { |
291 | if (*it == algorithm::eltwise_round) { |
292 | algs_all.erase(it); |
293 | break; |
294 | } |
295 | it++; |
296 | if (it == algs_all.end()) break; |
297 | } |
298 | } |
299 | |
300 | std::vector<algorithm> algs_int8 |
301 | = {algorithm::eltwise_relu, algorithm::eltwise_linear}; |
302 | const auto &algs = is_int8 ? algs_int8 : algs_all; |
303 | |
304 | for_(auto pk : pks) |
305 | for (auto alg : algs) { |
306 | SKIP_FOR_LOOP_CUDA(is_fwd(pk) |
307 | && !impl::utils::one_of(alg, |
308 | algorithm::eltwise_relu, |
309 | algorithm::eltwise_tanh, |
310 | algorithm::eltwise_elu, |
311 | algorithm::eltwise_logistic), |
312 | "Unsupported algorithm type for CUDA" ); |
313 | SKIP_FOR_LOOP_CUDA(alg == algorithm::eltwise_relu && p.alpha != 0.f, |
314 | "Unsupported combination of algorithm type and alpha " |
315 | "parameter for CUDA" ); |
316 | |
317 | SKIP_FOR_LOOP_HIP( |
318 | !impl::utils::one_of(alg, algorithm::eltwise_relu, |
319 | algorithm::eltwise_tanh, algorithm::eltwise_elu, |
320 | algorithm::eltwise_logistic, |
321 | algorithm::eltwise_soft_relu, |
322 | algorithm::eltwise_abs), |
323 | "Unsupported algorithm type for HIP" ); |
324 | |
325 | Forward(pk, alg); |
326 | |
327 | bool to_continue = pk != prop_kind::forward_training |
328 | || p.diff_src_dt == dt::undef |
329 | || alg == algorithm::eltwise_round; |
330 | if (to_continue) continue; |
331 | |
332 | SKIP_FOR_LOOP_CUDA( |
333 | !impl::utils::one_of(alg, algorithm::eltwise_relu), |
334 | "Unsupported algorithm type for CUDA" ); |
335 | |
336 | SKIP_IF(unsupported_data_type(p.diff_src_dt), |
337 | "Engine does not support this data type." ); |
338 | SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_src_tag), |
339 | "Unsupported format tag" ); |
340 | SKIP_IF_HIP(!hip_check_format_tag(p.diff_src_tag), |
341 | "Unsupported format tag" ); |
342 | |
343 | SKIP_IF_CUDA(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef |
344 | && p.diff_src_dt != dt::undef, |
345 | "Unsupported different data types for diff_source and " |
346 | "diff_destination" ); |
347 | SKIP_IF_HIP(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef |
348 | && p.diff_src_dt != dt::undef, |
349 | "Unsupported different data types for diff_source and " |
350 | "diff_destination" ); |
351 | |
352 | SKIP_IF_CUDA(p.src_tag != p.diff_src_tag && p.src_tag != tag::any |
353 | && p.diff_src_tag != tag::any, |
354 | "Unsupported different memory formats for diff_source " |
355 | "and " |
356 | "diff_destination" ); |
357 | SKIP_IF_HIP(p.src_tag != p.diff_src_tag && p.src_tag != tag::any |
358 | && p.diff_src_tag != tag::any, |
359 | "Unsupported different memory formats for diff_source " |
360 | "and diff_destination" ); |
361 | |
362 | Backward(alg); |
363 | } |
364 | } |
365 | |
366 | bool is_fwd(prop_kind pk) const { |
367 | return pk == prop_kind::forward_training |
368 | || pk == prop_kind::forward_inference; |
369 | } |
370 | }; |
371 | |
372 | using tp = eltwise_test_params_t; |
373 | |
374 | TEST_P(eltwise_test_t, TestsEltwise) {} |
375 | |
376 | INSTANTIATE_TEST_SUITE_P(Test_Eltwise_EF, eltwise_test_t, |
377 | ::testing::Values( |
378 | // Negative dims |
379 | tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nchw, |
380 | tag::undef, {2, -2, 128, 256}, 1.f, 2.f, true, |
381 | dnnl_invalid_arguments}, |
382 | // Tag for src on forward is not specified |
383 | tp {dt::f32, dt::f32, dt::undef, tag::any, tag::nchw, |
384 | tag::undef, {2, 2, 128, 256}, 1.f, 2.f, true, |
385 | dnnl_invalid_arguments}, |
386 | // Tag for src on backward is not specified |
387 | tp {dt::f32, dt::f32, dt::f32, tag::any, tag::nchw, tag::nchw, |
388 | {2, 2, 128, 256}, 1.f, 2.f, true, |
389 | dnnl_invalid_arguments}, |
390 | // Data type for src is not specified |
391 | tp {dt::undef, dt::f32, dt::undef, tag::nchw, tag::nchw, |
392 | tag::undef, {2, 2, 128, 256}, 1.f, 2.f, true, |
393 | dnnl_invalid_arguments}, |
394 | // Different data types are not supported |
395 | tp {dt::f32, dt::bf16, dt::undef, tag::nchw, tag::nchw, |
396 | tag::undef, {2, 2, 128, 256}, 1.f, 2.f, true, |
397 | dnnl_unimplemented}, |
398 | // Different data types are not supported |
399 | tp {dt::f32, dt::bf16, dt::f32, tag::nchw, tag::nchw, tag::nchw, |
400 | {2, 2, 128, 256}, 1.f, 2.f, true, dnnl_unimplemented}, |
401 | // Different memory formats are not supported |
402 | tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nhwc, |
403 | tag::undef, {2, 2, 128, 256}, 1.f, 2.f, true, |
404 | dnnl_unimplemented}, |
405 | // Different memory formats are not supported |
406 | tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nhwc, tag::nchw, |
407 | {2, 2, 128, 256}, 1.f, 2.f, true, dnnl_unimplemented})); |
408 | |
409 | static auto all_cases = [](dt src_dt, dt dst_dt, dt diff_src_dt) { |
410 | return ::testing::Values(tp {src_dt, dst_dt, diff_src_dt, tag::nwc, |
411 | tag::nwc, tag::nwc, {2, 16, 10}, 0.f, 0.f}, |
412 | tp {src_dt, dst_dt, diff_src_dt, tag::ncw, tag::ncw, tag::ncw, |
413 | {2, 64, 27}, 1.f, 2.f}, |
414 | tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc, |
415 | {2, 16, 10, 8}, 0.f, 0.9f}, |
416 | tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw, |
417 | {2, 64, 27, 27}, 1.f, 2.f}, |
418 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c, |
419 | tag::nChw8c, {2, 16, 16, 8}, 0.1f, 0.9f}, |
420 | tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c, |
421 | tag::nChw16c, {2, 16, 4, 4}, 0.f, 0.f}, |
422 | tp {src_dt, dst_dt, diff_src_dt, tag::ncdhw, tag::ncdhw, tag::ncdhw, |
423 | {2, 64, 7, 7, 7}, 1.f, 1.f}, |
424 | tp {src_dt, dst_dt, diff_src_dt, tag::ncdhw, tag::ncdhw, tag::ncdhw, |
425 | {10, 10, 10, 10, 10}, 0.f, 0.f}, |
426 | tp {src_dt, dst_dt, diff_src_dt, tag::nCdhw16c, tag::nCdhw16c, |
427 | tag::nCdhw16c, {4, 15, 2, 2, 2}, 0.1f, 0.2f}); |
428 | }; |
429 | |
430 | #define EXPAND_DTS(src, dst, diff_src) \ |
431 | memory::data_type::src, memory::data_type::dst, memory::data_type::diff_src |
432 | |
433 | #define INST_TEST_CASE(name, suite, ...) \ |
434 | INSTANTIATE_TEST_SUITE_P(name, eltwise_test_t, suite(__VA_ARGS__)); |
435 | |
436 | #define CPU_INST_TEST_CASE(name, suite, ...) \ |
437 | CPU_INSTANTIATE_TEST_SUITE_P(name, eltwise_test_t, suite(__VA_ARGS__)); |
438 | |
439 | #define GPU_INST_TEST_CASE(name, suite, ...) \ |
440 | GPU_INSTANTIATE_TEST_SUITE_P(name, eltwise_test_t, suite(__VA_ARGS__)); |
441 | |
442 | INST_TEST_CASE(EltwiseSimpleF32, all_cases, EXPAND_DTS(f32, f32, f32)); |
443 | INST_TEST_CASE(EltwiseSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16, bf16)); |
444 | INST_TEST_CASE(EltwiseSimpleF16, all_cases, EXPAND_DTS(f16, f16, undef)); |
445 | INST_TEST_CASE(EltwiseSimpleU8, all_cases, EXPAND_DTS(u8, u8, undef)); |
446 | } // namespace dnnl |
447 | |