1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain src copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "src/common/broadcast_strategy.cpp"
19#include "gtest/gtest.h"
20
21#include "oneapi/dnnl/dnnl.hpp"
22
23#define CASE(ndims, tag) \
24 case ndims: return memory::format_tag::tag;
25
26namespace dnnl {
27
28memory::format_tag plain_format_tag(size_t ndims) {
29 assert(ndims <= 12);
30 switch (ndims) {
31 CASE(1, a)
32 CASE(2, ab)
33 CASE(3, abc)
34 CASE(4, abcd)
35 CASE(5, abcde)
36 CASE(6, abcdef)
37 CASE(7, abcdefg)
38 CASE(8, abcdefgh)
39 CASE(9, abcdefghi)
40 CASE(10, abcdefghij)
41 CASE(11, abcdefghijk)
42 CASE(12, abcdefghijkl)
43 default: return memory::format_tag::any;
44 }
45}
46
47#undef CASE
48
49struct bcast_strategy_test_t
50 : public ::testing::TestWithParam<std::tuple<memory::dims, memory::dims,
51 impl::broadcasting_strategy_t>> {};
52
53HANDLE_EXCEPTIONS_FOR_TEST_P(bcast_strategy_test_t, TestBroadcastStrategy) {
54 const auto &dst_dims = std::get<0>(GetParam());
55 const auto &rhs_arg_dims = std::get<1>(GetParam());
56 ASSERT_EQ(dst_dims.size(), rhs_arg_dims.size());
57
58 const size_t ndims = dst_dims.size();
59 constexpr auto defualt_dt = memory::data_type::f32;
60 const auto default_format = plain_format_tag(ndims);
61 auto rhs_md = memory::desc(rhs_arg_dims, defualt_dt, default_format, true);
62 auto dst_md = memory::desc(dst_dims, defualt_dt, default_format, true);
63 auto dst_mdw = impl::memory_desc_wrapper(dst_md.get());
64 const auto bcast_type
65 = impl::get_rhs_arg_broadcasting_strategy(*rhs_md.get(), dst_mdw);
66 const auto expected_bcast_type = std::get<2>(GetParam());
67 ASSERT_EQ(bcast_type, expected_bcast_type);
68}
69
70INSTANTIATE_TEST_SUITE_P(SupportedStrategies, bcast_strategy_test_t,
71 ::testing::Values(
72 // 5d cases
73 std::make_tuple(memory::dims {2, 2, 2, 2, 2},
74 memory::dims {1, 1, 1, 1, 1},
75 impl::broadcasting_strategy_t::scalar),
76 std::make_tuple(memory::dims {2, 2, 2, 2, 2},
77 memory::dims {1, 2, 1, 1, 1},
78 impl::broadcasting_strategy_t::per_oc_spatial),
79 std::make_tuple(memory::dims {1, 2, 2, 2, 2},
80 memory::dims {1, 2, 1, 1, 1},
81 impl::broadcasting_strategy_t::per_oc_spatial),
82 std::make_tuple(memory::dims {1, 2, 1, 2, 2},
83 memory::dims {1, 2, 1, 1, 1},
84 impl::broadcasting_strategy_t::per_oc_spatial),
85 std::make_tuple(memory::dims {1, 2, 1, 1, 2},
86 memory::dims {1, 2, 1, 1, 1},
87 impl::broadcasting_strategy_t::per_oc_spatial),
88 std::make_tuple(memory::dims {2, 2, 1, 1, 2},
89 memory::dims {1, 2, 1, 1, 1},
90 impl::broadcasting_strategy_t::per_oc_spatial),
91 std::make_tuple(memory::dims {2, 2, 1, 2, 2},
92 memory::dims {1, 2, 1, 1, 1},
93 impl::broadcasting_strategy_t::per_oc_spatial),
94 std::make_tuple(memory::dims {1, 2, 1, 1, 1},
95 memory::dims {1, 2, 1, 1, 1},
96 impl::broadcasting_strategy_t::no_broadcast),
97 std::make_tuple(memory::dims {2, 2, 2, 2, 2},
98 memory::dims {2, 2, 2, 2, 2},
99 impl::broadcasting_strategy_t::no_broadcast),
100 std::make_tuple(memory::dims {2, 2, 2, 2, 2},
101 memory::dims {1, 1, 1, 1, 2},
102 impl::broadcasting_strategy_t::per_w),
103 std::make_tuple(memory::dims {2, 2, 2, 1, 2},
104 memory::dims {1, 1, 1, 1, 2},
105 impl::broadcasting_strategy_t::per_w),
106 std::make_tuple(memory::dims {2, 2, 1, 1, 2},
107 memory::dims {1, 1, 1, 1, 2},
108 impl::broadcasting_strategy_t::per_w),
109 std::make_tuple(memory::dims {2, 1, 1, 1, 2},
110 memory::dims {1, 1, 1, 1, 2},
111 impl::broadcasting_strategy_t::per_w),
112 std::make_tuple(memory::dims {1, 2, 1, 1, 2},
113 memory::dims {1, 1, 1, 1, 2},
114 impl::broadcasting_strategy_t::per_w),
115 std::make_tuple(memory::dims {1, 1, 2, 1, 2},
116 memory::dims {1, 1, 1, 1, 2},
117 impl::broadcasting_strategy_t::per_w),
118 std::make_tuple(memory::dims {1, 1, 1, 2, 2},
119 memory::dims {1, 1, 1, 1, 2},
120 impl::broadcasting_strategy_t::per_w),
121 std::make_tuple(memory::dims {1, 1, 1, 1, 2},
122 memory::dims {1, 1, 1, 1, 2},
123 impl::broadcasting_strategy_t::no_broadcast),
124 std::make_tuple(memory::dims {2, 2, 2, 2, 2},
125 memory::dims {2, 1, 1, 1, 2},
126 impl::broadcasting_strategy_t::per_mb_w),
127 std::make_tuple(memory::dims {2, 2, 2, 1, 2},
128 memory::dims {2, 1, 1, 1, 2},
129 impl::broadcasting_strategy_t::per_mb_w),
130 std::make_tuple(memory::dims {2, 2, 1, 1, 2},
131 memory::dims {2, 1, 1, 1, 2},
132 impl::broadcasting_strategy_t::per_mb_w),
133 std::make_tuple(memory::dims {2, 1, 1, 1, 2},
134 memory::dims {2, 1, 1, 1, 2},
135 impl::broadcasting_strategy_t::no_broadcast),
136 std::make_tuple(memory::dims {2, 1, 2, 2, 2},
137 memory::dims {2, 1, 2, 2, 2},
138 impl::broadcasting_strategy_t::no_broadcast),
139 std::make_tuple(memory::dims {2, 2, 2, 2, 2},
140 memory::dims {2, 1, 2, 2, 2},
141 impl::broadcasting_strategy_t::per_mb_spatial),
142 std::make_tuple(memory::dims {1, 2, 2, 2, 2},
143 memory::dims {1, 1, 2, 2, 2},
144 impl::broadcasting_strategy_t::per_mb_spatial),
145 std::make_tuple(memory::dims {2, 2, 1, 2, 2},
146 memory::dims {2, 1, 1, 2, 2},
147 impl::broadcasting_strategy_t::per_mb_spatial),
148 std::make_tuple(memory::dims {2, 2, 2, 1, 2},
149 memory::dims {2, 1, 2, 1, 2},
150 impl::broadcasting_strategy_t::per_mb_spatial),
151 std::make_tuple(memory::dims {2, 2, 2, 2, 1},
152 memory::dims {2, 1, 2, 2, 1},
153 impl::broadcasting_strategy_t::per_mb_spatial),
154 std::make_tuple(memory::dims {2, 2, 1, 2, 1},
155 memory::dims {2, 1, 1, 2, 1},
156 impl::broadcasting_strategy_t::per_mb_spatial),
157 std::make_tuple(memory::dims {1, 2, 1, 2, 1},
158 memory::dims {1, 1, 1, 2, 1},
159 impl::broadcasting_strategy_t::per_mb_spatial),
160 // 4d cases
161 std::make_tuple(memory::dims {2, 2, 2, 2},
162 memory::dims {1, 1, 1, 1},
163 impl::broadcasting_strategy_t::scalar),
164 std::make_tuple(memory::dims {2, 2, 2, 2},
165 memory::dims {1, 2, 1, 1},
166 impl::broadcasting_strategy_t::per_oc_spatial),
167 std::make_tuple(memory::dims {1, 2, 2, 2},
168 memory::dims {1, 2, 1, 1},
169 impl::broadcasting_strategy_t::per_oc_spatial),
170 std::make_tuple(memory::dims {1, 2, 1, 2},
171 memory::dims {1, 2, 1, 1},
172 impl::broadcasting_strategy_t::per_oc_spatial),
173 std::make_tuple(memory::dims {2, 2, 1, 2},
174 memory::dims {1, 2, 1, 1},
175 impl::broadcasting_strategy_t::per_oc_spatial),
176 std::make_tuple(memory::dims {2, 2, 2, 1},
177 memory::dims {1, 2, 1, 1},
178 impl::broadcasting_strategy_t::per_oc_spatial),
179 std::make_tuple(memory::dims {1, 2, 1, 1},
180 memory::dims {1, 2, 1, 1},
181 impl::broadcasting_strategy_t::no_broadcast),
182 std::make_tuple(memory::dims {2, 2, 2, 2},
183 memory::dims {2, 2, 2, 2},
184 impl::broadcasting_strategy_t::no_broadcast),
185 std::make_tuple(memory::dims {2, 2, 2, 2},
186 memory::dims {1, 1, 1, 2},
187 impl::broadcasting_strategy_t::per_w),
188 std::make_tuple(memory::dims {2, 2, 1, 2},
189 memory::dims {1, 1, 1, 2},
190 impl::broadcasting_strategy_t::per_w),
191 std::make_tuple(memory::dims {2, 1, 1, 2},
192 memory::dims {1, 1, 1, 2},
193 impl::broadcasting_strategy_t::per_w),
194 std::make_tuple(memory::dims {1, 2, 1, 2},
195 memory::dims {1, 1, 1, 2},
196 impl::broadcasting_strategy_t::per_w),
197 std::make_tuple(memory::dims {1, 1, 2, 2},
198 memory::dims {1, 1, 1, 2},
199 impl::broadcasting_strategy_t::per_w),
200 std::make_tuple(memory::dims {1, 1, 1, 2},
201 memory::dims {1, 1, 1, 2},
202 impl::broadcasting_strategy_t::no_broadcast),
203 std::make_tuple(memory::dims {2, 2, 2, 2},
204 memory::dims {2, 1, 1, 2},
205 impl::broadcasting_strategy_t::per_mb_w),
206 std::make_tuple(memory::dims {2, 2, 1, 2},
207 memory::dims {2, 1, 1, 2},
208 impl::broadcasting_strategy_t::per_mb_w),
209 std::make_tuple(memory::dims {2, 1, 1, 2},
210 memory::dims {2, 1, 1, 2},
211 impl::broadcasting_strategy_t::no_broadcast),
212 std::make_tuple(memory::dims {2, 1, 2, 2},
213 memory::dims {2, 1, 2, 2},
214 impl::broadcasting_strategy_t::no_broadcast),
215 std::make_tuple(memory::dims {2, 2, 2, 2},
216 memory::dims {2, 1, 2, 2},
217 impl::broadcasting_strategy_t::per_mb_spatial),
218 std::make_tuple(memory::dims {1, 2, 2, 2},
219 memory::dims {1, 1, 2, 2},
220 impl::broadcasting_strategy_t::per_mb_spatial),
221 std::make_tuple(memory::dims {1, 2, 2, 1},
222 memory::dims {1, 1, 2, 1},
223 impl::broadcasting_strategy_t::per_mb_spatial),
224 std::make_tuple(memory::dims {2, 2, 2, 1},
225 memory::dims {2, 1, 2, 1},
226 impl::broadcasting_strategy_t::per_mb_spatial),
227 // 3d cases
228 std::make_tuple(memory::dims {2, 2, 2}, memory::dims {1, 1, 1},
229 impl::broadcasting_strategy_t::scalar),
230 std::make_tuple(memory::dims {2, 2, 2}, memory::dims {1, 2, 1},
231 impl::broadcasting_strategy_t::per_oc_spatial),
232 std::make_tuple(memory::dims {2, 2, 1}, memory::dims {1, 2, 1},
233 impl::broadcasting_strategy_t::per_oc_spatial),
234 std::make_tuple(memory::dims {1, 2, 2}, memory::dims {1, 2, 1},
235 impl::broadcasting_strategy_t::per_oc_spatial),
236 std::make_tuple(memory::dims {1, 2, 1}, memory::dims {1, 2, 1},
237 impl::broadcasting_strategy_t::no_broadcast),
238 std::make_tuple(memory::dims {2, 2, 2}, memory::dims {2, 2, 2},
239 impl::broadcasting_strategy_t::no_broadcast),
240 std::make_tuple(memory::dims {2, 2, 2}, memory::dims {1, 1, 2},
241 impl::broadcasting_strategy_t::per_w),
242 std::make_tuple(memory::dims {2, 1, 2}, memory::dims {1, 1, 2},
243 impl::broadcasting_strategy_t::per_w),
244 std::make_tuple(memory::dims {1, 2, 2}, memory::dims {1, 1, 2},
245 impl::broadcasting_strategy_t::per_w),
246 std::make_tuple(memory::dims {1, 1, 2}, memory::dims {1, 1, 2},
247 impl::broadcasting_strategy_t::no_broadcast),
248 std::make_tuple(memory::dims {2, 2, 2}, memory::dims {2, 1, 2},
249 impl::broadcasting_strategy_t::per_mb_w),
250 std::make_tuple(memory::dims {2, 1, 2}, memory::dims {2, 1, 2},
251 impl::broadcasting_strategy_t::no_broadcast),
252 std::make_tuple(memory::dims {2, 2, 1}, memory::dims {2, 1, 1},
253 impl::broadcasting_strategy_t::per_mb_spatial),
254 // 2d cases
255 std::make_tuple(memory::dims {2, 2}, memory::dims {1, 1},
256 impl::broadcasting_strategy_t::scalar),
257 std::make_tuple(memory::dims {2, 2}, memory::dims {1, 2},
258 impl::broadcasting_strategy_t::per_oc),
259 std::make_tuple(memory::dims {1, 2}, memory::dims {1, 2},
260 impl::broadcasting_strategy_t::no_broadcast),
261 std::make_tuple(memory::dims {2, 2}, memory::dims {2, 1},
262 impl::broadcasting_strategy_t::per_mb_spatial),
263 // 1d cases
264 std::make_tuple(memory::dims {2}, memory::dims {1},
265 impl::broadcasting_strategy_t::scalar),
266 std::make_tuple(memory::dims {2}, memory::dims {2},
267 impl::broadcasting_strategy_t::no_broadcast)));
268} // namespace dnnl
269