1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain src copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22#define BCAST 1
23#define NO_BCAST 8
24
25#define CASE(ndims, tag) \
26 case ndims: return memory::format_tag::tag;
27
28namespace dnnl {
29
30memory::format_tag plain_format_tag(size_t ndims) {
31 assert(ndims <= 12);
32 switch (ndims) {
33 CASE(1, a)
34 CASE(2, ab)
35 CASE(3, abc)
36 CASE(4, abcd)
37 CASE(5, abcde)
38 CASE(6, abcdef)
39 CASE(7, abcdefg)
40 CASE(8, abcdefgh)
41 CASE(9, abcdefghi)
42 CASE(10, abcdefghij)
43 CASE(11, abcdefghijk)
44 CASE(12, abcdefghijkl)
45 default: return memory::format_tag::any;
46 }
47}
48
49struct binary_bcast_test_t
50 : public ::testing::TestWithParam<
51 std::tuple<engine::kind, memory::dims, bool>> {};
52
53HANDLE_EXCEPTIONS_FOR_TEST_P(
54 binary_bcast_test_t, TestBinaryOptimizedBroadcast) {
55 auto engine_kind = std::get<0>(GetParam());
56 SKIP_IF(engine_kind != get_test_engine_kind(),
57 "Test prepared for a different engine kind");
58 SKIP_IF(!IMPLICATION(engine_kind == engine::kind::cpu, DNNL_X64),
59 "Binary impl_info_str should be same only on x64 CPU");
60 engine e {engine_kind, 0};
61
62 const auto &src1_bcast_dims = std::get<1>(GetParam());
63 const size_t ndims = src1_bcast_dims.size();
64
65 constexpr auto defualt_dt = memory::data_type::f32;
66 const auto default_format = plain_format_tag(ndims);
67 memory::dims default_dims;
68 for (size_t d = 0; d < ndims; d++)
69 default_dims.push_back(NO_BCAST);
70
71 std::string impl_info_no_bcast, impl_info_bcast;
72
73 auto src0_md = memory::desc(default_dims, defualt_dt, default_format, true);
74 auto src1_md = memory::desc(default_dims, defualt_dt, default_format, true);
75 auto dst_md = memory::desc(default_dims, defualt_dt, default_format, true);
76
77 auto binary_pd = binary::primitive_desc(
78 e, algorithm::binary_add, src0_md, src1_md, dst_md);
79 ASSERT_NO_THROW(impl_info_no_bcast = binary_pd.impl_info_str(););
80
81 memory::desc src1_bcast_md(
82 src1_bcast_dims, defualt_dt, default_format, true);
83
84 binary_pd = binary::primitive_desc(
85 e, algorithm::binary_add, src0_md, src1_bcast_md, dst_md);
86
87 ASSERT_NO_THROW(impl_info_bcast = binary_pd.impl_info_str(););
88
89 const auto expect_jit = std::get<2>(GetParam());
90 if (expect_jit)
91 ASSERT_EQ(impl_info_no_bcast, impl_info_bcast);
92 else
93 ASSERT_NE(impl_info_no_bcast, impl_info_bcast);
94}
95
96INSTANTIATE_TEST_SUITE_P(CPUOptimizedDims, binary_bcast_test_t,
97 ::testing::Values(
98 // 5d cases
99 std::make_tuple(engine::kind::cpu,
100 memory::dims {NO_BCAST, NO_BCAST, BCAST, BCAST, BCAST},
101 true),
102 std::make_tuple(engine::kind::cpu,
103 memory::dims {BCAST, NO_BCAST, BCAST, BCAST, BCAST},
104 true),
105 std::make_tuple(engine::kind::cpu,
106 memory::dims {
107 NO_BCAST, BCAST, NO_BCAST, NO_BCAST, NO_BCAST},
108 true),
109 std::make_tuple(engine::kind::cpu,
110 memory::dims {
111 NO_BCAST, BCAST, BCAST, NO_BCAST, NO_BCAST},
112 true),
113 std::make_tuple(engine::kind::cpu,
114 memory::dims {NO_BCAST, BCAST, BCAST, BCAST, NO_BCAST},
115 true),
116 std::make_tuple(engine::kind::cpu,
117 memory::dims {
118 BCAST, NO_BCAST, NO_BCAST, NO_BCAST, NO_BCAST},
119 true),
120 std::make_tuple(engine::kind::cpu,
121 memory::dims {
122 BCAST, BCAST, NO_BCAST, NO_BCAST, NO_BCAST},
123 true),
124 std::make_tuple(engine::kind::cpu,
125 memory::dims {BCAST, BCAST, BCAST, NO_BCAST, NO_BCAST},
126 true),
127 std::make_tuple(engine::kind::cpu,
128 memory::dims {BCAST, BCAST, BCAST, BCAST, NO_BCAST},
129 true),
130 std::make_tuple(engine::kind::cpu,
131 memory::dims {BCAST, BCAST, BCAST, BCAST, BCAST}, true),
132 // 4d cases
133 std::make_tuple(engine::kind::cpu,
134 memory::dims {NO_BCAST, NO_BCAST, BCAST, BCAST}, true),
135 std::make_tuple(engine::kind::cpu,
136 memory::dims {BCAST, NO_BCAST, BCAST, BCAST}, true),
137 std::make_tuple(engine::kind::cpu,
138 memory::dims {NO_BCAST, BCAST, NO_BCAST, NO_BCAST},
139 true),
140 std::make_tuple(engine::kind::cpu,
141 memory::dims {NO_BCAST, BCAST, BCAST, NO_BCAST}, true),
142 std::make_tuple(engine::kind::cpu,
143 memory::dims {BCAST, NO_BCAST, NO_BCAST, NO_BCAST},
144 true),
145 std::make_tuple(engine::kind::cpu,
146 memory::dims {BCAST, BCAST, NO_BCAST, NO_BCAST}, true),
147 std::make_tuple(engine::kind::cpu,
148 memory::dims {BCAST, BCAST, BCAST, NO_BCAST}, true),
149 std::make_tuple(engine::kind::cpu,
150 memory::dims {BCAST, BCAST, BCAST, BCAST}, true),
151 // 3d cases
152 std::make_tuple(engine::kind::cpu,
153 memory::dims {NO_BCAST, NO_BCAST, BCAST}, true),
154 std::make_tuple(engine::kind::cpu,
155 memory::dims {BCAST, NO_BCAST, BCAST}, true),
156 std::make_tuple(engine::kind::cpu,
157 memory::dims {NO_BCAST, BCAST, NO_BCAST}, true),
158 std::make_tuple(engine::kind::cpu,
159 memory::dims {BCAST, NO_BCAST, NO_BCAST}, true),
160 std::make_tuple(engine::kind::cpu,
161 memory::dims {BCAST, BCAST, NO_BCAST}, true),
162 std::make_tuple(engine::kind::cpu,
163 memory::dims {BCAST, BCAST, BCAST}, true),
164 // 2d cases
165 std::make_tuple(
166 engine::kind::cpu, memory::dims {BCAST, BCAST}, true),
167 std::make_tuple(engine::kind::cpu,
168 memory::dims {NO_BCAST, BCAST}, true),
169 std::make_tuple(engine::kind::cpu,
170 memory::dims {BCAST, NO_BCAST}, true),
171 // 1d case
172 std::make_tuple(
173 engine::kind::cpu, memory::dims {BCAST}, true)));
174
175INSTANTIATE_TEST_SUITE_P(CPUNotOptimizedDims, binary_bcast_test_t,
176 ::testing::Values(
177 // selected unoptimized cases
178 std::make_tuple(engine::kind::cpu,
179 memory::dims {BCAST, BCAST, NO_BCAST, BCAST, NO_BCAST},
180 false),
181 std::make_tuple(engine::kind::cpu,
182 memory::dims {BCAST, NO_BCAST, BCAST, NO_BCAST}, false),
183 std::make_tuple(engine::kind::cpu,
184 memory::dims {BCAST, NO_BCAST, BCAST, BCAST, NO_BCAST},
185 false),
186 std::make_tuple(engine::kind::cpu,
187 memory::dims {NO_BCAST, BCAST, BCAST, BCAST, BCAST},
188 false),
189 std::make_tuple(engine::kind::cpu,
190 memory::dims {BCAST, BCAST, NO_BCAST, BCAST}, false),
191 std::make_tuple(engine::kind::cpu,
192 memory::dims {NO_BCAST, BCAST, BCAST}, false)));
193
194} // namespace dnnl
195