1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24using tag = memory::format_tag;
25using data_type = memory::data_type;
26
27struct binary_test_params_t {
28 std::vector<tag> srcs_format;
29 tag dst_format;
30 algorithm aalgorithm;
31 memory::dims dims;
32 bool expect_to_fail;
33 dnnl_status_t expected_status;
34};
35
36template <typename src0_data_t, typename src1_data_t = src0_data_t,
37 typename dst_data_t = src0_data_t>
38class binary_test_t : public ::testing::TestWithParam<binary_test_params_t> {
39private:
40 binary_test_params_t p;
41 data_type src0_dt, src1_dt, dst_dt;
42
43protected:
44 void SetUp() override {
45 src0_dt = data_traits<src0_data_t>::data_type;
46 src1_dt = data_traits<src1_data_t>::data_type;
47 dst_dt = data_traits<dst_data_t>::data_type;
48
49 p = ::testing::TestWithParam<binary_test_params_t>::GetParam();
50
51 SKIP_IF(unsupported_data_type(src0_dt),
52 "Engine does not support this data type.");
53
54 SKIP_IF(unsupported_data_type(src1_dt),
55 "Engine does not support this data type.");
56
57 SKIP_IF(unsupported_data_type(dst_dt),
58 "Engine does not support this data type.");
59
60 SKIP_IF_CUDA(
61 !cuda_check_data_types_combination(src0_dt, src1_dt, dst_dt),
62 "Engine does not support this data type combination.");
63 SKIP_IF_HIP(!hip_check_data_types_combination(src0_dt, src1_dt, dst_dt),
64 "Engine does not support this data type combination.");
65
66 for (auto tag : p.srcs_format) {
67 MAYBE_UNUSED(tag);
68 SKIP_IF_CUDA(!cuda_check_format_tag(tag),
69 "Unsupported source format tag");
70 SKIP_IF_HIP(!hip_check_format_tag(tag),
71 "Unsupported source format tag");
72 }
73 SKIP_IF_CUDA(!cuda_check_format_tag(p.dst_format),
74 "Unsupported destination format tag");
75
76 SKIP_IF_HIP(!hip_check_format_tag(p.dst_format),
77 "Unsupported destination format tag");
78 catch_expected_failures(
79 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
80 }
81
82 bool cuda_check_data_types_combination(
83 data_type src0_dt, data_type src1_dt, data_type dst_dt) {
84 bool correct_input_dt = src0_dt == data_type::f32 || src0_dt == dst_dt
85 || dst_dt == data_type::f32;
86 bool inputs_same_dt = src0_dt == src1_dt;
87
88 return inputs_same_dt && correct_input_dt;
89 }
90 bool hip_check_data_types_combination(
91 data_type src0_dt, data_type src1_dt, data_type dst_dt) {
92 bool correct_input_dt = src0_dt == dst_dt
93 && (dst_dt == data_type::f32 || dst_dt == data_type::f16
94 || dst_dt == data_type::s32);
95 bool inputs_same_dt = src0_dt == src1_dt;
96
97 return inputs_same_dt && correct_input_dt;
98 }
99
100 bool cuda_check_format_tag(tag atag) {
101 return atag == tag::abcd || atag == tag::acdb;
102 }
103 bool hip_check_format_tag(tag atag) { return atag == tag::abcd; }
104
105 void Test() {
106 auto eng = get_test_engine();
107 auto strm = make_stream(eng);
108
109 // binary specific types and values
110 using pd_t = binary::primitive_desc;
111 allows_attr_t aa {false};
112 aa.scales = true;
113 aa.po_sum = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
114 aa.po_eltwise = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
115 aa.po_binary = !is_nvidia_gpu(eng) && !is_amd_gpu(eng);
116 std::vector<memory::desc> srcs_md;
117 std::vector<memory> srcs;
118
119 for (int i_case = 0;; ++i_case) {
120 memory::dims dims_B = p.dims;
121 if (i_case == 0) {
122 } else if (i_case == 1) {
123 dims_B[0] = 1;
124 } else if (i_case == 2) {
125 dims_B[1] = 1;
126 dims_B[2] = 1;
127 } else if (i_case == 3) {
128 dims_B[0] = 1;
129 dims_B[2] = 1;
130 dims_B[3] = 1;
131 } else if (i_case == 4) {
132 dims_B[0] = 1;
133 dims_B[1] = 1;
134 dims_B[2] = 1;
135 dims_B[3] = 1;
136 } else {
137 break;
138 }
139
140 auto desc_A = memory::desc(p.dims, src0_dt, p.srcs_format[0]);
141 // TODO: try to fit "reshape" logic here.
142 auto desc_B = memory::desc(dims_B, src1_dt, p.srcs_format[1]);
143 auto desc_C = memory::desc(p.dims, dst_dt, p.dst_format);
144
145 const dnnl::impl::memory_desc_wrapper mdw_desc_A(desc_A.get());
146 const bool has_zero_dim = mdw_desc_A.has_zero_dim();
147
148 // default pd ctor
149 auto pd = pd_t();
150 // regular pd ctor
151 pd = pd_t(eng, p.aalgorithm, desc_A, desc_B, desc_C);
152 // test all pd ctors
153 if (!has_zero_dim)
154 test_fwd_pd_constructors<pd_t>(
155 pd, aa, p.aalgorithm, desc_A, desc_B, desc_C);
156 // test non-md query interfaces
157 ASSERT_EQ(pd.get_algorithm(), p.aalgorithm);
158
159 EXPECT_ANY_THROW(binary(pd, {}));
160 // default primitive ctor
161 auto prim = binary();
162 // regular primitive ctor
163 prim = binary(pd);
164
165 // query for descs from pd
166 const auto src0_desc = pd.src_desc(0);
167 const auto src1_desc = pd.src_desc(1);
168 const auto dst_desc = pd.dst_desc();
169 const auto workspace_desc = pd.workspace_desc();
170
171 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_0)
172 == src0_desc);
173 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_1)
174 == src1_desc);
175 ASSERT_TRUE(
176 pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
177
178 // check primitive returns zero_md for all rest md
179 ASSERT_TRUE(pd.weights_desc().is_zero());
180 ASSERT_TRUE(pd.diff_src_desc().is_zero());
181 ASSERT_TRUE(pd.diff_dst_desc().is_zero());
182 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
183
184 const auto test_engine = pd.get_engine();
185
186 auto mem_A = test::make_memory(src0_desc, test_engine);
187 auto mem_B = test::make_memory(src1_desc, test_engine);
188 auto mem_C = test::make_memory(dst_desc, test_engine);
189 auto mem_ws = test::make_memory(workspace_desc, test_engine);
190
191 fill_data<src0_data_t>(
192 src0_desc.get_size() / sizeof(src0_data_t), mem_A);
193 fill_data<src1_data_t>(
194 src1_desc.get_size() / sizeof(src1_data_t), mem_B);
195 // Remove zeroes in src1 to avoid division by zero
196 remove_zeroes<src1_data_t>(mem_B);
197
198 prim.execute(strm,
199 {{DNNL_ARG_SRC_0, mem_A}, {DNNL_ARG_SRC_1, mem_B},
200 {DNNL_ARG_DST, mem_C},
201 {DNNL_ARG_WORKSPACE, mem_ws}});
202 strm.wait();
203 }
204 }
205};
206
207struct binary_attr_test_t
208 : public ::testing::TestWithParam<
209 std::tuple<memory::dims, memory::dims, memory::format_tag>> {};
210
211HANDLE_EXCEPTIONS_FOR_TEST_P(
212 binary_attr_test_t, TestBinaryShouldCallSameImplementationWithPostops) {
213 auto engine_kind = get_test_engine_kind();
214 SKIP_IF(!DNNL_X64 || engine_kind != engine::kind::cpu,
215 "Binary impl_info_str should be same only on x64 CPU");
216 engine e {engine_kind, 0};
217
218 std::vector<memory::data_type> test_dts {
219 memory::data_type::f32, memory::data_type::s8};
220
221 if (!unsupported_data_type(memory::data_type::bf16))
222 test_dts.emplace_back(memory::data_type::bf16);
223
224 if (!unsupported_data_type(memory::data_type::f16))
225 test_dts.emplace_back(memory::data_type::f16);
226
227 for (auto dt : test_dts) {
228 const auto &binary_tensor_dims = std::get<0>(GetParam());
229 const auto format_tag = std::get<2>(GetParam());
230
231 const memory::desc src_0_md {binary_tensor_dims, dt, format_tag};
232 const memory::desc src_1_md {binary_tensor_dims, dt, format_tag};
233 const memory::desc dst_md {binary_tensor_dims, dt, format_tag};
234
235 std::string impl_info_no_postops;
236
237 auto pd = binary::primitive_desc(
238 e, algorithm::binary_mul, src_0_md, src_1_md, dst_md);
239 ASSERT_NO_THROW(impl_info_no_postops = pd.impl_info_str(););
240
241 dnnl::primitive_attr attr;
242 const float alpha = 1.f;
243 const float beta = 1.f;
244 dnnl::post_ops ops;
245
246 ops.append_sum(1.0);
247
248 ops.append_eltwise(algorithm::eltwise_relu, alpha, beta);
249
250 const auto &binary_po_tensor_dims = std::get<1>(GetParam());
251 memory::desc src1_po_md(
252 binary_po_tensor_dims, data_type::f32, format_tag);
253 ops.append_binary(algorithm::binary_add, src1_po_md);
254
255 attr.set_post_ops(ops);
256
257 std::string impl_info_with_postops;
258
259 pd = binary::primitive_desc(
260 e, algorithm::binary_mul, src_0_md, src_1_md, dst_md, attr);
261 ASSERT_NO_THROW(impl_info_with_postops = pd.impl_info_str(););
262 ASSERT_EQ(impl_info_no_postops, impl_info_with_postops);
263 }
264}
265
266INSTANTIATE_TEST_SUITE_P(BinaryTensorDims, binary_attr_test_t,
267 ::testing::Values(
268 // {{src0, src1, dst same_dim}, { binary post-op dim }}
269 std::make_tuple(memory::dims {1, 1024}, memory::dims {1, 1024},
270 memory::format_tag::ab),
271 std::make_tuple(memory::dims {1, 1024, 1},
272 memory::dims {1, 1024, 1}, memory::format_tag::abc),
273 std::make_tuple(memory::dims {1, 1024, 17},
274 memory::dims {1, 1024, 1}, memory::format_tag::abc),
275 std::make_tuple(memory::dims {10, 1024, 17, 17},
276 memory::dims {1, 1024, 1, 1},
277 memory::format_tag::abcd)));
278
279static auto expected_failures = []() {
280 return ::testing::Values(
281 // test tag::any support
282 binary_test_params_t {{tag::any, tag::nchw}, tag::nchw,
283 algorithm::binary_add, {8, 7, 6, 5}, true,
284 dnnl_invalid_arguments},
285 // not supported alg_kind
286 binary_test_params_t {{tag::nchw, tag::nchw}, tag::nchw,
287 algorithm::eltwise_relu, {1, 8, 4, 4}, true,
288 dnnl_invalid_arguments},
289 // negative dim
290 binary_test_params_t {{tag::nchw, tag::nchw}, tag::nchw,
291 algorithm::binary_div, {-1, 8, 4, 4}, true,
292 dnnl_invalid_arguments});
293};
294
295static auto zero_dim = []() {
296 return ::testing::Values(
297 binary_test_params_t {{tag::nchw, tag::nchw}, tag::nchw,
298 algorithm::binary_add, {0, 7, 6, 5}},
299 binary_test_params_t {{tag::nChw8c, tag::nhwc}, tag::nChw8c,
300 algorithm::binary_mul, {5, 0, 7, 6}},
301 binary_test_params_t {{tag::nChw16c, tag::nchw}, tag::nChw16c,
302 algorithm::binary_div, {8, 15, 0, 5}},
303 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc,
304 algorithm::binary_mul, {5, 16, 7, 0}},
305 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc,
306 algorithm::binary_sub, {4, 0, 7, 5}},
307 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc,
308 algorithm::binary_ge, {4, 16, 7, 0}},
309 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc,
310 algorithm::binary_gt, {4, 16, 7, 0}},
311 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc,
312 algorithm::binary_le, {4, 16, 7, 0}},
313 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc,
314 algorithm::binary_lt, {4, 16, 7, 0}},
315 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc,
316 algorithm::binary_eq, {4, 16, 7, 0}},
317 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc,
318 algorithm::binary_ne, {4, 16, 7, 0}});
319};
320
321static auto simple_cases = []() {
322 return ::testing::Values(
323 binary_test_params_t {{tag::nchw, tag::nchw}, tag::nchw,
324 algorithm::binary_add, {8, 7, 6, 5}},
325 binary_test_params_t {{tag::nhwc, tag::nhwc}, tag::nhwc,
326 algorithm::binary_mul, {5, 8, 7, 6}},
327 binary_test_params_t {{tag::nChw8c, tag::nchw}, tag::nChw8c,
328 algorithm::binary_max, {8, 15, 6, 5}},
329 binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::any,
330 algorithm::binary_min, {5, 16, 7, 6}},
331 binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any,
332 algorithm::binary_div, {5, 16, 8, 7}},
333 binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any,
334 algorithm::binary_sub, {5, 16, 8, 7}},
335 binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any,
336 algorithm::binary_ge, {5, 16, 8, 7}},
337 binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any,
338 algorithm::binary_gt, {5, 16, 8, 7}},
339 binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any,
340 algorithm::binary_le, {5, 16, 8, 7}},
341 binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any,
342 algorithm::binary_lt, {5, 16, 8, 7}},
343 binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any,
344 algorithm::binary_eq, {5, 16, 8, 7}},
345 binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any,
346 algorithm::binary_ne, {5, 16, 8, 7}});
347};
348
349#define INST_TEST_CASE(test) \
350 TEST_P(test, Testsbinary) {} \
351 INSTANTIATE_TEST_SUITE_P(TestbinaryEF, test, expected_failures()); \
352 INSTANTIATE_TEST_SUITE_P(TestbinaryZero, test, zero_dim()); \
353 INSTANTIATE_TEST_SUITE_P(TestbinarySimple, test, simple_cases());
354
355using binary_test_f32 = binary_test_t<float>;
356using binary_test_bf16 = binary_test_t<bfloat16_t>;
357using binary_test_f16 = binary_test_t<float16_t>;
358using binary_test_s8 = binary_test_t<int8_t>;
359using binary_test_u8 = binary_test_t<uint8_t>;
360using binary_test_s8u8s8 = binary_test_t<int8_t, uint8_t, int8_t>;
361using binary_test_u8s8u8 = binary_test_t<uint8_t, int8_t, uint8_t>;
362using binary_test_u8s8s8 = binary_test_t<uint8_t, int8_t, int8_t>;
363using binary_test_s8u8u8 = binary_test_t<int8_t, uint8_t, uint8_t>;
364using binary_test_s8f32u8 = binary_test_t<int8_t, float, uint8_t>;
365using binary_test_s8f32s8 = binary_test_t<int8_t, float, int8_t>;
366using binary_test_f32u8s8 = binary_test_t<float, uint8_t, int8_t>;
367using binary_test_f32f32u8 = binary_test_t<float, float, uint8_t>;
368
369INST_TEST_CASE(binary_test_f32)
370INST_TEST_CASE(binary_test_bf16)
371INST_TEST_CASE(binary_test_f16)
372INST_TEST_CASE(binary_test_s8)
373INST_TEST_CASE(binary_test_u8)
374INST_TEST_CASE(binary_test_s8u8s8)
375INST_TEST_CASE(binary_test_u8s8u8)
376INST_TEST_CASE(binary_test_u8s8s8)
377INST_TEST_CASE(binary_test_s8u8u8)
378INST_TEST_CASE(binary_test_s8f32u8)
379INST_TEST_CASE(binary_test_s8f32s8)
380INST_TEST_CASE(binary_test_f32u8s8)
381INST_TEST_CASE(binary_test_f32f32u8)
382
383} // namespace dnnl
384