1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3* Copyright 2021 Alanna Tempest
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include "dnnl_test_common.hpp"
19#include "gtest/gtest.h"
20
21#include "oneapi/dnnl/dnnl.hpp"
22
23namespace dnnl {
24
25TEST(TestsbinaryStride, StrideZero) {
26 engine eng(engine::kind::cpu, 0);
27 auto strm = make_stream(eng);
28
29 auto dims = memory::dims {2, 3};
30 std::vector<float> lhs = {0, 0, 0, 0, 0, 0}; // [[0 0 0][0 0 0]]
31 std::vector<float> rhs = {1, 2, 3, 4, 5, 6}; // [[1 2 3][4 5 6]]
32 std::vector<float> res = {0, 0, 0, 0, 0, 0};
33 std::vector<float> expected_result = {1, 1, 1, 4, 4, 4};
34
35 const auto dst_md = memory::desc(
36 dims, memory::data_type::f32, memory::format_tag::ab);
37
38 const auto src0_md
39 = memory::desc(dims, memory::data_type::f32, memory::dims {3, 1});
40 const auto src1_md
41 = memory::desc(dims, memory::data_type::f32, memory::dims {3, 0});
42 const auto src0_mem = memory(src0_md, eng, lhs.data());
43 const auto src1_mem = memory(src1_md, eng, rhs.data());
44 const auto dst_mem = memory(dst_md, eng, res.data());
45
46 const auto pd = binary::primitive_desc(
47 eng, algorithm::binary_add, src0_md, src1_md, dst_md);
48
49 const auto prim = binary(pd);
50 prim.execute(strm,
51 {{DNNL_ARG_SRC_0, src0_mem}, {DNNL_ARG_SRC_1, src1_mem},
52 {DNNL_ARG_DST, dst_mem}});
53 strm.wait();
54
55 auto correct_result = [&]() {
56 size_t nelems = res.size();
57 for (size_t i = 0; i < nelems; i++)
58 if (res.at(i) != expected_result.at(i)) return false;
59 return true;
60 };
61
62 ASSERT_TRUE(correct_result());
63}
64
65} // namespace dnnl
66