1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22namespace dnnl {
23
24using tag = memory::format_tag;
25using dt = memory::data_type;
26
27struct shuffle_test_params_t {
28 dt src_dt;
29 dt dst_dt;
30 tag src_tag;
31 tag dst_tag;
32 memory::dims dims;
33 int axis;
34 memory::dim group_size;
35 bool expect_to_fail;
36 dnnl_status_t expected_status;
37};
38
39class shuffle_test_t : public ::testing::TestWithParam<shuffle_test_params_t> {
40private:
41 shuffle_test_params_t p;
42 std::shared_ptr<shuffle_forward::primitive_desc> pd_fwd_hint;
43
44protected:
45 void SetUp() override {
46 p = ::testing::TestWithParam<shuffle_test_params_t>::GetParam();
47
48 SKIP_IF_CUDA(true, "Shuffle primitive not supported by CUDA");
49
50 SKIP_IF(unsupported_data_type(p.src_dt, p.dst_dt),
51 "Engine does not support this data type.");
52
53 catch_expected_failures(
54 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
55 }
56
57 void Forward(prop_kind pk) {
58 // shuffle specific types and values
59 using pd_t = shuffle_forward::primitive_desc;
60
61 auto eng = get_test_engine();
62 auto strm = make_stream(eng);
63
64 auto aa = allows_attr_t {false};
65
66 auto src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
67 auto dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
68
69 // default pd ctor
70 auto pd = pd_t();
71 // regular pd ctor
72 pd = pd_t(eng, pk, src_md, dst_md, p.axis, p.group_size);
73 // test all pd ctors
74 test_fwd_pd_constructors<pd_t>(
75 pd, aa, pk, src_md, dst_md, p.axis, p.group_size);
76 pd_fwd_hint = std::make_shared<pd_t>(pd);
77
78 EXPECT_ANY_THROW(shuffle_forward(pd, {}));
79 // default primitive ctor
80 auto shuffle = shuffle_forward();
81 // regular primitive ctor
82 shuffle = shuffle_forward(pd);
83
84 // check primitive kind is shuffle
85 ASSERT_TRUE(shuffle.get_kind() == primitive::kind::shuffle);
86 // query for descs from pd
87 const auto src_desc = pd.src_desc();
88 const auto dst_desc = pd.dst_desc();
89 // query for src_desc via exec arg
90 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == src_desc);
91 if (p.src_tag != tag::any) { ASSERT_TRUE(src_md == src_desc); }
92 // query for dst_desc via exec arg
93 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == dst_desc);
94 if (p.dst_tag != tag::any) { ASSERT_TRUE(dst_md == dst_desc); }
95
96 // query primitive parameters
97 ASSERT_EQ(pd.get_prop_kind(), pk);
98 ASSERT_EQ(pd.get_axis(), p.axis);
99 ASSERT_EQ(pd.get_group_size(), p.group_size);
100
101 // check primitive returns zero_md for all rest md
102 ASSERT_TRUE(pd.weights_desc().is_zero());
103 ASSERT_TRUE(pd.diff_src_desc().is_zero());
104 ASSERT_TRUE(pd.diff_dst_desc().is_zero());
105 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
106
107 auto src = test::make_memory(src_desc, eng);
108 auto dst = test::make_memory(dst_desc, eng);
109
110 fill_data(p.src_dt, src, 1, 1);
111 // test out-place mode
112 shuffle.execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}});
113 strm.wait();
114 }
115
116 void Backward() {
117 // shuffle specific types and values
118 using pd_t = shuffle_backward::primitive_desc;
119 using hint_pd_t = shuffle_forward::primitive_desc;
120 allows_attr_t aa {false}; // doesn't support anything
121
122 auto eng = get_test_engine();
123 auto strm = make_stream(eng);
124
125 auto diff_src_md = memory::desc(p.dims, p.src_dt, p.src_tag);
126 auto diff_dst_md = memory::desc(p.dims, p.dst_dt, p.dst_tag);
127
128 // default pd ctor
129 auto pd = pd_t();
130 // regular pd ctor
131 pd = pd_t(eng, diff_src_md, diff_dst_md, p.axis, p.group_size,
132 *pd_fwd_hint);
133 // test all pd ctors
134 test_bwd_pd_constructors<pd_t, hint_pd_t>(pd, *pd_fwd_hint, aa,
135 diff_src_md, diff_dst_md, p.axis, p.group_size);
136
137 EXPECT_ANY_THROW(shuffle_backward(pd, {}));
138 // default primitive ctor
139 auto shuffle = shuffle_backward();
140 // regular primitive ctor
141 shuffle = shuffle_backward(pd);
142
143 // check primitive kind is shuffle
144 ASSERT_TRUE(shuffle.get_kind() == primitive::kind::shuffle);
145
146 // query for descs from pd
147 const auto diff_src_desc = pd.diff_src_desc();
148 const auto diff_dst_desc = pd.diff_dst_desc();
149 // query for diff_src_desc via exec arg
150 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
151 == diff_src_desc);
152 if (p.src_tag != tag::any) {
153 ASSERT_TRUE(diff_src_md == diff_src_desc);
154 }
155 // query for diff_dst_desc via exec arg
156 ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
157 == diff_dst_desc);
158 if (p.dst_tag != tag::any) {
159 ASSERT_TRUE(diff_dst_md == diff_dst_desc);
160 }
161 // query primitive parameters
162 ASSERT_EQ(pd.get_prop_kind(), prop_kind::backward_data);
163 ASSERT_EQ(pd.get_axis(), p.axis);
164 ASSERT_EQ(pd.get_group_size(), p.group_size);
165
166 // check primitive returns zero_md for all rest md
167 ASSERT_TRUE(pd.src_desc().is_zero());
168 ASSERT_TRUE(pd.weights_desc().is_zero());
169 ASSERT_TRUE(pd.dst_desc().is_zero());
170 ASSERT_TRUE(pd.diff_weights_desc().is_zero());
171
172 auto diff_src = test::make_memory(diff_src_desc, eng);
173 auto diff_dst = test::make_memory(diff_dst_desc, eng);
174
175 fill_data(p.dst_dt, diff_dst, 2, 2);
176
177 // test out-place mode
178 shuffle.execute(strm,
179 {{DNNL_ARG_DIFF_DST, diff_dst}, {DNNL_ARG_DIFF_SRC, diff_src}});
180 strm.wait();
181 }
182
183 void Test() {
184 const bool is_int8 = p.src_dt == dt::s8 || p.src_dt == dt::u8;
185 std::vector<prop_kind> pks = {is_int8 ? prop_kind::forward_inference
186 : prop_kind::forward_training};
187
188 for (auto pk : pks) {
189 Forward(pk);
190
191 bool to_continue = pk != prop_kind::forward_training;
192 if (to_continue) continue;
193
194 Backward();
195 }
196 }
197
198 bool is_fwd(prop_kind pk) const {
199 return pk == prop_kind::forward_training
200 || pk == prop_kind::forward_inference;
201 }
202};
203
204using tp = shuffle_test_params_t;
205
206TEST_P(shuffle_test_t, TestsShuffle) {}
207
208INSTANTIATE_TEST_SUITE_P(Test_Shuffle_EF, shuffle_test_t,
209 ::testing::Values(
210 // Negative dims
211 tp {dt::f32, dt::f32, tag::nchw, tag::nchw, {2, -4, 128, 256},
212 1, 4, true, dnnl_invalid_arguments},
213 // Non-divisible group_size
214 tp {dt::f32, dt::f32, tag::nchw, tag::nchw, {2, 4, 128, 256}, 1,
215 3, true, dnnl_invalid_arguments},
216 // Axis exceeds ndims
217 tp {dt::f32, dt::f32, tag::nchw, tag::nchw, {2, 4, 128, 256}, 4,
218 4, true, dnnl_invalid_arguments},
219 // Tag for src on forward is not specified
220 tp {dt::f32, dt::f32, tag::any, tag::nchw, {2, 4, 128, 256}, 1,
221 4, true, dnnl_invalid_arguments},
222 // Data type for src is not specified
223 tp {dt::undef, dt::f32, tag::nchw, tag::nchw, {2, 4, 128, 256},
224 1, 4, true, dnnl_invalid_arguments},
225 // Different data types are not supported
226 tp {dt::f32, dt::bf16, tag::nchw, tag::nchw, {2, 4, 128, 256},
227 1, 4, true, dnnl_unimplemented},
228 // Different memory formats are not supported
229 tp {dt::f32, dt::f32, tag::nchw, tag::nhwc, {2, 4, 128, 256}, 1,
230 4, true, dnnl_unimplemented}));
231
232static auto all_cases = [](dt src_dt, dt dst_dt) {
233 return ::testing::Values(
234 tp {src_dt, dst_dt, tag::nwc, tag::nwc, {2, 16, 10}, 1, 4},
235 tp {src_dt, dst_dt, tag::ncw, tag::ncw, {2, 64, 27}, 1, 4},
236 tp {src_dt, dst_dt, tag::nhwc, tag::nhwc, {2, 15, 10, 8}, 1, 3},
237 tp {src_dt, dst_dt, tag::nchw, tag::nchw, {2, 64, 27, 27}, 0, 2},
238 tp {src_dt, dst_dt, tag::nChw8c, tag::nChw8c, {2, 16, 16, 8}, 1, 4},
239 tp {src_dt, dst_dt, tag::nChw16c, tag::nChw16c, {2, 16, 4, 4}, 1,
240 4},
241 tp {src_dt, dst_dt, tag::ncdhw, tag::ncdhw, {2, 64, 7, 7, 7}, 2, 7},
242 tp {src_dt, dst_dt, tag::ncdhw, tag::ncdhw, {10, 10, 10, 10, 10}, 0,
243 5},
244 tp {src_dt, dst_dt, tag::nCdhw16c, tag::nCdhw16c, {4, 16, 2, 2, 2},
245 1, 4});
246};
247
248#define EXPAND_DTS(src, dst) memory::data_type::src, memory::data_type::dst
249
250#define INST_TEST_CASE(name, suite, ...) \
251 INSTANTIATE_TEST_SUITE_P(name, shuffle_test_t, suite(__VA_ARGS__));
252
253#define CPU_INST_TEST_CASE(name, suite, ...) \
254 CPU_INSTANTIATE_TEST_SUITE_P(name, shuffle_test_t, suite(__VA_ARGS__));
255
256#define GPU_INST_TEST_CASE(name, suite, ...) \
257 GPU_INSTANTIATE_TEST_SUITE_P(name, shuffle_test_t, suite(__VA_ARGS__));
258
259INST_TEST_CASE(ShuffleSimpleF32, all_cases, EXPAND_DTS(f32, f32));
260INST_TEST_CASE(ShuffleSimpleBF16, all_cases, EXPAND_DTS(bf16, bf16));
261INST_TEST_CASE(ShuffleSimpleF16, all_cases, EXPAND_DTS(f16, f16));
262INST_TEST_CASE(ShuffleSimpleU8, all_cases, EXPAND_DTS(u8, u8));
263INST_TEST_CASE(ShuffleSimpleS8, all_cases, EXPAND_DTS(s8, s8));
264
265} // namespace dnnl
266