1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef TEST_REORDER_COMMON_HPP
18#define TEST_REORDER_COMMON_HPP
19
20#include <memory>
21#include <numeric>
22#include <utility>
23#include <type_traits>
24
25#include "dnnl_test_common.hpp"
26#include "gtest/gtest.h"
27
28#include "oneapi/dnnl/dnnl.hpp"
29
30namespace dnnl {
31
32template <typename data_i_t, typename data_o_t>
33inline void check_reorder(const memory::desc &md_i, const memory::desc &md_o,
34 memory &src, memory &dst) {
35 auto src_data = map_memory<data_i_t>(src);
36 auto dst_data = map_memory<data_o_t>(dst);
37
38 const auto dims = md_i.get_dims();
39 const size_t nelems = std::accumulate(
40 dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
41
42 const dnnl::impl::memory_desc_wrapper mdw_i(md_i.get());
43 const dnnl::impl::memory_desc_wrapper mdw_o(md_o.get());
44 for (size_t i = 0; i < nelems; ++i) {
45 data_i_t s_raw = src_data[mdw_i.off_l(i, false)];
46 data_o_t s = static_cast<data_o_t>(s_raw);
47 data_o_t d = dst_data[mdw_o.off_l(i, false)];
48 ASSERT_EQ(s, d) << "mismatch at position " << i;
49 }
50}
51
52template <typename reorder_types>
53struct test_simple_params {
54 memory::format_tag fmt_i;
55 memory::format_tag fmt_o;
56 memory::dims dims;
57 bool expect_to_fail;
58 dnnl_status_t expected_status;
59};
60
61template <typename reorder_types>
62class reorder_simple_test
63 : public ::testing::TestWithParam<test_simple_params<reorder_types>> {
64protected:
65#ifdef DNNL_TEST_WITH_ENGINE_PARAM
66 void Test() {
67 using data_i_t = typename reorder_types::first_type;
68 using data_o_t = typename reorder_types::second_type;
69 memory::data_type prec_i = data_traits<data_i_t>::data_type;
70 memory::data_type prec_o = data_traits<data_o_t>::data_type;
71
72 SKIP_IF(unsupported_data_type(prec_i),
73 "Engine does not support this data type.");
74 SKIP_IF(unsupported_data_type(prec_o),
75 "Engine does not support this data type.");
76
77 test_simple_params<reorder_types> p
78 = ::testing::TestWithParam<decltype(p)>::GetParam();
79
80 SKIP_IF_CUDA(!((supported_format(p.fmt_i)
81 || supported_blocking(prec_i, p.fmt_i))
82 && (supported_format(p.fmt_o)
83 || supported_blocking(prec_o, p.fmt_o))),
84 "Unsupported cuda format tag/ data type");
85
86 catch_expected_failures(
87 [=]() {
88 engine eng = get_test_engine();
89 RunTest(eng, eng);
90 },
91 p.expect_to_fail, p.expected_status);
92 }
93#endif
94 bool supported_format(memory::format_tag fmt) {
95 return impl::utils::one_of(fmt, memory::format_tag::abcde,
96 memory::format_tag::acdeb, memory::format_tag::abcd,
97 memory::format_tag::acdb, memory::format_tag::abc,
98 memory::format_tag::acb, memory::format_tag::ab,
99 memory::format_tag::ba, memory::format_tag::a,
100 memory::format_tag::any);
101 }
102
103 bool supported_blocking(memory::data_type dt, memory::format_tag fmt) {
104 return (dt == dnnl_u8
105 && impl::utils::one_of(fmt, dnnl_aBcd4b, dnnl_aBcde4b));
106 }
107
108 void Test(engine &eng_i, engine &eng_o) {
109 using data_i_t = typename reorder_types::first_type;
110 using data_o_t = typename reorder_types::second_type;
111 memory::data_type prec_i = data_traits<data_i_t>::data_type;
112 memory::data_type prec_o = data_traits<data_o_t>::data_type;
113
114 SKIP_IF(unsupported_data_type(prec_i, eng_i),
115 "Engine does not support this data type.");
116 SKIP_IF(unsupported_data_type(prec_o, eng_o),
117 "Engine does not support this data type.");
118
119 test_simple_params<reorder_types> p
120 = ::testing::TestWithParam<decltype(p)>::GetParam();
121
122#ifdef DNNL_SYCL_CUDA
123 SKIP_IF(!((supported_format(p.fmt_i)
124 || supported_blocking(prec_i, p.fmt_i))
125 && (supported_format(p.fmt_o)
126 || supported_blocking(prec_o, p.fmt_o))),
127 "Unsupported cuda format tag/ data type");
128#endif
129
130 catch_expected_failures([&]() { RunTest(eng_i, eng_o); },
131 p.expect_to_fail, p.expected_status);
132 }
133
134 void RunTest(engine &eng_i, engine &eng_o) {
135 using data_i_t = typename reorder_types::first_type;
136 using data_o_t = typename reorder_types::second_type;
137
138 test_simple_params<reorder_types> p
139 = ::testing::TestWithParam<decltype(p)>::GetParam();
140
141 const size_t nelems = std::accumulate(p.dims.begin(), p.dims.end(),
142 size_t(1), std::multiplies<size_t>());
143
144 memory::data_type prec_i = data_traits<data_i_t>::data_type;
145 memory::data_type prec_o = data_traits<data_o_t>::data_type;
146 auto md_i = memory::desc(p.dims, prec_i, p.fmt_i);
147 auto md_o = memory::desc(p.dims, prec_o, p.fmt_o);
148
149 reorder::primitive_desc r_pd(
150 eng_i, md_i, eng_o, md_o, primitive_attr());
151 // test construction from a C pd
152 r_pd = reorder::primitive_desc(r_pd.get());
153
154 ASSERT_TRUE(r_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
155 == r_pd.src_desc());
156 ASSERT_TRUE(r_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
157 == r_pd.dst_desc());
158 if (p.fmt_i != memory::format_tag::any) {
159 ASSERT_TRUE(md_i == r_pd.src_desc());
160 }
161
162 auto src = test::make_memory(r_pd.src_desc(), eng_i);
163 auto dst = test::make_memory(r_pd.dst_desc(), eng_o);
164
165 /* initialize input data */
166 const dnnl::impl::memory_desc_wrapper mdw_i(md_i.get());
167 {
168 auto src_data = map_memory<data_i_t>(src);
169 for (size_t i = 0; i < nelems; ++i)
170 src_data[mdw_i.off_l(i, false)] = data_i_t(i);
171 }
172
173 EXPECT_ANY_THROW(reorder(r_pd, {}));
174 auto r = reorder(r_pd);
175 auto strm = make_stream(r_pd.get_engine());
176 r.execute(strm, src, dst);
177 strm.wait();
178
179 check_reorder<data_i_t, data_o_t>(md_i, md_o, src, dst);
180 check_zero_tail<data_o_t>(0, dst);
181 }
182};
183
184} // namespace dnnl
185
186#endif
187