1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24using tag = memory::format_tag;
25using dt = memory::data_type;
26
27struct eltwise_test_params_t {
28 dt src_dt;
29 dt dst_dt;
30 dt diff_src_dt;
31 tag src_tag;
32 tag dst_tag;
33 tag diff_src_tag;
34 memory::dims dims;
35 float alpha, beta;
36 bool expect_to_fail;
37 dnnl_status_t expected_status;
38};
39
40bool cuda_check_format_tag(tag atag) {
41 // Blocking is not supported by cuDNN
42 return !impl::utils::one_of(
43 atag, tag::aBcd8b, tag::aBcd16b, tag::aBcde8b, tag::aBcde16b);
44}
45
46template <typename... Rest>
47bool cuda_check_format_tag(tag first_tag, Rest... rest_tags) {
48 const bool ok = cuda_check_format_tag(first_tag);
49 if (!ok) return ok;
50 return cuda_check_format_tag(rest_tags...);
51}
52
53bool hip_check_format_tag(tag atag) {
54 // HIP has the same limitations for `tag` as CUDA.
55 return cuda_check_format_tag(atag);
56}
57
58template <typename... Rest>
59bool hip_check_format_tag(tag first_tag, Rest... rest_tags) {
60 const bool ok = hip_check_format_tag(first_tag);
61 if (!ok) return ok;
62 return hip_check_format_tag(rest_tags...);
63}
64
65class eltwise_test_t : public ::testing::TestWithParam<eltwise_test_params_t> {
66private:
67 eltwise_test_params_t p;
68 memory src, dst;
69 std::shared_ptr<eltwise_forward::primitive_desc> pd_fwd_hint;
70
71protected:
72 void SetUp() override {
73 p = ::testing::TestWithParam<eltwise_test_params_t>::GetParam();
74
75 SKIP_IF(unsupported_data_type(p.src_dt, p.dst_dt),
76 "Engine does not support this data type.");
77
78 SKIP_IF_CUDA(
79 p.dst_dt == dt::s8, "Unsupported int8 destination data type");
80 SKIP_IF_HIP(p.src_dt == dt::s8, "Unsupported int8 source data type");
81
82 SKIP_IF_CUDA(!cuda_check_format_tag(p.src_tag, p.dst_tag),
83 "Unsupported format tag");
84 SKIP_IF_HIP(!hip_check_format_tag(p.src_tag, p.dst_tag),
85 "Unsupported format tag");
86 SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef
87 && p.dst_dt != dt::undef,
88 "Unsupported different data types for source and "
89 "destination");
90 SKIP_IF_HIP(p.src_dt != p.dst_dt && p.src_dt != dt::undef
91 && p.dst_dt != dt::undef,
92 "Unsupported different data types for source and "
93 "destination");
94 SKIP_IF_CUDA(p.src_tag != p.dst_tag && p.src_tag != tag::any
95 && p.dst_tag != tag::any,
96 "Unsupported different memory formats for source and "
97 "destination");
98 SKIP_IF_HIP(p.src_tag != p.dst_tag && p.src_tag != tag::any
99 && p.dst_tag != tag::any,
100 "Unsupported different memory formats for source and "
101 "destination");
102
103 catch_expected_failures(
104 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
105 }
106
107 void Forward(prop_kind pk, algorithm aalgorithm) {
108 // eltwise specific types and values
109 using pd_t = eltwise_forward::primitive_desc;
110
111 auto eng = get_test_engine();
112 auto strm = make_stream(eng);
113
114 auto aa = allows_attr_t {false};
115 aa.po_sum = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
116 aa.po_eltwise = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
117 aa.po_binary = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
118
119 auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
120 auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
121
122 // default pd ctor
123 auto pd = pd_t();
124 // regular pd ctor
125 pd = pd_t(eng, pk, aalgorithm, src_md, dst_md, p.alpha, p.beta);
126 // test all pd ctors
127 test_fwd_pd_constructors<pd_t>(
128 pd, aa, pk, aalgorithm, src_md, dst_md, p.alpha, p.beta);
129 pd_fwd_hint = std::make_shared<pd_t>(pd);
130
131 EXPECT_ANY_THROW(eltwise_forward(pd, {}));
132 // default primitive ctor
133 auto eltwise = eltwise_forward();
134 // regular primitive ctor
135 eltwise = eltwise_forward(pd);
136
137 // check primitive kind is eltwise
138 ASSERT_TRUE(eltwise.get_kind() == primitive::kind::eltwise);
139 // query for descs from pd
140 const auto src_desc = pd.src_desc();
141 const auto dst_desc = pd.dst_desc();
142 // query for src_desc via exec arg
143 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
144 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
145 // query for dst_desc via exec arg
146 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
147 if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); }
148
149 // query primitive parameters
150 ASSERT_EQ(pd.get_prop_kind(), pk);
151 ASSERT_EQ(pd.get_algorithm(), aalgorithm);
152 ASSERT_EQ(pd.get_alpha(), p.alpha);
153 ASSERT_EQ(pd.get_beta(), p.beta);
154
155 // check primitive returns zero_md for all rest md
156 ASSERT_TRUE(pd.weights_desc().is_zero());
157 ASSERT_TRUE(pd.diff_src_desc().is_zero());
158 ASSERT_TRUE(pd.diff_dst_desc().is_zero());
159 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
160
161 src = test::make_memory(src_desc, eng);
162 dst = test::make_memory(dst_desc, eng);
163
164 fill_data(p.src_dt, src, 1, 1);
165 // test out-place mode
166 eltwise.execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}});
167 strm.wait();
168
169 // test in-place mode on forward
170 if (p.src_tag == p.dst_tag && p.src_dt == p.dst_dt) {
171 // TODO: add a copy of memory and result comparison with previous
172 // dst output.
173 eltwise.execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, src}});
174 strm.wait();
175 }
176 }
177
178 void Backward(algorithm aalgorithm) {
179 // eltwise specific types and values
180 using pd_t = eltwise_backward::primitive_desc;
181 using hint_pd_t = eltwise_forward::primitive_desc;
182 allows_attr_t aa {false}; // doesn't support anything
183
184 auto eng = get_test_engine();
185 auto strm = make_stream(eng);
186 auto diff_src_md = memory::desc(p.dims, p.diff_src_dt, p.diff_src_tag);
187 auto diff_dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
188 auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
189
190 // default pd ctor
191 auto pd = pd_t();
192 // regular pd ctor
193 pd = pd_t(eng, aalgorithm, diff_src_md, diff_dst_md, src_md, p.alpha,
194 p.beta, *pd_fwd_hint);
195 // test all pd ctors
196 test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa,
197 aalgorithm, diff_src_md, diff_dst_md, src_md, p.alpha, p.beta);
198
199 EXPECT_ANY_THROW(eltwise_backward(pd, {}));
200 // default primitive ctor
201 auto eltwise = eltwise_backward();
202 // regular primitive ctor
203 eltwise = eltwise_backward(pd);
204
205 // check primitive kind is eltwise
206 ASSERT_TRUE(eltwise.get_kind() == primitive::kind::eltwise);
207
208 // query for descs from pd
209 const auto diff_src_desc = pd.diff_src_desc();
210 const auto diff_dst_desc = pd.diff_dst_desc();
211 const auto src_desc = pd.src_desc();
212 const auto dst_desc = pd.dst_desc();
213 // query for diff_src_desc via exec arg
214 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
215 == diff_src_desc);
216 if (p.diff_src_tag != tag::any) {
217 ASSERT_TRUE(diff_src_md == diff_src_desc);
218 }
219 // query for diff_dst_desc via exec arg
220 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
221 == diff_dst_desc);
222 if (p.dst_tag != tag::any) {
223 ASSERT_TRUE(diff_dst_md == diff_dst_desc);
224 }
225 // query for src_desc via exec arg
226 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
227 ASSERT_TRUE((p.src_tag != tag::any && src_md == src_desc)
228 || (p.dst_tag != tag::any
229 && pd_fwd_hint.get()->dst_desc() == dst_desc));
230
231 // query primitive parameters
232 ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward_data);
233 ASSERT_EQ(pd.get_algorithm(), aalgorithm);
234 ASSERT_EQ(pd.get_alpha(), p.alpha);
235 ASSERT_EQ(pd.get_beta(), p.beta);
236
237 // check primitive returns zero_md for all rest md
238 ASSERT_TRUE(pd.weights_desc().is_zero());
239 ASSERT_TRUE(pd.dst_desc().is_zero() || pd.src_desc().is_zero());
240 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
241
242 auto diff_src = test::make_memory(diff_src_desc, eng);
243 auto diff_dst = test::make_memory(diff_dst_desc, eng);
244
245 fill_data(p.diff_src_dt, diff_dst, 2, 2);
246
247 // test out-place mode
248 eltwise.execute(strm,
249 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst},
250 {DNNL_ARG_DIFF_DST, diff_dst},
251 {DNNL_ARG_DIFF_SRC, diff_src}});
252 strm.wait();
253
254 // test in-place mode
255 if (p.dst_tag == p.diff_src_tag && p.dst_dt == p.diff_src_dt) {
256 eltwise.execute(strm,
257 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst},
258 {DNNL_ARG_DIFF_DST, diff_dst},
259 {DNNL_ARG_DIFF_SRC, diff_dst}});
260 strm.wait();
261 }
262 }
263
264 void Test() {
265 const bool is_int8 = p.src_dt == dt::s8 || p.src_dt == dt::u8;
266 std::vector<prop_kind> pks = {is_int8 ? prop_kind::forward_inference
267 : prop_kind::forward_training};
268
269 std::vector<algorithm> algs_all = {algorithm::eltwise_relu,
270 algorithm::eltwise_tanh, algorithm::eltwise_elu,
271 algorithm::eltwise_square, algorithm::eltwise_abs,
272 algorithm::eltwise_sqrt, algorithm::eltwise_swish,
273 algorithm::eltwise_linear, algorithm::eltwise_soft_relu,
274 algorithm::eltwise_mish, algorithm::eltwise_logistic,
275 algorithm::eltwise_exp, algorithm::eltwise_gelu_tanh,
276 algorithm::eltwise_gelu_erf, algorithm::eltwise_log,
277 algorithm::eltwise_clip, algorithm::eltwise_clip_v2,
278 algorithm::eltwise_pow, algorithm::eltwise_round,
279 algorithm::eltwise_hardswish, algorithm::eltwise_hardsigmoid,
280 algorithm::eltwise_relu_use_dst_for_bwd,
281 algorithm::eltwise_tanh_use_dst_for_bwd,
282 algorithm::eltwise_elu_use_dst_for_bwd,
283 algorithm::eltwise_sqrt_use_dst_for_bwd,
284 algorithm::eltwise_logistic_use_dst_for_bwd,
285 algorithm::eltwise_exp_use_dst_for_bwd,
286 algorithm::eltwise_clip_v2_use_dst_for_bwd};
287 // TODO: generalize this function.
288 if (p.src_dt != dt::f32) {
289 auto it = algs_all.begin();
290 while (true) {
291 if (*it == algorithm::eltwise_round) {
292 algs_all.erase(it);
293 break;
294 }
295 it++;
296 if (it == algs_all.end()) break;
297 }
298 }
299
300 std::vector<algorithm> algs_int8
301 = {algorithm::eltwise_relu, algorithm::eltwise_linear};
302 const auto &algs = is_int8 ? algs_int8 : algs_all;
303
304 for_(auto pk : pks)
305 for (auto alg : algs) {
306 SKIP_FOR_LOOP_CUDA(is_fwd(pk)
307 && !impl::utils::one_of(alg,
308 algorithm::eltwise_relu,
309 algorithm::eltwise_tanh,
310 algorithm::eltwise_elu,
311 algorithm::eltwise_logistic),
312 "Unsupported algorithm type for CUDA");
313 SKIP_FOR_LOOP_CUDA(alg == algorithm::eltwise_relu && p.alpha != 0.f,
314 "Unsupported combination of algorithm type and alpha "
315 "parameter for CUDA");
316
317 SKIP_FOR_LOOP_HIP(
318 !impl::utils::one_of(alg, algorithm::eltwise_relu,
319 algorithm::eltwise_tanh, algorithm::eltwise_elu,
320 algorithm::eltwise_logistic,
321 algorithm::eltwise_soft_relu,
322 algorithm::eltwise_abs),
323 "Unsupported algorithm type for HIP");
324
325 Forward(pk, alg);
326
327 bool to_continue = pk != prop_kind::forward_training
328 || p.diff_src_dt == dt::undef
329 || alg == algorithm::eltwise_round;
330 if (to_continue) continue;
331
332 SKIP_FOR_LOOP_CUDA(
333 !impl::utils::one_of(alg, algorithm::eltwise_relu),
334 "Unsupported algorithm type for CUDA");
335
336 SKIP_IF(unsupported_data_type(p.diff_src_dt),
337 "Engine does not support this data type.");
338 SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_src_tag),
339 "Unsupported format tag");
340 SKIP_IF_HIP(!hip_check_format_tag(p.diff_src_tag),
341 "Unsupported format tag");
342
343 SKIP_IF_CUDA(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef
344 && p.diff_src_dt != dt::undef,
345 "Unsupported different data types for diff_source and "
346 "diff_destination");
347 SKIP_IF_HIP(p.src_dt != p.diff_src_dt && p.src_dt != dt::undef
348 && p.diff_src_dt != dt::undef,
349 "Unsupported different data types for diff_source and "
350 "diff_destination");
351
352 SKIP_IF_CUDA(p.src_tag != p.diff_src_tag && p.src_tag != tag::any
353 && p.diff_src_tag != tag::any,
354 "Unsupported different memory formats for diff_source "
355 "and "
356 "diff_destination");
357 SKIP_IF_HIP(p.src_tag != p.diff_src_tag && p.src_tag != tag::any
358 && p.diff_src_tag != tag::any,
359 "Unsupported different memory formats for diff_source "
360 "and diff_destination");
361
362 Backward(alg);
363 }
364 }
365
366 bool is_fwd(prop_kind pk) const {
367 return pk == prop_kind::forward_training
368 || pk == prop_kind::forward_inference;
369 }
370};
371
372using tp = eltwise_test_params_t;
373
374TEST_P(eltwise_test_t, TestsEltwise) {}
375
376INSTANTIATE_TEST_SUITE_P(Test_Eltwise_EF, eltwise_test_t,
377 ::testing::Values(
378 // Negative dims
379 tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nchw,
380 tag::undef, {2, -2, 128, 256}, 1.f, 2.f, true,
381 dnnl_invalid_arguments},
382 // Tag for src on forward is not specified
383 tp {dt::f32, dt::f32, dt::undef, tag::any, tag::nchw,
384 tag::undef, {2, 2, 128, 256}, 1.f, 2.f, true,
385 dnnl_invalid_arguments},
386 // Tag for src on backward is not specified
387 tp {dt::f32, dt::f32, dt::f32, tag::any, tag::nchw, tag::nchw,
388 {2, 2, 128, 256}, 1.f, 2.f, true,
389 dnnl_invalid_arguments},
390 // Data type for src is not specified
391 tp {dt::undef, dt::f32, dt::undef, tag::nchw, tag::nchw,
392 tag::undef, {2, 2, 128, 256}, 1.f, 2.f, true,
393 dnnl_invalid_arguments},
394 // Different data types are not supported
395 tp {dt::f32, dt::bf16, dt::undef, tag::nchw, tag::nchw,
396 tag::undef, {2, 2, 128, 256}, 1.f, 2.f, true,
397 dnnl_unimplemented},
398 // Different data types are not supported
399 tp {dt::f32, dt::bf16, dt::f32, tag::nchw, tag::nchw, tag::nchw,
400 {2, 2, 128, 256}, 1.f, 2.f, true, dnnl_unimplemented},
401 // Different memory formats are not supported
402 tp {dt::f32, dt::f32, dt::undef, tag::nchw, tag::nhwc,
403 tag::undef, {2, 2, 128, 256}, 1.f, 2.f, true,
404 dnnl_unimplemented},
405 // Different memory formats are not supported
406 tp {dt::f32, dt::f32, dt::f32, tag::nchw, tag::nhwc, tag::nchw,
407 {2, 2, 128, 256}, 1.f, 2.f, true, dnnl_unimplemented}));
408
409static auto all_cases = [](dt src_dt, dt dst_dt, dt diff_src_dt) {
410 return ::testing::Values(tp {src_dt, dst_dt, diff_src_dt, tag::nwc,
411 tag::nwc, tag::nwc, {2, 16, 10}, 0.f, 0.f},
412 tp {src_dt, dst_dt, diff_src_dt, tag::ncw, tag::ncw, tag::ncw,
413 {2, 64, 27}, 1.f, 2.f},
414 tp {src_dt, dst_dt, diff_src_dt, tag::nhwc, tag::nhwc, tag::nhwc,
415 {2, 16, 10, 8}, 0.f, 0.9f},
416 tp {src_dt, dst_dt, diff_src_dt, tag::nchw, tag::nchw, tag::nchw,
417 {2, 64, 27, 27}, 1.f, 2.f},
418 tp {src_dt, dst_dt, diff_src_dt, tag::nChw8c, tag::nChw8c,
419 tag::nChw8c, {2, 16, 16, 8}, 0.1f, 0.9f},
420 tp {src_dt, dst_dt, diff_src_dt, tag::nChw16c, tag::nChw16c,
421 tag::nChw16c, {2, 16, 4, 4}, 0.f, 0.f},
422 tp {src_dt, dst_dt, diff_src_dt, tag::ncdhw, tag::ncdhw, tag::ncdhw,
423 {2, 64, 7, 7, 7}, 1.f, 1.f},
424 tp {src_dt, dst_dt, diff_src_dt, tag::ncdhw, tag::ncdhw, tag::ncdhw,
425 {10, 10, 10, 10, 10}, 0.f, 0.f},
426 tp {src_dt, dst_dt, diff_src_dt, tag::nCdhw16c, tag::nCdhw16c,
427 tag::nCdhw16c, {4, 15, 2, 2, 2}, 0.1f, 0.2f});
428};
429
430#define EXPAND_DTS(src, dst, diff_src) \
431 memory::data_type::src, memory::data_type::dst, memory::data_type::diff_src
432
433#define INST_TEST_CASE(name, suite, ...) \
434 INSTANTIATE_TEST_SUITE_P(name, eltwise_test_t, suite(__VA_ARGS__));
435
436#define CPU_INST_TEST_CASE(name, suite, ...) \
437 CPU_INSTANTIATE_TEST_SUITE_P(name, eltwise_test_t, suite(__VA_ARGS__));
438
439#define GPU_INST_TEST_CASE(name, suite, ...) \
440 GPU_INSTANTIATE_TEST_SUITE_P(name, eltwise_test_t, suite(__VA_ARGS__));
441
442INST_TEST_CASE(EltwiseSimpleF32, all_cases, EXPAND_DTS(f32, f32, f32));
443INST_TEST_CASE(EltwiseSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16, bf16));
444INST_TEST_CASE(EltwiseSimpleF16, all_cases, EXPAND_DTS(f16, f16, undef));
445INST_TEST_CASE(EltwiseSimpleU8, all_cases, EXPAND_DTS(u8, u8, undef));
446} // namespace dnnl
447