1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | namespace dnnl { |
23 | |
24 | using tag = memory::format_tag; |
25 | using dt = memory::data_type; |
26 | |
27 | struct shuffle_test_params_t { |
28 | dt src_dt; |
29 | dt dst_dt; |
30 | tag src_tag; |
31 | tag dst_tag; |
32 | memory::dims dims; |
33 | int axis; |
34 | memory::dim group_size; |
35 | bool expect_to_fail; |
36 | dnnl_status_t expected_status; |
37 | }; |
38 | |
39 | class shuffle_test_t : public ::testing::TestWithParam<shuffle_test_params_t> { |
40 | private: |
41 | shuffle_test_params_t p; |
42 | std::shared_ptr<shuffle_forward::primitive_desc> pd_fwd_hint; |
43 | |
44 | protected: |
45 | void SetUp() override { |
46 | p = ::testing::TestWithParam<shuffle_test_params_t>::GetParam(); |
47 | |
48 | SKIP_IF_CUDA(true, "Shuffle primitive not supported by CUDA" ); |
49 | |
50 | SKIP_IF(unsupported_data_type(p.src_dt, p.dst_dt), |
51 | "Engine does not support this data type." ); |
52 | |
53 | catch_expected_failures( |
54 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
55 | } |
56 | |
57 | void Forward(prop_kind pk) { |
58 | // shuffle specific types and values |
59 | using pd_t = shuffle_forward::primitive_desc; |
60 | |
61 | auto eng = get_test_engine(); |
62 | auto strm = make_stream(eng); |
63 | |
64 | auto aa = allows_attr_t {false}; |
65 | |
66 | auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
67 | auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
68 | |
69 | // default pd ctor |
70 | auto pd = pd_t(); |
71 | // regular pd ctor |
72 | pd = pd_t(eng, pk, src_md, dst_md, p.axis, p.group_size); |
73 | // test all pd ctors |
74 | test_fwd_pd_constructors<pd_t>( |
75 | pd, aa, pk, src_md, dst_md, p.axis, p.group_size); |
76 | pd_fwd_hint = std::make_shared<pd_t>(pd); |
77 | |
78 | EXPECT_ANY_THROW(shuffle_forward(pd, {})); |
79 | // default primitive ctor |
80 | auto shuffle = shuffle_forward(); |
81 | // regular primitive ctor |
82 | shuffle = shuffle_forward(pd); |
83 | |
84 | // check primitive kind is shuffle |
85 | ASSERT_TRUE(shuffle.get_kind() == primitive::kind::shuffle); |
86 | // query for descs from pd |
87 | const auto src_desc = pd.src_desc(); |
88 | const auto dst_desc = pd.dst_desc(); |
89 | // query for src_desc via exec arg |
90 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc); |
91 | if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); } |
92 | // query for dst_desc via exec arg |
93 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc); |
94 | if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); } |
95 | |
96 | // query primitive parameters |
97 | ASSERT_EQ(pd.get_prop_kind(), pk); |
98 | ASSERT_EQ(pd.get_axis(), p.axis); |
99 | ASSERT_EQ(pd.get_group_size(), p.group_size); |
100 | |
101 | // check primitive returns zero_md for all rest md |
102 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
103 | ASSERT_TRUE(pd.diff_src_desc().is_zero()); |
104 | ASSERT_TRUE(pd.diff_dst_desc().is_zero()); |
105 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
106 | |
107 | auto src = test::make_memory(src_desc, eng); |
108 | auto dst = test::make_memory(dst_desc, eng); |
109 | |
110 | fill_data(p.src_dt, src, 1, 1); |
111 | // test out-place mode |
112 | shuffle.execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); |
113 | strm.wait(); |
114 | } |
115 | |
116 | void Backward() { |
117 | // shuffle specific types and values |
118 | using pd_t = shuffle_backward::primitive_desc; |
119 | using hint_pd_t = shuffle_forward::primitive_desc; |
120 | allows_attr_t aa {false}; // doesn't support anything |
121 | |
122 | auto eng = get_test_engine(); |
123 | auto strm = make_stream(eng); |
124 | |
125 | auto diff_src_md = memory::desc(p.dims, p.src_dt, p.src_tag); |
126 | auto diff_dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag); |
127 | |
128 | // default pd ctor |
129 | auto pd = pd_t(); |
130 | // regular pd ctor |
131 | pd = pd_t(eng, diff_src_md, diff_dst_md, p.axis, p.group_size, |
132 | *pd_fwd_hint); |
133 | // test all pd ctors |
134 | test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa, |
135 | diff_src_md, diff_dst_md, p.axis, p.group_size); |
136 | |
137 | EXPECT_ANY_THROW(shuffle_backward(pd, {})); |
138 | // default primitive ctor |
139 | auto shuffle = shuffle_backward(); |
140 | // regular primitive ctor |
141 | shuffle = shuffle_backward(pd); |
142 | |
143 | // check primitive kind is shuffle |
144 | ASSERT_TRUE(shuffle.get_kind() == primitive::kind::shuffle); |
145 | |
146 | // query for descs from pd |
147 | const auto diff_src_desc = pd.diff_src_desc(); |
148 | const auto diff_dst_desc = pd.diff_dst_desc(); |
149 | // query for diff_src_desc via exec arg |
150 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
151 | == diff_src_desc); |
152 | if (p.src_tag != tag::any) { |
153 | ASSERT_TRUE(diff_src_md == diff_src_desc); |
154 | } |
155 | // query for diff_dst_desc via exec arg |
156 | ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST) |
157 | == diff_dst_desc); |
158 | if (p.dst_tag != tag::any) { |
159 | ASSERT_TRUE(diff_dst_md == diff_dst_desc); |
160 | } |
161 | // query primitive parameters |
162 | ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward_data); |
163 | ASSERT_EQ(pd.get_axis(), p.axis); |
164 | ASSERT_EQ(pd.get_group_size(), p.group_size); |
165 | |
166 | // check primitive returns zero_md for all rest md |
167 | ASSERT_TRUE(pd.src_desc().is_zero()); |
168 | ASSERT_TRUE(pd.weights_desc().is_zero()); |
169 | ASSERT_TRUE(pd.dst_desc().is_zero()); |
170 | ASSERT_TRUE(pd.diff_weights_desc().is_zero()); |
171 | |
172 | auto diff_src = test::make_memory(diff_src_desc, eng); |
173 | auto diff_dst = test::make_memory(diff_dst_desc, eng); |
174 | |
175 | fill_data(p.dst_dt, diff_dst, 2, 2); |
176 | |
177 | // test out-place mode |
178 | shuffle.execute(strm, |
179 | {{DNNL_ARG_DIFF_DST, diff_dst}, {DNNL_ARG_DIFF_SRC, diff_src}}); |
180 | strm.wait(); |
181 | } |
182 | |
183 | void Test() { |
184 | const bool is_int8 = p.src_dt == dt::s8 || p.src_dt == dt::u8; |
185 | std::vector<prop_kind> pks = {is_int8 ? prop_kind::forward_inference |
186 | : prop_kind::forward_training}; |
187 | |
188 | for (auto pk : pks) { |
189 | Forward(pk); |
190 | |
191 | bool to_continue = pk != prop_kind::forward_training; |
192 | if (to_continue) continue; |
193 | |
194 | Backward(); |
195 | } |
196 | } |
197 | |
198 | bool is_fwd(prop_kind pk) const { |
199 | return pk == prop_kind::forward_training |
200 | || pk == prop_kind::forward_inference; |
201 | } |
202 | }; |
203 | |
204 | using tp = shuffle_test_params_t; |
205 | |
206 | TEST_P(shuffle_test_t, TestsShuffle) {} |
207 | |
208 | INSTANTIATE_TEST_SUITE_P(Test_Shuffle_EF, shuffle_test_t, |
209 | ::testing::Values( |
210 | // Negative dims |
211 | tp {dt::f32, dt::f32, tag::nchw, tag::nchw, {2, -4, 128, 256}, |
212 | 1, 4, true, dnnl_invalid_arguments}, |
213 | // Non-divisible group_size |
214 | tp {dt::f32, dt::f32, tag::nchw, tag::nchw, {2, 4, 128, 256}, 1, |
215 | 3, true, dnnl_invalid_arguments}, |
216 | // Axis exceeds ndims |
217 | tp {dt::f32, dt::f32, tag::nchw, tag::nchw, {2, 4, 128, 256}, 4, |
218 | 4, true, dnnl_invalid_arguments}, |
219 | // Tag for src on forward is not specified |
220 | tp {dt::f32, dt::f32, tag::any, tag::nchw, {2, 4, 128, 256}, 1, |
221 | 4, true, dnnl_invalid_arguments}, |
222 | // Data type for src is not specified |
223 | tp {dt::undef, dt::f32, tag::nchw, tag::nchw, {2, 4, 128, 256}, |
224 | 1, 4, true, dnnl_invalid_arguments}, |
225 | // Different data types are not supported |
226 | tp {dt::f32, dt::bf16, tag::nchw, tag::nchw, {2, 4, 128, 256}, |
227 | 1, 4, true, dnnl_unimplemented}, |
228 | // Different memory formats are not supported |
229 | tp {dt::f32, dt::f32, tag::nchw, tag::nhwc, {2, 4, 128, 256}, 1, |
230 | 4, true, dnnl_unimplemented})); |
231 | |
232 | static auto all_cases = [](dt src_dt, dt dst_dt) { |
233 | return ::testing::Values( |
234 | tp {src_dt, dst_dt, tag::nwc, tag::nwc, {2, 16, 10}, 1, 4}, |
235 | tp {src_dt, dst_dt, tag::ncw, tag::ncw, {2, 64, 27}, 1, 4}, |
236 | tp {src_dt, dst_dt, tag::nhwc, tag::nhwc, {2, 15, 10, 8}, 1, 3}, |
237 | tp {src_dt, dst_dt, tag::nchw, tag::nchw, {2, 64, 27, 27}, 0, 2}, |
238 | tp {src_dt, dst_dt, tag::nChw8c, tag::nChw8c, {2, 16, 16, 8}, 1, 4}, |
239 | tp {src_dt, dst_dt, tag::nChw16c, tag::nChw16c, {2, 16, 4, 4}, 1, |
240 | 4}, |
241 | tp {src_dt, dst_dt, tag::ncdhw, tag::ncdhw, {2, 64, 7, 7, 7}, 2, 7}, |
242 | tp {src_dt, dst_dt, tag::ncdhw, tag::ncdhw, {10, 10, 10, 10, 10}, 0, |
243 | 5}, |
244 | tp {src_dt, dst_dt, tag::nCdhw16c, tag::nCdhw16c, {4, 16, 2, 2, 2}, |
245 | 1, 4}); |
246 | }; |
247 | |
248 | #define EXPAND_DTS(src, dst) memory::data_type::src, memory::data_type::dst |
249 | |
250 | #define INST_TEST_CASE(name, suite, ...) \ |
251 | INSTANTIATE_TEST_SUITE_P(name, shuffle_test_t, suite(__VA_ARGS__)); |
252 | |
253 | #define CPU_INST_TEST_CASE(name, suite, ...) \ |
254 | CPU_INSTANTIATE_TEST_SUITE_P(name, shuffle_test_t, suite(__VA_ARGS__)); |
255 | |
256 | #define GPU_INST_TEST_CASE(name, suite, ...) \ |
257 | GPU_INSTANTIATE_TEST_SUITE_P(name, shuffle_test_t, suite(__VA_ARGS__)); |
258 | |
259 | INST_TEST_CASE(ShuffleSimpleF32, all_cases, EXPAND_DTS(f32, f32)); |
260 | INST_TEST_CASE(ShuffleSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16)); |
261 | INST_TEST_CASE(ShuffleSimpleF16, all_cases, EXPAND_DTS(f16, f16)); |
262 | INST_TEST_CASE(ShuffleSimpleU8, all_cases, EXPAND_DTS(u8, u8)); |
263 | INST_TEST_CASE(ShuffleSimpleS8, all_cases, EXPAND_DTS(s8, s8)); |
264 | |
265 | } // namespace dnnl |
266 | |