1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * Copyright 2021 Alanna Tempest |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #include "dnnl_test_common.hpp" |
19 | #include "gtest/gtest.h" |
20 | |
21 | #include "oneapi/dnnl/dnnl.hpp" |
22 | |
23 | namespace dnnl { |
24 | |
25 | TEST(TestsbinaryStride, StrideZero) { |
26 | engine eng(engine::kind::cpu, 0); |
27 | auto strm = make_stream(eng); |
28 | |
29 | auto dims = memory::dims {2, 3}; |
30 | std::vector<float> lhs = {0, 0, 0, 0, 0, 0}; // [[0 0 0][0 0 0]] |
31 | std::vector<float> rhs = {1, 2, 3, 4, 5, 6}; // [[1 2 3][4 5 6]] |
32 | std::vector<float> res = {0, 0, 0, 0, 0, 0}; |
33 | std::vector<float> expected_result = {1, 1, 1, 4, 4, 4}; |
34 | |
35 | const auto dst_md = memory::desc( |
36 | dims, memory::data_type::f32, memory::format_tag::ab); |
37 | |
38 | const auto src0_md |
39 | = memory::desc(dims, memory::data_type::f32, memory::dims {3, 1}); |
40 | const auto src1_md |
41 | = memory::desc(dims, memory::data_type::f32, memory::dims {3, 0}); |
42 | const auto src0_mem = memory(src0_md, eng, lhs.data()); |
43 | const auto src1_mem = memory(src1_md, eng, rhs.data()); |
44 | const auto dst_mem = memory(dst_md, eng, res.data()); |
45 | |
46 | const auto pd = binary::primitive_desc( |
47 | eng, algorithm::binary_add, src0_md, src1_md, dst_md); |
48 | |
49 | const auto prim = binary(pd); |
50 | prim.execute(strm, |
51 | {{DNNL_ARG_SRC_0, src0_mem}, {DNNL_ARG_SRC_1, src1_mem}, |
52 | {DNNL_ARG_DST, dst_mem}}); |
53 | strm.wait(); |
54 | |
55 | auto correct_result = [&]() { |
56 | size_t nelems = res.size(); |
57 | for (size_t i = 0; i < nelems; i++) |
58 | if (res.at(i) != expected_result.at(i)) return false; |
59 | return true; |
60 | }; |
61 | |
62 | ASSERT_TRUE(correct_result()); |
63 | } |
64 | |
65 | } // namespace dnnl |
66 | |