1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstring> |
18 | #include <memory> |
19 | #include <vector> |
20 | |
21 | #include "dnnl_test_common.hpp" |
22 | #include "gtest/gtest.h" |
23 | |
24 | #include "oneapi/dnnl/dnnl.hpp" |
25 | |
26 | #define DEBUG_TEST_MEMORY_DESC_OPS_CPP 0 |
27 | |
28 | namespace dnnl { |
29 | namespace memory_desc_ops { |
30 | |
31 | namespace debug { |
32 | #if DEBUG_TEST_MEMORY_DESC_OPS_CPP |
33 | template <typename T> |
34 | void print_vec(const char *str, const T &vec) { |
35 | printf("%s" , str); |
36 | for (int d = 0; d < (int)vec.size(); ++d) |
37 | printf("%d " , (int)vec[d]); |
38 | printf("\n" ); |
39 | } |
40 | void print_md(const char *str, const memory::desc &md) { |
41 | const auto &o_bd = md.format_desc.blocking; |
42 | |
43 | printf("%s\n" , str); |
44 | |
45 | print_vec("\tdims : " , md.get_dims()); |
46 | print_vec("\tpdims: " , md.get_padded_dims()); |
47 | print_vec("\toffs : " , md.get_padded_offsets()); |
48 | print_vec("\tstrs : " , o_bd.get_strides()); |
49 | |
50 | printf("\t\tnblks : %d\n" , o_bd.get_inner_nblks()); |
51 | print_vec("\t\tidxs : " , o_bd.get_get_inner_idxs()); |
52 | print_vec("\t\tblks : " , o_bd.get_inner_blks()); |
53 | } |
54 | #else // DEBUG_TEST_MEMORY_DESC_OPS_CPP |
55 | template <typename T> |
56 | void print_vec(const char *, const T &) {} |
57 | void print_md(const char *, const memory::desc &) {} |
58 | #endif // DEBUG_TEST_MEMORY_DESC_OPS_CPP |
59 | } // namespace debug |
60 | |
61 | // A proxy to memory::desc with fixed data type (f32) |
62 | struct memory_desc_proxy_t { |
63 | memory::desc md; |
64 | memory_desc_proxy_t() = default; |
65 | memory_desc_proxy_t(const memory::desc &md) : md(md) {} |
66 | |
67 | memory_desc_proxy_t(const memory::dims &dims, memory::format_tag tag) |
68 | : md(dims, memory::data_type::f32, tag) {} |
69 | memory_desc_proxy_t(const memory::dims &dims, const memory::dims &strides) |
70 | : md(dims, memory::data_type::f32, strides) {} |
71 | |
72 | memory_desc_proxy_t(const memory::dims &dims, const memory::dims &strides, |
73 | const memory::dims &padded_dims) |
74 | : md(dims, memory::data_type::f32, strides) { |
75 | for (int d = 0; d < md.get_ndims(); ++d) |
76 | md.get()->padded_dims[d] = padded_dims[d]; |
77 | } |
78 | }; |
79 | |
80 | enum test_direction_t { BI_DIRECTION = 0 /* default */, UNI_DIRECTION = 1 }; |
81 | |
82 | namespace properties { |
83 | |
84 | using fmt = dnnl::memory::format_tag; |
85 | |
86 | TEST(memory_desc_properties_test, TestMemoryDescSize) { |
87 | auto md1_simple = memory_desc_proxy_t {{1, 1, 1, 1}, {1, 1, 1, 1}}.md; |
88 | auto md1_strided = memory_desc_proxy_t {{1, 1, 1, 1}, {8, 4, 2, 1}}.md; |
89 | auto md2_blocked = memory_desc_proxy_t {{1, 4, 1, 1}, fmt::nChw8c}.md; |
90 | |
91 | ASSERT_EQ(md1_simple, md1_strided); |
92 | ASSERT_NE(md2_blocked, md1_simple); |
93 | ASSERT_NE(md2_blocked, md1_strided); |
94 | |
95 | ASSERT_EQ(md1_simple.get_size(), 1 * sizeof(float)); |
96 | ASSERT_EQ(md1_strided.get_size(), 1 * sizeof(float)); |
97 | ASSERT_EQ(md2_blocked.get_size(), 8 * sizeof(float)); |
98 | } |
99 | |
100 | } // namespace properties |
101 | |
102 | namespace reshape { |
103 | |
104 | struct params_t { |
105 | memory_desc_proxy_t in; |
106 | memory_desc_proxy_t out; |
107 | test_direction_t test_direction; |
108 | dnnl_status_t expected_status; |
109 | }; |
110 | |
111 | class reshape_test_t : public ::testing::TestWithParam<params_t> { |
112 | protected: |
113 | void Test(const memory::desc &in_md, const memory::desc &out_md) { |
114 | memory::desc get_out_md = in_md.reshape(out_md.get_dims()); |
115 | |
116 | debug::print_md("in_md" , in_md); |
117 | debug::print_md("out_md" , get_out_md); |
118 | debug::print_md("expect_out_md" , out_md); |
119 | |
120 | ASSERT_EQ(get_out_md, out_md); |
121 | } |
122 | }; |
123 | TEST_P(reshape_test_t, TestsReshape) { |
124 | params_t p = ::testing::TestWithParam<decltype(p)>::GetParam(); |
125 | catch_expected_failures([=]() { Test(p.in.md, p.out.md); }, |
126 | p.expected_status != dnnl_success, p.expected_status); |
127 | if (p.test_direction == UNI_DIRECTION) return; |
128 | catch_expected_failures([=]() { Test(p.out.md, p.in.md); }, |
129 | p.expected_status != dnnl_success, p.expected_status); |
130 | } |
131 | |
132 | using fmt = dnnl::memory::format_tag; |
133 | |
134 | // clang-format off |
135 | auto cases_expect_to_fail = ::testing::Values( |
136 | // volume mismatch |
137 | params_t {{{2, 2, 1, 1}, fmt::abcd}, {{2, 2, 2, 1, 1}, fmt::abcde}, BI_DIRECTION, dnnl_invalid_arguments}, |
138 | // volume mismatch |
139 | params_t {{{2, 1}, {1, 1}}, {{2, 1, 2}, {2, 2, 1}}, BI_DIRECTION, dnnl_invalid_arguments}, |
140 | // volume mismatch |
141 | params_t {{{6, 2}, fmt::ab}, {{6}, fmt::a}, BI_DIRECTION, dnnl_invalid_arguments}, |
142 | // joining axes are not contiguous in memory (`cdab` would be oK) |
143 | params_t {{{2, 3, 0, 2}, fmt::cdba}, {{6, 0, 2}, fmt::bca}, UNI_DIRECTION, dnnl_invalid_arguments}, |
144 | // joining axes are not contiguous in memory |
145 | params_t {{{6, 2}, fmt::ba}, {{12}, fmt::a}, UNI_DIRECTION, dnnl_invalid_arguments}, |
146 | // joining axes are not contiguous in memory (strides {2, 1} would be oK) |
147 | params_t {{{6, 2}, {3, 1}}, {{12}, fmt::a}, UNI_DIRECTION, dnnl_invalid_arguments}, |
148 | // removing an axis of size `1` that has padding is not allowed |
149 | params_t {{{6, 1, 2}, {4, 2, 1}, {6, 2, 2}}, {{6, 2}, fmt::any}, UNI_DIRECTION, dnnl_invalid_arguments}, |
150 | // joining axes where one has padding is not allowed |
151 | params_t {{{6, 2, 2}, {6, 2, 1}, {6, 3, 2}}, {{6, 4}, fmt::any}, UNI_DIRECTION, dnnl_invalid_arguments}, |
152 | // splitting an axis that has padding is not allowed |
153 | params_t {{{6}, {1}, {12}}, {{2, 3}, fmt::any}, UNI_DIRECTION, dnnl_invalid_arguments}, |
154 | // joining axes are not contiguous (partially, due to the blocking) |
155 | params_t {{{2, 8, 3, 4}, fmt::aBcd8b}, {{2, 8 * 3 * 4}, fmt::ab}, UNI_DIRECTION, dnnl_invalid_arguments}, |
156 | // nothing can be done with zero memory desc |
157 | params_t {{}, {}, UNI_DIRECTION, dnnl_invalid_arguments}, |
158 | // run-time dims are not supported |
159 | params_t {{{DNNL_RUNTIME_DIM_VAL}, {1}}, {{DNNL_RUNTIME_DIM_VAL}, {1}}, UNI_DIRECTION, dnnl_invalid_arguments} |
160 | ); |
161 | |
162 | auto cases_zero_dim = ::testing::Values( |
163 | params_t {{{2, 0, 2}, fmt::abc}, {{2, 0, 2, 1}, fmt::abcd}}, |
164 | params_t {{{2, 0, 2}, fmt::abc}, {{2, 0, 1, 2, 1}, fmt::abcde}}, |
165 | params_t {{{2, 1, 0, 2}, fmt::abcd}, {{2, 0, 2, 1}, fmt::abcd}}, |
166 | params_t {{{31, 1, 0, 2}, fmt::Abcd16a}, {{1, 31, 0, 2, 1}, fmt::aBcde16b}}, |
167 | params_t {{{2, 3, 0, 2}, {6, 2, 2, 1}}, {{6, 0, 2}, {2, 2, 1}}} |
168 | ); |
169 | |
170 | auto cases_generic = ::testing::Values( |
171 | // add and/or remove axes of size `1` |
172 | params_t {{{2, 1}, {2, 2}}, {{2}, {2}}}, |
173 | params_t {{{2, 1}, {2, 2}}, {{2, 1, 1}, {2, 2, 1}}}, |
174 | params_t {{{2, 1}, {2, 2}}, {{2, 1, 1}, {2, 1, 2}}}, |
175 | params_t {{{2, 1}, {2, 2}}, {{2, 1, 1}, {2, 2, 2}}}, |
176 | params_t {{{2, 2}, fmt::ab}, {{2, 2, 1}, fmt::abc}}, |
177 | params_t {{{2, 1}, fmt::ab}, {{1, 2, 1, 1}, fmt::abcd}}, |
178 | params_t {{{1, 2, 1}, fmt::abc}, {{2}, fmt::a}}, |
179 | params_t {{{3, 4, 5, 6}, fmt::ABcd16b16a}, {{1, 3, 4, 5, 6}, fmt::aBCde16c16b}}, |
180 | // UNI_DIRECTION due to ambiguity of adding 1, where there is already another axes of size 1 |
181 | params_t {{{2, 1, 1}, {2, 1, 1}, {2, 2, 1}}, {{2, 1}, {2, 1}, {2, 2}}, UNI_DIRECTION}, |
182 | params_t {{{2, 1, 1}, {2, 2, 1}, {2, 1, 2}}, {{2, 1}, {2, 1}, {2, 2}}}, |
183 | // split and join axes (as test_direction == BI_DIRECTION) |
184 | params_t {{{6, 2}, fmt::ab}, {{3, 2, 2}, fmt::abc}}, |
185 | params_t {{{6, 2}, fmt::ab}, {{2, 3, 2}, fmt::abc}}, |
186 | params_t {{{6, 2}, fmt::ba}, {{2, 3, 2}, /* fmt::cab: */ {3, 1, 6}}}, |
187 | params_t {{{6, 2}, {4, 1}, {6, 4}}, {{2, 3, 2}, {12, 4, 1}, {2, 3, 4}}}, |
188 | params_t {{{1, 15, 12}, fmt::aBc8b}, {{1, 15, 3, 4}, fmt::aBcd8b}}, |
189 | params_t {{{1, 15, 12}, fmt::aBc8b}, {{1, 15, 2, 3, 2}, fmt::aBcde8b}}, |
190 | // combined cases |
191 | params_t {{{15, 3, 4}, fmt::abc}, {{3, 5, 6, 1, 2}, fmt::abcde}}, |
192 | params_t {{{15, 3, 4}, fmt::bca}, {{3, 5, 6, 1, 2}, /* fmt::cdeab */ {5, 1, 30, 30, 15}}} |
193 | ); |
194 | // clang-format on |
195 | |
196 | INSTANTIATE_TEST_SUITE_P(TestReshapeEF, reshape_test_t, cases_expect_to_fail); |
197 | INSTANTIATE_TEST_SUITE_P(TestReshapeZeroDim, reshape_test_t, cases_zero_dim); |
198 | INSTANTIATE_TEST_SUITE_P(TestReshapeOK, reshape_test_t, cases_generic); |
199 | |
200 | } // namespace reshape |
201 | |
202 | namespace permute_axes { |
203 | |
204 | struct params_t { |
205 | memory_desc_proxy_t in; |
206 | memory_desc_proxy_t out; |
207 | std::vector<int> perm; |
208 | test_direction_t test_direction; |
209 | dnnl_status_t expected_status; |
210 | }; |
211 | |
212 | class permute_axes_test_t : public ::testing::TestWithParam<params_t> { |
213 | protected: |
214 | void Test(const memory::desc &in_md, const memory::desc &out_md, |
215 | const std::vector<int> &perm) { |
216 | memory::desc get_out_md = in_md.permute_axes(perm); |
217 | |
218 | debug::print_md("in_md" , in_md); |
219 | debug::print_vec("perm : " , perm); |
220 | debug::print_md("out_md" , get_out_md); |
221 | debug::print_md("expect_out_md" , out_md); |
222 | |
223 | ASSERT_EQ(get_out_md, out_md); |
224 | } |
225 | }; |
226 | TEST_P(permute_axes_test_t, TestsPermuteAxes) { |
227 | params_t p = ::testing::TestWithParam<decltype(p)>::GetParam(); |
228 | catch_expected_failures([=]() { Test(p.in.md, p.out.md, p.perm); }, |
229 | p.expected_status != dnnl_success, p.expected_status); |
230 | if (p.test_direction == UNI_DIRECTION) return; |
231 | |
232 | std::vector<int> inv_perm(p.perm.size()); |
233 | for (int i = 0; i < (int)p.perm.size(); ++i) |
234 | inv_perm[p.perm[i]] = i; |
235 | catch_expected_failures([=]() { Test(p.out.md, p.in.md, inv_perm); }, |
236 | p.expected_status != dnnl_success, p.expected_status); |
237 | } |
238 | |
239 | using fmt = dnnl::memory::format_tag; |
240 | |
241 | // clang-format off |
242 | auto cases_expect_to_fail = ::testing::Values( |
243 | // incorrect permutation |
244 | params_t {{{2, 2, 1, 1}, fmt::abcd}, {{2, 2, 1, 1}, fmt::abcd}, {0, 1, 2, 2}, UNI_DIRECTION, dnnl_invalid_arguments}, |
245 | // incorrect permutation |
246 | params_t {{{2, 2, 1, 1}, fmt::abcd}, {{2, 2, 1, 1}, fmt::abcd}, {0, 1, 2, 4}, UNI_DIRECTION, dnnl_invalid_arguments}, |
247 | // incorrect permutation |
248 | params_t {{{2, 2, 1, 1}, fmt::abcd}, {{2, 2, 1, 1}, fmt::abcd}, {0, 1, 2, -1}, UNI_DIRECTION, dnnl_invalid_arguments}, |
249 | // nothing can be done with zero memory desc |
250 | params_t {{}, {}, {}, UNI_DIRECTION, dnnl_invalid_arguments}, |
251 | // run-time dims are not supported |
252 | params_t {{{DNNL_RUNTIME_DIM_VAL}, {1}}, {{DNNL_RUNTIME_DIM_VAL}, {1}}, {0}, UNI_DIRECTION, dnnl_invalid_arguments} |
253 | ); |
254 | |
255 | auto cases_generic = ::testing::Values( |
256 | params_t {{{2, 1}, fmt::ab}, {{2, 1}, fmt::ab}, {0, 1}}, |
257 | params_t {{{2, 1}, fmt::ab}, {{1, 2}, fmt::ba}, {1, 0}}, |
258 | params_t {{{2, 1}, fmt::ba}, {{1, 2}, fmt::ab}, {1, 0}}, |
259 | params_t {{{2, 3}, {4, 1}, {2, 4}}, {{3, 2}, {1, 4}, {4, 2}}, {1, 0}}, |
260 | params_t {{{3, 2}, {2, 30}}, {{2, 3}, {30, 2}}, {1, 0}}, |
261 | params_t {{{2, 3, 4, 5}, fmt::acdb}, {{2, 4, 5, 3}, fmt::abcd}, {0, 3, 1, 2}}, |
262 | params_t {{{2, 3, 4, 5}, fmt::cdba}, {{4, 5, 3, 2}, fmt::abcd}, {3, 2, 0, 1}}, |
263 | params_t {{{2, 15, 3, 4}, fmt::ABcd16b16a}, {{15, 2, 3, 4}, fmt::BAcd16a16b}, {1, 0, 2, 3}}, |
264 | params_t {{{3, 2, 15, 3, 4, 5}, fmt::aBCdef16b16c}, {{3, 15, 2, 3, 4, 5}, fmt::aCBdef16c16b}, {0, 2, 1, 3, 4, 5}} |
265 | ); |
266 | // clang-format on |
267 | |
268 | INSTANTIATE_TEST_SUITE_P( |
269 | TestPermuteAxesEF, permute_axes_test_t, cases_expect_to_fail); |
270 | INSTANTIATE_TEST_SUITE_P(TestPermuteAxesOK, permute_axes_test_t, cases_generic); |
271 | |
272 | } // namespace permute_axes |
273 | |
274 | } // namespace memory_desc_ops |
275 | } // namespace dnnl |
276 | |