1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cstring>
18#include <memory>
19#include <vector>
20
21#include "dnnl_test_common.hpp"
22#include "gtest/gtest.h"
23
24#include "oneapi/dnnl/dnnl.hpp"
25
26#define DEBUG_TEST_MEMORY_DESC_OPS_CPP 0
27
28namespace dnnl {
29namespace memory_desc_ops {
30
31namespace debug {
32#if DEBUG_TEST_MEMORY_DESC_OPS_CPP
33template <typename T>
34void print_vec(const char *str, const T &vec) {
35 printf("%s", str);
36 for (int d = 0; d < (int)vec.size(); ++d)
37 printf("%d ", (int)vec[d]);
38 printf("\n");
39}
40void print_md(const char *str, const memory::desc &md) {
41 const auto &o_bd = md.format_desc.blocking;
42
43 printf("%s\n", str);
44
45 print_vec("\tdims : ", md.get_dims());
46 print_vec("\tpdims: ", md.get_padded_dims());
47 print_vec("\toffs : ", md.get_padded_offsets());
48 print_vec("\tstrs : ", o_bd.get_strides());
49
50 printf("\t\tnblks : %d\n", o_bd.get_inner_nblks());
51 print_vec("\t\tidxs : ", o_bd.get_get_inner_idxs());
52 print_vec("\t\tblks : ", o_bd.get_inner_blks());
53}
54#else // DEBUG_TEST_MEMORY_DESC_OPS_CPP
55template <typename T>
56void print_vec(const char *, const T &) {}
57void print_md(const char *, const memory::desc &) {}
58#endif // DEBUG_TEST_MEMORY_DESC_OPS_CPP
59} // namespace debug
60
61// A proxy to memory::desc with fixed data type (f32)
62struct memory_desc_proxy_t {
63 memory::desc md;
64 memory_desc_proxy_t() = default;
65 memory_desc_proxy_t(const memory::desc &md) : md(md) {}
66
67 memory_desc_proxy_t(const memory::dims &dims, memory::format_tag tag)
68 : md(dims, memory::data_type::f32, tag) {}
69 memory_desc_proxy_t(const memory::dims &dims, const memory::dims &strides)
70 : md(dims, memory::data_type::f32, strides) {}
71
72 memory_desc_proxy_t(const memory::dims &dims, const memory::dims &strides,
73 const memory::dims &padded_dims)
74 : md(dims, memory::data_type::f32, strides) {
75 for (int d = 0; d < md.get_ndims(); ++d)
76 md.get()->padded_dims[d] = padded_dims[d];
77 }
78};
79
80enum test_direction_t { BI_DIRECTION = 0 /* default */, UNI_DIRECTION = 1 };
81
82namespace properties {
83
84using fmt = dnnl::memory::format_tag;
85
86TEST(memory_desc_properties_test, TestMemoryDescSize) {
87 auto md1_simple = memory_desc_proxy_t {{1, 1, 1, 1}, {1, 1, 1, 1}}.md;
88 auto md1_strided = memory_desc_proxy_t {{1, 1, 1, 1}, {8, 4, 2, 1}}.md;
89 auto md2_blocked = memory_desc_proxy_t {{1, 4, 1, 1}, fmt::nChw8c}.md;
90
91 ASSERT_EQ(md1_simple, md1_strided);
92 ASSERT_NE(md2_blocked, md1_simple);
93 ASSERT_NE(md2_blocked, md1_strided);
94
95 ASSERT_EQ(md1_simple.get_size(), 1 * sizeof(float));
96 ASSERT_EQ(md1_strided.get_size(), 1 * sizeof(float));
97 ASSERT_EQ(md2_blocked.get_size(), 8 * sizeof(float));
98}
99
100} // namespace properties
101
102namespace reshape {
103
104struct params_t {
105 memory_desc_proxy_t in;
106 memory_desc_proxy_t out;
107 test_direction_t test_direction;
108 dnnl_status_t expected_status;
109};
110
111class reshape_test_t : public ::testing::TestWithParam<params_t> {
112protected:
113 void Test(const memory::desc &in_md, const memory::desc &out_md) {
114 memory::desc get_out_md = in_md.reshape(out_md.get_dims());
115
116 debug::print_md("in_md", in_md);
117 debug::print_md("out_md", get_out_md);
118 debug::print_md("expect_out_md", out_md);
119
120 ASSERT_EQ(get_out_md, out_md);
121 }
122};
123TEST_P(reshape_test_t, TestsReshape) {
124 params_t p = ::testing::TestWithParam<decltype(p)>::GetParam();
125 catch_expected_failures([=]() { Test(p.in.md, p.out.md); },
126 p.expected_status != dnnl_success, p.expected_status);
127 if (p.test_direction == UNI_DIRECTION) return;
128 catch_expected_failures([=]() { Test(p.out.md, p.in.md); },
129 p.expected_status != dnnl_success, p.expected_status);
130}
131
132using fmt = dnnl::memory::format_tag;
133
134// clang-format off
135auto cases_expect_to_fail = ::testing::Values(
136 // volume mismatch
137 params_t {{{2, 2, 1, 1}, fmt::abcd}, {{2, 2, 2, 1, 1}, fmt::abcde}, BI_DIRECTION, dnnl_invalid_arguments},
138 // volume mismatch
139 params_t {{{2, 1}, {1, 1}}, {{2, 1, 2}, {2, 2, 1}}, BI_DIRECTION, dnnl_invalid_arguments},
140 // volume mismatch
141 params_t {{{6, 2}, fmt::ab}, {{6}, fmt::a}, BI_DIRECTION, dnnl_invalid_arguments},
142 // joining axes are not contiguous in memory (`cdab` would be oK)
143 params_t {{{2, 3, 0, 2}, fmt::cdba}, {{6, 0, 2}, fmt::bca}, UNI_DIRECTION, dnnl_invalid_arguments},
144 // joining axes are not contiguous in memory
145 params_t {{{6, 2}, fmt::ba}, {{12}, fmt::a}, UNI_DIRECTION, dnnl_invalid_arguments},
146 // joining axes are not contiguous in memory (strides {2, 1} would be oK)
147 params_t {{{6, 2}, {3, 1}}, {{12}, fmt::a}, UNI_DIRECTION, dnnl_invalid_arguments},
148 // removing an axis of size `1` that has padding is not allowed
149 params_t {{{6, 1, 2}, {4, 2, 1}, {6, 2, 2}}, {{6, 2}, fmt::any}, UNI_DIRECTION, dnnl_invalid_arguments},
150 // joining axes where one has padding is not allowed
151 params_t {{{6, 2, 2}, {6, 2, 1}, {6, 3, 2}}, {{6, 4}, fmt::any}, UNI_DIRECTION, dnnl_invalid_arguments},
152 // splitting an axis that has padding is not allowed
153 params_t {{{6}, {1}, {12}}, {{2, 3}, fmt::any}, UNI_DIRECTION, dnnl_invalid_arguments},
154 // joining axes are not contiguous (partially, due to the blocking)
155 params_t {{{2, 8, 3, 4}, fmt::aBcd8b}, {{2, 8 * 3 * 4}, fmt::ab}, UNI_DIRECTION, dnnl_invalid_arguments},
156 // nothing can be done with zero memory desc
157 params_t {{}, {}, UNI_DIRECTION, dnnl_invalid_arguments},
158 // run-time dims are not supported
159 params_t {{{DNNL_RUNTIME_DIM_VAL}, {1}}, {{DNNL_RUNTIME_DIM_VAL}, {1}}, UNI_DIRECTION, dnnl_invalid_arguments}
160 );
161
162auto cases_zero_dim = ::testing::Values(
163 params_t {{{2, 0, 2}, fmt::abc}, {{2, 0, 2, 1}, fmt::abcd}},
164 params_t {{{2, 0, 2}, fmt::abc}, {{2, 0, 1, 2, 1}, fmt::abcde}},
165 params_t {{{2, 1, 0, 2}, fmt::abcd}, {{2, 0, 2, 1}, fmt::abcd}},
166 params_t {{{31, 1, 0, 2}, fmt::Abcd16a}, {{1, 31, 0, 2, 1}, fmt::aBcde16b}},
167 params_t {{{2, 3, 0, 2}, {6, 2, 2, 1}}, {{6, 0, 2}, {2, 2, 1}}}
168 );
169
170auto cases_generic = ::testing::Values(
171 // add and/or remove axes of size `1`
172 params_t {{{2, 1}, {2, 2}}, {{2}, {2}}},
173 params_t {{{2, 1}, {2, 2}}, {{2, 1, 1}, {2, 2, 1}}},
174 params_t {{{2, 1}, {2, 2}}, {{2, 1, 1}, {2, 1, 2}}},
175 params_t {{{2, 1}, {2, 2}}, {{2, 1, 1}, {2, 2, 2}}},
176 params_t {{{2, 2}, fmt::ab}, {{2, 2, 1}, fmt::abc}},
177 params_t {{{2, 1}, fmt::ab}, {{1, 2, 1, 1}, fmt::abcd}},
178 params_t {{{1, 2, 1}, fmt::abc}, {{2}, fmt::a}},
179 params_t {{{3, 4, 5, 6}, fmt::ABcd16b16a}, {{1, 3, 4, 5, 6}, fmt::aBCde16c16b}},
180 // UNI_DIRECTION due to ambiguity of adding 1, where there is already another axes of size 1
181 params_t {{{2, 1, 1}, {2, 1, 1}, {2, 2, 1}}, {{2, 1}, {2, 1}, {2, 2}}, UNI_DIRECTION},
182 params_t {{{2, 1, 1}, {2, 2, 1}, {2, 1, 2}}, {{2, 1}, {2, 1}, {2, 2}}},
183 // split and join axes (as test_direction == BI_DIRECTION)
184 params_t {{{6, 2}, fmt::ab}, {{3, 2, 2}, fmt::abc}},
185 params_t {{{6, 2}, fmt::ab}, {{2, 3, 2}, fmt::abc}},
186 params_t {{{6, 2}, fmt::ba}, {{2, 3, 2}, /* fmt::cab: */ {3, 1, 6}}},
187 params_t {{{6, 2}, {4, 1}, {6, 4}}, {{2, 3, 2}, {12, 4, 1}, {2, 3, 4}}},
188 params_t {{{1, 15, 12}, fmt::aBc8b}, {{1, 15, 3, 4}, fmt::aBcd8b}},
189 params_t {{{1, 15, 12}, fmt::aBc8b}, {{1, 15, 2, 3, 2}, fmt::aBcde8b}},
190 // combined cases
191 params_t {{{15, 3, 4}, fmt::abc}, {{3, 5, 6, 1, 2}, fmt::abcde}},
192 params_t {{{15, 3, 4}, fmt::bca}, {{3, 5, 6, 1, 2}, /* fmt::cdeab */ {5, 1, 30, 30, 15}}}
193 );
194// clang-format on
195
196INSTANTIATE_TEST_SUITE_P(TestReshapeEF, reshape_test_t, cases_expect_to_fail);
197INSTANTIATE_TEST_SUITE_P(TestReshapeZeroDim, reshape_test_t, cases_zero_dim);
198INSTANTIATE_TEST_SUITE_P(TestReshapeOK, reshape_test_t, cases_generic);
199
200} // namespace reshape
201
202namespace permute_axes {
203
204struct params_t {
205 memory_desc_proxy_t in;
206 memory_desc_proxy_t out;
207 std::vector<int> perm;
208 test_direction_t test_direction;
209 dnnl_status_t expected_status;
210};
211
212class permute_axes_test_t : public ::testing::TestWithParam<params_t> {
213protected:
214 void Test(const memory::desc &in_md, const memory::desc &out_md,
215 const std::vector<int> &perm) {
216 memory::desc get_out_md = in_md.permute_axes(perm);
217
218 debug::print_md("in_md", in_md);
219 debug::print_vec("perm : ", perm);
220 debug::print_md("out_md", get_out_md);
221 debug::print_md("expect_out_md", out_md);
222
223 ASSERT_EQ(get_out_md, out_md);
224 }
225};
226TEST_P(permute_axes_test_t, TestsPermuteAxes) {
227 params_t p = ::testing::TestWithParam<decltype(p)>::GetParam();
228 catch_expected_failures([=]() { Test(p.in.md, p.out.md, p.perm); },
229 p.expected_status != dnnl_success, p.expected_status);
230 if (p.test_direction == UNI_DIRECTION) return;
231
232 std::vector<int> inv_perm(p.perm.size());
233 for (int i = 0; i < (int)p.perm.size(); ++i)
234 inv_perm[p.perm[i]] = i;
235 catch_expected_failures([=]() { Test(p.out.md, p.in.md, inv_perm); },
236 p.expected_status != dnnl_success, p.expected_status);
237}
238
239using fmt = dnnl::memory::format_tag;
240
241// clang-format off
242auto cases_expect_to_fail = ::testing::Values(
243 // incorrect permutation
244 params_t {{{2, 2, 1, 1}, fmt::abcd}, {{2, 2, 1, 1}, fmt::abcd}, {0, 1, 2, 2}, UNI_DIRECTION, dnnl_invalid_arguments},
245 // incorrect permutation
246 params_t {{{2, 2, 1, 1}, fmt::abcd}, {{2, 2, 1, 1}, fmt::abcd}, {0, 1, 2, 4}, UNI_DIRECTION, dnnl_invalid_arguments},
247 // incorrect permutation
248 params_t {{{2, 2, 1, 1}, fmt::abcd}, {{2, 2, 1, 1}, fmt::abcd}, {0, 1, 2, -1}, UNI_DIRECTION, dnnl_invalid_arguments},
249 // nothing can be done with zero memory desc
250 params_t {{}, {}, {}, UNI_DIRECTION, dnnl_invalid_arguments},
251 // run-time dims are not supported
252 params_t {{{DNNL_RUNTIME_DIM_VAL}, {1}}, {{DNNL_RUNTIME_DIM_VAL}, {1}}, {0}, UNI_DIRECTION, dnnl_invalid_arguments}
253 );
254
255auto cases_generic = ::testing::Values(
256 params_t {{{2, 1}, fmt::ab}, {{2, 1}, fmt::ab}, {0, 1}},
257 params_t {{{2, 1}, fmt::ab}, {{1, 2}, fmt::ba}, {1, 0}},
258 params_t {{{2, 1}, fmt::ba}, {{1, 2}, fmt::ab}, {1, 0}},
259 params_t {{{2, 3}, {4, 1}, {2, 4}}, {{3, 2}, {1, 4}, {4, 2}}, {1, 0}},
260 params_t {{{3, 2}, {2, 30}}, {{2, 3}, {30, 2}}, {1, 0}},
261 params_t {{{2, 3, 4, 5}, fmt::acdb}, {{2, 4, 5, 3}, fmt::abcd}, {0, 3, 1, 2}},
262 params_t {{{2, 3, 4, 5}, fmt::cdba}, {{4, 5, 3, 2}, fmt::abcd}, {3, 2, 0, 1}},
263 params_t {{{2, 15, 3, 4}, fmt::ABcd16b16a}, {{15, 2, 3, 4}, fmt::BAcd16a16b}, {1, 0, 2, 3}},
264 params_t {{{3, 2, 15, 3, 4, 5}, fmt::aBCdef16b16c}, {{3, 15, 2, 3, 4, 5}, fmt::aCBdef16c16b}, {0, 2, 1, 3, 4, 5}}
265 );
266// clang-format on
267
268INSTANTIATE_TEST_SUITE_P(
269 TestPermuteAxesEF, permute_axes_test_t, cases_expect_to_fail);
270INSTANTIATE_TEST_SUITE_P(TestPermuteAxesOK, permute_axes_test_t, cases_generic);
271
272} // namespace permute_axes
273
274} // namespace memory_desc_ops
275} // namespace dnnl
276