1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain src copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "src/common/broadcast_strategy.cpp" |
19 | #include "gtest/gtest.h" |
20 | |
21 | #include "oneapi/dnnl/dnnl.hpp" |
22 | |
23 | #define CASE(ndims, tag) \ |
24 | case ndims: return memory::format_tag::tag; |
25 | |
26 | namespace dnnl { |
27 | |
28 | memory::format_tag plain_format_tag(size_t ndims) { |
29 | assert(ndims <= 12); |
30 | switch (ndims) { |
31 | CASE(1, a) |
32 | CASE(2, ab) |
33 | CASE(3, abc) |
34 | CASE(4, abcd) |
35 | CASE(5, abcde) |
36 | CASE(6, abcdef) |
37 | CASE(7, abcdefg) |
38 | CASE(8, abcdefgh) |
39 | CASE(9, abcdefghi) |
40 | CASE(10, abcdefghij) |
41 | CASE(11, abcdefghijk) |
42 | CASE(12, abcdefghijkl) |
43 | default: return memory::format_tag::any; |
44 | } |
45 | } |
46 | |
47 | #undef CASE |
48 | |
49 | struct bcast_strategy_test_t |
50 | : public ::testing::TestWithParam<std::tuple<memory::dims, memory::dims, |
51 | impl::broadcasting_strategy_t>> {}; |
52 | |
53 | HANDLE_EXCEPTIONS_FOR_TEST_P(bcast_strategy_test_t, TestBroadcastStrategy) { |
54 | const auto &dst_dims = std::get<0>(GetParam()); |
55 | const auto &rhs_arg_dims = std::get<1>(GetParam()); |
56 | ASSERT_EQ(dst_dims.size(), rhs_arg_dims.size()); |
57 | |
58 | const size_t ndims = dst_dims.size(); |
59 | constexpr auto defualt_dt = memory::data_type::f32; |
60 | const auto default_format = plain_format_tag(ndims); |
61 | auto rhs_md = memory::desc(rhs_arg_dims, defualt_dt, default_format, true); |
62 | auto dst_md = memory::desc(dst_dims, defualt_dt, default_format, true); |
63 | auto dst_mdw = impl::memory_desc_wrapper(dst_md.get()); |
64 | const auto bcast_type |
65 | = impl::get_rhs_arg_broadcasting_strategy(*rhs_md.get(), dst_mdw); |
66 | const auto expected_bcast_type = std::get<2>(GetParam()); |
67 | ASSERT_EQ(bcast_type, expected_bcast_type); |
68 | } |
69 | |
70 | INSTANTIATE_TEST_SUITE_P(SupportedStrategies, bcast_strategy_test_t, |
71 | ::testing::Values( |
72 | // 5d cases |
73 | std::make_tuple(memory::dims {2, 2, 2, 2, 2}, |
74 | memory::dims {1, 1, 1, 1, 1}, |
75 | impl::broadcasting_strategy_t::scalar), |
76 | std::make_tuple(memory::dims {2, 2, 2, 2, 2}, |
77 | memory::dims {1, 2, 1, 1, 1}, |
78 | impl::broadcasting_strategy_t::per_oc_spatial), |
79 | std::make_tuple(memory::dims {1, 2, 2, 2, 2}, |
80 | memory::dims {1, 2, 1, 1, 1}, |
81 | impl::broadcasting_strategy_t::per_oc_spatial), |
82 | std::make_tuple(memory::dims {1, 2, 1, 2, 2}, |
83 | memory::dims {1, 2, 1, 1, 1}, |
84 | impl::broadcasting_strategy_t::per_oc_spatial), |
85 | std::make_tuple(memory::dims {1, 2, 1, 1, 2}, |
86 | memory::dims {1, 2, 1, 1, 1}, |
87 | impl::broadcasting_strategy_t::per_oc_spatial), |
88 | std::make_tuple(memory::dims {2, 2, 1, 1, 2}, |
89 | memory::dims {1, 2, 1, 1, 1}, |
90 | impl::broadcasting_strategy_t::per_oc_spatial), |
91 | std::make_tuple(memory::dims {2, 2, 1, 2, 2}, |
92 | memory::dims {1, 2, 1, 1, 1}, |
93 | impl::broadcasting_strategy_t::per_oc_spatial), |
94 | std::make_tuple(memory::dims {1, 2, 1, 1, 1}, |
95 | memory::dims {1, 2, 1, 1, 1}, |
96 | impl::broadcasting_strategy_t::no_broadcast), |
97 | std::make_tuple(memory::dims {2, 2, 2, 2, 2}, |
98 | memory::dims {2, 2, 2, 2, 2}, |
99 | impl::broadcasting_strategy_t::no_broadcast), |
100 | std::make_tuple(memory::dims {2, 2, 2, 2, 2}, |
101 | memory::dims {1, 1, 1, 1, 2}, |
102 | impl::broadcasting_strategy_t::per_w), |
103 | std::make_tuple(memory::dims {2, 2, 2, 1, 2}, |
104 | memory::dims {1, 1, 1, 1, 2}, |
105 | impl::broadcasting_strategy_t::per_w), |
106 | std::make_tuple(memory::dims {2, 2, 1, 1, 2}, |
107 | memory::dims {1, 1, 1, 1, 2}, |
108 | impl::broadcasting_strategy_t::per_w), |
109 | std::make_tuple(memory::dims {2, 1, 1, 1, 2}, |
110 | memory::dims {1, 1, 1, 1, 2}, |
111 | impl::broadcasting_strategy_t::per_w), |
112 | std::make_tuple(memory::dims {1, 2, 1, 1, 2}, |
113 | memory::dims {1, 1, 1, 1, 2}, |
114 | impl::broadcasting_strategy_t::per_w), |
115 | std::make_tuple(memory::dims {1, 1, 2, 1, 2}, |
116 | memory::dims {1, 1, 1, 1, 2}, |
117 | impl::broadcasting_strategy_t::per_w), |
118 | std::make_tuple(memory::dims {1, 1, 1, 2, 2}, |
119 | memory::dims {1, 1, 1, 1, 2}, |
120 | impl::broadcasting_strategy_t::per_w), |
121 | std::make_tuple(memory::dims {1, 1, 1, 1, 2}, |
122 | memory::dims {1, 1, 1, 1, 2}, |
123 | impl::broadcasting_strategy_t::no_broadcast), |
124 | std::make_tuple(memory::dims {2, 2, 2, 2, 2}, |
125 | memory::dims {2, 1, 1, 1, 2}, |
126 | impl::broadcasting_strategy_t::per_mb_w), |
127 | std::make_tuple(memory::dims {2, 2, 2, 1, 2}, |
128 | memory::dims {2, 1, 1, 1, 2}, |
129 | impl::broadcasting_strategy_t::per_mb_w), |
130 | std::make_tuple(memory::dims {2, 2, 1, 1, 2}, |
131 | memory::dims {2, 1, 1, 1, 2}, |
132 | impl::broadcasting_strategy_t::per_mb_w), |
133 | std::make_tuple(memory::dims {2, 1, 1, 1, 2}, |
134 | memory::dims {2, 1, 1, 1, 2}, |
135 | impl::broadcasting_strategy_t::no_broadcast), |
136 | std::make_tuple(memory::dims {2, 1, 2, 2, 2}, |
137 | memory::dims {2, 1, 2, 2, 2}, |
138 | impl::broadcasting_strategy_t::no_broadcast), |
139 | std::make_tuple(memory::dims {2, 2, 2, 2, 2}, |
140 | memory::dims {2, 1, 2, 2, 2}, |
141 | impl::broadcasting_strategy_t::per_mb_spatial), |
142 | std::make_tuple(memory::dims {1, 2, 2, 2, 2}, |
143 | memory::dims {1, 1, 2, 2, 2}, |
144 | impl::broadcasting_strategy_t::per_mb_spatial), |
145 | std::make_tuple(memory::dims {2, 2, 1, 2, 2}, |
146 | memory::dims {2, 1, 1, 2, 2}, |
147 | impl::broadcasting_strategy_t::per_mb_spatial), |
148 | std::make_tuple(memory::dims {2, 2, 2, 1, 2}, |
149 | memory::dims {2, 1, 2, 1, 2}, |
150 | impl::broadcasting_strategy_t::per_mb_spatial), |
151 | std::make_tuple(memory::dims {2, 2, 2, 2, 1}, |
152 | memory::dims {2, 1, 2, 2, 1}, |
153 | impl::broadcasting_strategy_t::per_mb_spatial), |
154 | std::make_tuple(memory::dims {2, 2, 1, 2, 1}, |
155 | memory::dims {2, 1, 1, 2, 1}, |
156 | impl::broadcasting_strategy_t::per_mb_spatial), |
157 | std::make_tuple(memory::dims {1, 2, 1, 2, 1}, |
158 | memory::dims {1, 1, 1, 2, 1}, |
159 | impl::broadcasting_strategy_t::per_mb_spatial), |
160 | // 4d cases |
161 | std::make_tuple(memory::dims {2, 2, 2, 2}, |
162 | memory::dims {1, 1, 1, 1}, |
163 | impl::broadcasting_strategy_t::scalar), |
164 | std::make_tuple(memory::dims {2, 2, 2, 2}, |
165 | memory::dims {1, 2, 1, 1}, |
166 | impl::broadcasting_strategy_t::per_oc_spatial), |
167 | std::make_tuple(memory::dims {1, 2, 2, 2}, |
168 | memory::dims {1, 2, 1, 1}, |
169 | impl::broadcasting_strategy_t::per_oc_spatial), |
170 | std::make_tuple(memory::dims {1, 2, 1, 2}, |
171 | memory::dims {1, 2, 1, 1}, |
172 | impl::broadcasting_strategy_t::per_oc_spatial), |
173 | std::make_tuple(memory::dims {2, 2, 1, 2}, |
174 | memory::dims {1, 2, 1, 1}, |
175 | impl::broadcasting_strategy_t::per_oc_spatial), |
176 | std::make_tuple(memory::dims {2, 2, 2, 1}, |
177 | memory::dims {1, 2, 1, 1}, |
178 | impl::broadcasting_strategy_t::per_oc_spatial), |
179 | std::make_tuple(memory::dims {1, 2, 1, 1}, |
180 | memory::dims {1, 2, 1, 1}, |
181 | impl::broadcasting_strategy_t::no_broadcast), |
182 | std::make_tuple(memory::dims {2, 2, 2, 2}, |
183 | memory::dims {2, 2, 2, 2}, |
184 | impl::broadcasting_strategy_t::no_broadcast), |
185 | std::make_tuple(memory::dims {2, 2, 2, 2}, |
186 | memory::dims {1, 1, 1, 2}, |
187 | impl::broadcasting_strategy_t::per_w), |
188 | std::make_tuple(memory::dims {2, 2, 1, 2}, |
189 | memory::dims {1, 1, 1, 2}, |
190 | impl::broadcasting_strategy_t::per_w), |
191 | std::make_tuple(memory::dims {2, 1, 1, 2}, |
192 | memory::dims {1, 1, 1, 2}, |
193 | impl::broadcasting_strategy_t::per_w), |
194 | std::make_tuple(memory::dims {1, 2, 1, 2}, |
195 | memory::dims {1, 1, 1, 2}, |
196 | impl::broadcasting_strategy_t::per_w), |
197 | std::make_tuple(memory::dims {1, 1, 2, 2}, |
198 | memory::dims {1, 1, 1, 2}, |
199 | impl::broadcasting_strategy_t::per_w), |
200 | std::make_tuple(memory::dims {1, 1, 1, 2}, |
201 | memory::dims {1, 1, 1, 2}, |
202 | impl::broadcasting_strategy_t::no_broadcast), |
203 | std::make_tuple(memory::dims {2, 2, 2, 2}, |
204 | memory::dims {2, 1, 1, 2}, |
205 | impl::broadcasting_strategy_t::per_mb_w), |
206 | std::make_tuple(memory::dims {2, 2, 1, 2}, |
207 | memory::dims {2, 1, 1, 2}, |
208 | impl::broadcasting_strategy_t::per_mb_w), |
209 | std::make_tuple(memory::dims {2, 1, 1, 2}, |
210 | memory::dims {2, 1, 1, 2}, |
211 | impl::broadcasting_strategy_t::no_broadcast), |
212 | std::make_tuple(memory::dims {2, 1, 2, 2}, |
213 | memory::dims {2, 1, 2, 2}, |
214 | impl::broadcasting_strategy_t::no_broadcast), |
215 | std::make_tuple(memory::dims {2, 2, 2, 2}, |
216 | memory::dims {2, 1, 2, 2}, |
217 | impl::broadcasting_strategy_t::per_mb_spatial), |
218 | std::make_tuple(memory::dims {1, 2, 2, 2}, |
219 | memory::dims {1, 1, 2, 2}, |
220 | impl::broadcasting_strategy_t::per_mb_spatial), |
221 | std::make_tuple(memory::dims {1, 2, 2, 1}, |
222 | memory::dims {1, 1, 2, 1}, |
223 | impl::broadcasting_strategy_t::per_mb_spatial), |
224 | std::make_tuple(memory::dims {2, 2, 2, 1}, |
225 | memory::dims {2, 1, 2, 1}, |
226 | impl::broadcasting_strategy_t::per_mb_spatial), |
227 | // 3d cases |
228 | std::make_tuple(memory::dims {2, 2, 2}, memory::dims {1, 1, 1}, |
229 | impl::broadcasting_strategy_t::scalar), |
230 | std::make_tuple(memory::dims {2, 2, 2}, memory::dims {1, 2, 1}, |
231 | impl::broadcasting_strategy_t::per_oc_spatial), |
232 | std::make_tuple(memory::dims {2, 2, 1}, memory::dims {1, 2, 1}, |
233 | impl::broadcasting_strategy_t::per_oc_spatial), |
234 | std::make_tuple(memory::dims {1, 2, 2}, memory::dims {1, 2, 1}, |
235 | impl::broadcasting_strategy_t::per_oc_spatial), |
236 | std::make_tuple(memory::dims {1, 2, 1}, memory::dims {1, 2, 1}, |
237 | impl::broadcasting_strategy_t::no_broadcast), |
238 | std::make_tuple(memory::dims {2, 2, 2}, memory::dims {2, 2, 2}, |
239 | impl::broadcasting_strategy_t::no_broadcast), |
240 | std::make_tuple(memory::dims {2, 2, 2}, memory::dims {1, 1, 2}, |
241 | impl::broadcasting_strategy_t::per_w), |
242 | std::make_tuple(memory::dims {2, 1, 2}, memory::dims {1, 1, 2}, |
243 | impl::broadcasting_strategy_t::per_w), |
244 | std::make_tuple(memory::dims {1, 2, 2}, memory::dims {1, 1, 2}, |
245 | impl::broadcasting_strategy_t::per_w), |
246 | std::make_tuple(memory::dims {1, 1, 2}, memory::dims {1, 1, 2}, |
247 | impl::broadcasting_strategy_t::no_broadcast), |
248 | std::make_tuple(memory::dims {2, 2, 2}, memory::dims {2, 1, 2}, |
249 | impl::broadcasting_strategy_t::per_mb_w), |
250 | std::make_tuple(memory::dims {2, 1, 2}, memory::dims {2, 1, 2}, |
251 | impl::broadcasting_strategy_t::no_broadcast), |
252 | std::make_tuple(memory::dims {2, 2, 1}, memory::dims {2, 1, 1}, |
253 | impl::broadcasting_strategy_t::per_mb_spatial), |
254 | // 2d cases |
255 | std::make_tuple(memory::dims {2, 2}, memory::dims {1, 1}, |
256 | impl::broadcasting_strategy_t::scalar), |
257 | std::make_tuple(memory::dims {2, 2}, memory::dims {1, 2}, |
258 | impl::broadcasting_strategy_t::per_oc), |
259 | std::make_tuple(memory::dims {1, 2}, memory::dims {1, 2}, |
260 | impl::broadcasting_strategy_t::no_broadcast), |
261 | std::make_tuple(memory::dims {2, 2}, memory::dims {2, 1}, |
262 | impl::broadcasting_strategy_t::per_mb_spatial), |
263 | // 1d cases |
264 | std::make_tuple(memory::dims {2}, memory::dims {1}, |
265 | impl::broadcasting_strategy_t::scalar), |
266 | std::make_tuple(memory::dims {2}, memory::dims {2}, |
267 | impl::broadcasting_strategy_t::no_broadcast))); |
268 | } // namespace dnnl |
269 | |