1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | using tag = memory::format_tag; |
25 | using data_type = memory::data_type; |
26 | |
27 | struct binary_test_params_t { |
28 | std::vector<tag> srcs_format; |
29 | tag dst_format; |
30 | algorithm aalgorithm; |
31 | memory::dims dims; |
32 | bool expect_to_fail; |
33 | dnnl_status_t expected_status; |
34 | }; |
35 | |
36 | template <typename src0_data_t, typename src1_data_t = src0_data_t, |
37 | typename dst_data_t = src0_data_t> |
38 | class binary_test_t : public ::testing::TestWithParam<binary_test_params_t> { |
39 | private: |
40 | binary_test_params_t p; |
41 | data_type src0_dt, src1_dt, dst_dt; |
42 | |
43 | protected: |
44 | void SetUp() override { |
45 | src0_dt = data_traits<src0_data_t>::data_type; |
46 | src1_dt = data_traits<src1_data_t>::data_type; |
47 | dst_dt = data_traits<dst_data_t>::data_type; |
48 | |
49 | p = ::testing::TestWithParam<binary_test_params_t>::GetParam(); |
50 | |
51 | SKIP_IF(unsupported_data_type(src0_dt), |
52 | "Engine does not support this data type." ); |
53 | |
54 | SKIP_IF(unsupported_data_type(src1_dt), |
55 | "Engine does not support this data type." ); |
56 | |
57 | SKIP_IF(unsupported_data_type(dst_dt), |
58 | "Engine does not support this data type." ); |
59 | |
60 | SKIP_IF_CUDA( |
61 | !cuda_check_data_types_combination(src0_dt, src1_dt, dst_dt), |
62 | "Engine does not support this data type combination." ); |
63 | SKIP_IF_HIP(!hip_check_data_types_combination(src0_dt, src1_dt, dst_dt), |
64 | "Engine does not support this data type combination." ); |
65 | |
66 | for (auto tag : p.srcs_format) { |
67 | MAYBE_UNUSED(tag); |
68 | SKIP_IF_CUDA(!cuda_check_format_tag(tag), |
69 | "Unsupported source format tag" ); |
70 | SKIP_IF_HIP(!hip_check_format_tag(tag), |
71 | "Unsupported source format tag" ); |
72 | } |
73 | SKIP_IF_CUDA(!cuda_check_format_tag(p.dst_format), |
74 | "Unsupported destination format tag" ); |
75 | |
76 | SKIP_IF_HIP(!hip_check_format_tag(p.dst_format), |
77 | "Unsupported destination format tag" ); |
78 | catch_expected_failures( |
79 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
80 | } |
81 | |
82 | bool cuda_check_data_types_combination( |
83 | data_type src0_dt, data_type src1_dt, data_type dst_dt) { |
84 | bool correct_input_dt = src0_dt == data_type::f32 || src0_dt == dst_dt |
85 | || dst_dt == data_type::f32; |
86 | bool inputs_same_dt = src0_dt == src1_dt; |
87 | |
88 | return inputs_same_dt && correct_input_dt; |
89 | } |
90 | bool hip_check_data_types_combination( |
91 | data_type src0_dt, data_type src1_dt, data_type dst_dt) { |
92 | bool correct_input_dt = src0_dt == dst_dt |
93 | && (dst_dt == data_type::f32 || dst_dt == data_type::f16 |
94 | || dst_dt == data_type::s32); |
95 | bool inputs_same_dt = src0_dt == src1_dt; |
96 | |
97 | return inputs_same_dt && correct_input_dt; |
98 | } |
99 | |
100 | bool cuda_check_format_tag(tag atag) { |
101 | return atag == tag::abcd || atag == tag::acdb; |
102 | } |
103 | bool hip_check_format_tag(tag atag) { return atag == tag::abcd; } |
104 | |
105 | void Test() { |
106 | auto eng = get_test_engine(); |
107 | auto strm = make_stream(eng); |
108 | |
109 | // binary specific types and values |
110 | using pd_t = binary::primitive_desc; |
111 | allows_attr_t aa {false}; |
112 | aa.scales = true; |
113 | aa.po_sum = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); |
114 | aa.po_eltwise = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); |
115 | aa.po_binary = !is_nvidia_gpu(eng) && !is_amd_gpu(eng); |
116 | std::vector<memory::desc> srcs_md; |
117 | std::vector<memory> srcs; |
118 | |
119 | for (int i_case = 0;; ++i_case) { |
120 | memory::dims dims_B = p.dims; |
121 | if (i_case == 0) { |
122 | } else if (i_case == 1) { |
123 | dims_B[0] = 1; |
124 | } else if (i_case == 2) { |
125 | dims_B[1] = 1; |
126 | dims_B[2] = 1; |
127 | } else if (i_case == 3) { |
128 | dims_B[0] = 1; |
129 | dims_B[2] = 1; |
130 | dims_B[3] = 1; |
131 | } else if (i_case == 4) { |
132 | dims_B[0] = 1; |
133 | dims_B[1] = 1; |
134 | dims_B[2] = 1; |
135 | dims_B[3] = 1; |
136 | } else { |
137 | break; |
138 | } |
139 | |
140 | auto desc_A = memory::desc(p.dims, src0_dt, p.srcs_format[0]); |
141 | // TODO: try to fit "reshape" logic here. |
142 | auto desc_B = memory::desc(dims_B, src1_dt, p.srcs_format[1]); |
143 | auto desc_C = memory::desc(p.dims, dst_dt, p.dst_format); |
144 | |
145 | const dnnl::impl::memory_desc_wrapper mdw_desc_A(desc_A.get()); |
146 | const bool has_zero_dim = mdw_desc_A.has_zero_dim(); |
147 | |
148 | // default pd ctor |
149 | auto pd = pd_t(); |
150 | // regular pd ctor |
151 | pd = pd_t(eng, p.aalgorithm, desc_A, desc_B, desc_C); |
152 | // test all pd ctors |
153 | if (!has_zero_dim) |
154 | test_fwd_pd_constructors<pd_t>( |
155 | pd, aa, p.aalgorithm, desc_A, desc_B, desc_C); |
156 | // test non-md query interfaces |
157 | ASSERT_EQ(pd.get_algorithm(), p.aalgorithm); |
158 | |
159 | EXPECT_ANY_THROW(binary(pd, {})); |
160 | // default primitive ctor |
161 | auto prim = binary(); |
162 | // regular primitive ctor |
163 | prim = binary(pd); |
164 | |
165 | // query for descs from pd |
166 | const auto src0_desc = pd.src_desc(0); |
167 | const auto src1_desc = pd.src_desc(1); |
168 | const auto dst_desc = pd.dst_desc(); |
169 | const auto workspace_desc = pd.workspace_desc(); |
170 | |
171 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_0) |
172 | == src0_desc); |
173 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC_1) |
174 | == src1_desc); |
175 | ASSERT_TRUE( |
176 | pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
177 | |
178 | // check primitive returns zero_md for all rest md |
179 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
180 | ASSERT_TRUE(pd.diff_src_desc().is_zero()); |
181 | ASSERT_TRUE(pd.diff_dst_desc().is_zero()); |
182 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
183 | |
184 | const auto test_engine = pd.get_engine(); |
185 | |
186 | auto mem_A = test::make_memory(src0_desc, test_engine); |
187 | auto mem_B = test::make_memory(src1_desc, test_engine); |
188 | auto mem_C = test::make_memory(dst_desc, test_engine); |
189 | auto mem_ws = test::make_memory(workspace_desc, test_engine); |
190 | |
191 | fill_data<src0_data_t>( |
192 | src0_desc.get_size() / sizeof(src0_data_t), mem_A); |
193 | fill_data<src1_data_t>( |
194 | src1_desc.get_size() / sizeof(src1_data_t), mem_B); |
195 | // Remove zeroes in src1 to avoid division by zero |
196 | remove_zeroes<src1_data_t>(mem_B); |
197 | |
198 | prim.execute(strm, |
199 | {{DNNL_ARG_SRC_0, mem_A}, {DNNL_ARG_SRC_1, mem_B}, |
200 | {DNNL_ARG_DST, mem_C}, |
201 | {DNNL_ARG_WORKSPACE, mem_ws}}); |
202 | strm.wait(); |
203 | } |
204 | } |
205 | }; |
206 | |
207 | struct binary_attr_test_t |
208 | : public ::testing::TestWithParam< |
209 | std::tuple<memory::dims, memory::dims, memory::format_tag>> {}; |
210 | |
211 | HANDLE_EXCEPTIONS_FOR_TEST_P( |
212 | binary_attr_test_t, TestBinaryShouldCallSameImplementationWithPostops) { |
213 | auto engine_kind = get_test_engine_kind(); |
214 | SKIP_IF(!DNNL_X64 || engine_kind != engine::kind::cpu, |
215 | "Binary impl_info_str should be same only on x64 CPU" ); |
216 | engine e {engine_kind, 0}; |
217 | |
218 | std::vector<memory::data_type> test_dts { |
219 | memory::data_type::f32, memory::data_type::s8}; |
220 | |
221 | if (!unsupported_data_type(memory::data_type::bf16)) |
222 | test_dts.emplace_back(memory::data_type::bf16); |
223 | |
224 | if (!unsupported_data_type(memory::data_type::f16)) |
225 | test_dts.emplace_back(memory::data_type::f16); |
226 | |
227 | for (auto dt : test_dts) { |
228 | const auto &binary_tensor_dims = std::get<0>(GetParam()); |
229 | const auto format_tag = std::get<2>(GetParam()); |
230 | |
231 | const memory::desc src_0_md {binary_tensor_dims, dt, format_tag}; |
232 | const memory::desc src_1_md {binary_tensor_dims, dt, format_tag}; |
233 | const memory::desc dst_md {binary_tensor_dims, dt, format_tag}; |
234 | |
235 | std::string impl_info_no_postops; |
236 | |
237 | auto pd = binary::primitive_desc( |
238 | e, algorithm::binary_mul, src_0_md, src_1_md, dst_md); |
239 | ASSERT_NO_THROW(impl_info_no_postops = pd.impl_info_str();); |
240 | |
241 | dnnl::primitive_attr attr; |
242 | const float alpha = 1.f; |
243 | const float beta = 1.f; |
244 | dnnl::post_ops ops; |
245 | |
246 | ops.append_sum(1.0); |
247 | |
248 | ops.append_eltwise(algorithm::eltwise_relu, alpha, beta); |
249 | |
250 | const auto &binary_po_tensor_dims = std::get<1>(GetParam()); |
251 | memory::desc src1_po_md( |
252 | binary_po_tensor_dims, data_type::f32, format_tag); |
253 | ops.append_binary(algorithm::binary_add, src1_po_md); |
254 | |
255 | attr.set_post_ops(ops); |
256 | |
257 | std::string impl_info_with_postops; |
258 | |
259 | pd = binary::primitive_desc( |
260 | e, algorithm::binary_mul, src_0_md, src_1_md, dst_md, attr); |
261 | ASSERT_NO_THROW(impl_info_with_postops = pd.impl_info_str();); |
262 | ASSERT_EQ(impl_info_no_postops, impl_info_with_postops); |
263 | } |
264 | } |
265 | |
266 | INSTANTIATE_TEST_SUITE_P(BinaryTensorDims, binary_attr_test_t, |
267 | ::testing::Values( |
268 | // {{src0, src1, dst same_dim}, { binary post-op dim }} |
269 | std::make_tuple(memory::dims {1, 1024}, memory::dims {1, 1024}, |
270 | memory::format_tag::ab), |
271 | std::make_tuple(memory::dims {1, 1024, 1}, |
272 | memory::dims {1, 1024, 1}, memory::format_tag::abc), |
273 | std::make_tuple(memory::dims {1, 1024, 17}, |
274 | memory::dims {1, 1024, 1}, memory::format_tag::abc), |
275 | std::make_tuple(memory::dims {10, 1024, 17, 17}, |
276 | memory::dims {1, 1024, 1, 1}, |
277 | memory::format_tag::abcd))); |
278 | |
279 | static auto expected_failures = []() { |
280 | return ::testing::Values( |
281 | // test tag::any support |
282 | binary_test_params_t {{tag::any, tag::nchw}, tag::nchw, |
283 | algorithm::binary_add, {8, 7, 6, 5}, true, |
284 | dnnl_invalid_arguments}, |
285 | // not supported alg_kind |
286 | binary_test_params_t {{tag::nchw, tag::nchw}, tag::nchw, |
287 | algorithm::eltwise_relu, {1, 8, 4, 4}, true, |
288 | dnnl_invalid_arguments}, |
289 | // negative dim |
290 | binary_test_params_t {{tag::nchw, tag::nchw}, tag::nchw, |
291 | algorithm::binary_div, {-1, 8, 4, 4}, true, |
292 | dnnl_invalid_arguments}); |
293 | }; |
294 | |
295 | static auto zero_dim = []() { |
296 | return ::testing::Values( |
297 | binary_test_params_t {{tag::nchw, tag::nchw}, tag::nchw, |
298 | algorithm::binary_add, {0, 7, 6, 5}}, |
299 | binary_test_params_t {{tag::nChw8c, tag::nhwc}, tag::nChw8c, |
300 | algorithm::binary_mul, {5, 0, 7, 6}}, |
301 | binary_test_params_t {{tag::nChw16c, tag::nchw}, tag::nChw16c, |
302 | algorithm::binary_div, {8, 15, 0, 5}}, |
303 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc, |
304 | algorithm::binary_mul, {5, 16, 7, 0}}, |
305 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc, |
306 | algorithm::binary_sub, {4, 0, 7, 5}}, |
307 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc, |
308 | algorithm::binary_ge, {4, 16, 7, 0}}, |
309 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc, |
310 | algorithm::binary_gt, {4, 16, 7, 0}}, |
311 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc, |
312 | algorithm::binary_le, {4, 16, 7, 0}}, |
313 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc, |
314 | algorithm::binary_lt, {4, 16, 7, 0}}, |
315 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc, |
316 | algorithm::binary_eq, {4, 16, 7, 0}}, |
317 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::nhwc, |
318 | algorithm::binary_ne, {4, 16, 7, 0}}); |
319 | }; |
320 | |
321 | static auto simple_cases = []() { |
322 | return ::testing::Values( |
323 | binary_test_params_t {{tag::nchw, tag::nchw}, tag::nchw, |
324 | algorithm::binary_add, {8, 7, 6, 5}}, |
325 | binary_test_params_t {{tag::nhwc, tag::nhwc}, tag::nhwc, |
326 | algorithm::binary_mul, {5, 8, 7, 6}}, |
327 | binary_test_params_t {{tag::nChw8c, tag::nchw}, tag::nChw8c, |
328 | algorithm::binary_max, {8, 15, 6, 5}}, |
329 | binary_test_params_t {{tag::nhwc, tag::nChw16c}, tag::any, |
330 | algorithm::binary_min, {5, 16, 7, 6}}, |
331 | binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any, |
332 | algorithm::binary_div, {5, 16, 8, 7}}, |
333 | binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any, |
334 | algorithm::binary_sub, {5, 16, 8, 7}}, |
335 | binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any, |
336 | algorithm::binary_ge, {5, 16, 8, 7}}, |
337 | binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any, |
338 | algorithm::binary_gt, {5, 16, 8, 7}}, |
339 | binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any, |
340 | algorithm::binary_le, {5, 16, 8, 7}}, |
341 | binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any, |
342 | algorithm::binary_lt, {5, 16, 8, 7}}, |
343 | binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any, |
344 | algorithm::binary_eq, {5, 16, 8, 7}}, |
345 | binary_test_params_t {{tag::nchw, tag::nChw16c}, tag::any, |
346 | algorithm::binary_ne, {5, 16, 8, 7}}); |
347 | }; |
348 | |
349 | #define INST_TEST_CASE(test) \ |
350 | TEST_P(test, Testsbinary) {} \ |
351 | INSTANTIATE_TEST_SUITE_P(TestbinaryEF, test, expected_failures()); \ |
352 | INSTANTIATE_TEST_SUITE_P(TestbinaryZero, test, zero_dim()); \ |
353 | INSTANTIATE_TEST_SUITE_P(TestbinarySimple, test, simple_cases()); |
354 | |
355 | using binary_test_f32 = binary_test_t<float>; |
356 | using binary_test_bf16 = binary_test_t<bfloat16_t>; |
357 | using binary_test_f16 = binary_test_t<float16_t>; |
358 | using binary_test_s8 = binary_test_t<int8_t>; |
359 | using binary_test_u8 = binary_test_t<uint8_t>; |
360 | using binary_test_s8u8s8 = binary_test_t<int8_t, uint8_t, int8_t>; |
361 | using binary_test_u8s8u8 = binary_test_t<uint8_t, int8_t, uint8_t>; |
362 | using binary_test_u8s8s8 = binary_test_t<uint8_t, int8_t, int8_t>; |
363 | using binary_test_s8u8u8 = binary_test_t<int8_t, uint8_t, uint8_t>; |
364 | using binary_test_s8f32u8 = binary_test_t<int8_t, float, uint8_t>; |
365 | using binary_test_s8f32s8 = binary_test_t<int8_t, float, int8_t>; |
366 | using binary_test_f32u8s8 = binary_test_t<float, uint8_t, int8_t>; |
367 | using binary_test_f32f32u8 = binary_test_t<float, float, uint8_t>; |
368 | |
369 | INST_TEST_CASE(binary_test_f32) |
370 | INST_TEST_CASE(binary_test_bf16) |
371 | INST_TEST_CASE(binary_test_f16) |
372 | INST_TEST_CASE(binary_test_s8) |
373 | INST_TEST_CASE(binary_test_u8) |
374 | INST_TEST_CASE(binary_test_s8u8s8) |
375 | INST_TEST_CASE(binary_test_u8s8u8) |
376 | INST_TEST_CASE(binary_test_u8s8s8) |
377 | INST_TEST_CASE(binary_test_s8u8u8) |
378 | INST_TEST_CASE(binary_test_s8f32u8) |
379 | INST_TEST_CASE(binary_test_s8f32s8) |
380 | INST_TEST_CASE(binary_test_f32u8s8) |
381 | INST_TEST_CASE(binary_test_f32f32u8) |
382 | |
383 | } // namespace dnnl |
384 | |