1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef TEST_CONVOLUTION_FORWARD_COMMON_H |
18 | #define TEST_CONVOLUTION_FORWARD_COMMON_H |
19 | |
20 | #include "dnnl_test_common.hpp" |
21 | #include "gtest/gtest.h" |
22 | |
23 | #include <stdint.h> |
24 | #include "oneapi/dnnl/dnnl.hpp" |
25 | |
26 | #include <math.h> |
27 | |
28 | namespace dnnl { |
29 | |
30 | template <typename data_t_src, typename data_t_wei, typename data_t_acc, |
31 | typename data_t_dst> |
32 | void compute_ref_conv_fwd(const test_convolution_sizes_t &c, |
33 | const test_convolution_attr_t &attr, const memory::desc &src_d, |
34 | const memory::desc &weights_d, const memory::desc &bias_d, |
35 | const memory::desc &dst_d, const memory &src, const memory &weights, |
36 | const memory &bias, const memory &dst) { |
37 | const bool w_bias = bias_d.get_ndims() != 0; |
38 | auto src_data = map_memory<data_t_src>(src); |
39 | auto weights_data = map_memory<data_t_wei>(weights); |
40 | |
41 | auto bias_data = w_bias ? map_memory<data_t_dst>(bias) : nullptr; |
42 | auto dst_data = map_memory<data_t_dst>(dst); |
43 | |
44 | auto padded_ic = src_d.get_padded_dims()[1]; |
45 | auto padded_oc = dst_d.get_padded_dims()[1]; |
46 | |
47 | const dnnl::impl::memory_desc_wrapper src_mdw(src_d.get()); |
48 | const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.get()); |
49 | const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get()); |
50 | const dnnl::impl::memory_desc_wrapper bias_mdw(bias_d.get()); |
51 | |
52 | dnnl::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow, |
53 | [&](memory::dim n, memory::dim g, memory::dim oc, memory::dim oh, |
54 | memory::dim ow) { |
55 | data_t_acc a = 0; |
56 | for (memory::dim ic = 0; ic < c.ic / c.ng; ic++) { |
57 | for (memory::dim kh = 0; kh < c.kh; kh++) { |
58 | for (memory::dim kw = 0; kw < c.kw; kw++) { |
59 | memory::dim iw |
60 | = ow * c.strw - c.padw + kw * (1 + c.dilw); |
61 | memory::dim ih |
62 | = oh * c.strh - c.padh + kh * (1 + c.dilh); |
63 | if (iw < 0 || iw >= c.iw) continue; |
64 | if (ih < 0 || ih >= c.ih) continue; |
65 | memory::dim iidx = n * padded_ic * c.ih * c.iw |
66 | + g * padded_ic / c.ng * c.ih * c.iw |
67 | + ic * c.ih * c.iw + ih * c.iw + iw; |
68 | memory::dim widx = g * padded_oc / c.ng * padded_ic |
69 | / c.ng * c.kh * c.kw |
70 | + oc * padded_ic / c.ng * c.kh * c.kw |
71 | + ic * c.kh * c.kw + kh * c.kw + kw; |
72 | a += ((data_t_acc)src_data[src_mdw.off_l( |
73 | iidx, true)]) |
74 | * weights_data[weights_mdw.off_l( |
75 | widx, true)]; |
76 | } |
77 | } |
78 | } |
79 | |
80 | float a_fp = (float)a; |
81 | |
82 | if (attr.src_scale.is_def()) { |
83 | const auto &s = attr.src_scale; |
84 | using P = test_convolution_attr_t::scale_t; |
85 | if (s.policy == P::policy_t::COMMON) { a_fp *= s.scale; } |
86 | } |
87 | |
88 | if (attr.wei_scale.is_def()) { |
89 | const auto &s = attr.wei_scale; |
90 | using P = test_convolution_attr_t::scale_t; |
91 | if (s.policy == P::policy_t::COMMON) { a_fp *= s.scale; } |
92 | } |
93 | |
94 | a_fp += (float)(bias_data ? bias_data[bias_mdw.off_l( |
95 | g * c.oc / c.ng + oc, true)] |
96 | : 0); |
97 | |
98 | if (attr.dst_scale.is_def()) { |
99 | const auto &s = attr.dst_scale; |
100 | using P = test_convolution_attr_t::scale_t; |
101 | if (s.policy == P::policy_t::COMMON) { a_fp /= s.scale; } |
102 | } |
103 | |
104 | a_fp = out_round<data_t_dst>(a_fp); |
105 | |
106 | memory::dim oidx = n * padded_oc * c.oh * c.ow |
107 | + g * padded_oc / c.ng * c.oh * c.ow + oc * c.oh * c.ow |
108 | + oh * c.ow + ow; |
109 | dst_data[dst_mdw.off_l(oidx, true)] = (data_t_dst)a_fp; |
110 | }); |
111 | } |
112 | |
113 | template <typename data_t_src, typename data_t_wei, typename data_t_acc, |
114 | typename data_t_dst> |
115 | class convolution_forward_test |
116 | : public ::testing::TestWithParam<test_convolution_params_t> { |
117 | protected: |
118 | virtual void SetUp() { |
119 | memory::data_type data_type_src = data_traits<data_t_src>::data_type; |
120 | memory::data_type data_type_dst = data_traits<data_t_dst>::data_type; |
121 | memory::data_type data_type_wei = data_traits<data_t_wei>::data_type; |
122 | |
123 | SKIP_IF(unsupported_data_type(data_type_src), |
124 | "Engine does not support this data type." ); |
125 | SKIP_IF(unsupported_data_type(data_type_dst), |
126 | "Engine does not support this data type." ); |
127 | SKIP_IF(unsupported_data_type(data_type_wei), |
128 | "Engine does not support this data type." ); |
129 | |
130 | auto p = ::testing::TestWithParam< |
131 | test_convolution_params_t>::GetParam(); |
132 | |
133 | SKIP_IF_CUDA( |
134 | !(cuda_check_format_tags(p.formats.src_format, data_type_src) |
135 | && cuda_check_format_tags( |
136 | p.formats.dst_format, data_type_dst) |
137 | && (cuda_check_format_tags( |
138 | p.formats.weights_format, data_type_wei) |
139 | || impl::utils::one_of(p.formats.weights_format, |
140 | /* weights formats */ |
141 | memory::format_tag::gowi, |
142 | memory::format_tag::gohwi, |
143 | memory::format_tag::godhwi, |
144 | memory::format_tag::owi, |
145 | memory::format_tag::ohwi, |
146 | memory::format_tag::odhwi))), |
147 | "Format is not supported." ); |
148 | |
149 | catch_expected_failures( |
150 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
151 | } |
152 | |
153 | bool cuda_check_format_tags(memory::format_tag tag, memory::data_type dt) { |
154 | return ((impl::utils::one_of(tag, memory::format_tag::ab, |
155 | memory::format_tag::abc, memory::format_tag::abcd, |
156 | memory::format_tag::abcde, memory::format_tag::abcdef, |
157 | memory::format_tag::acb, memory::format_tag::acdb, |
158 | memory::format_tag::acdeb)) |
159 | || (dt == memory::data_type::s8 |
160 | && impl::utils::one_of(tag, memory::format_tag::aBcd4b, |
161 | memory::format_tag::aBcde4b))); |
162 | } |
163 | |
164 | void Test() { |
165 | auto p = ::testing::TestWithParam< |
166 | test_convolution_params_t>::GetParam(); |
167 | ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct); |
168 | auto eng = get_test_engine(); |
169 | auto strm = stream(eng); |
170 | |
171 | memory::data_type data_type_src = data_traits<data_t_src>::data_type; |
172 | memory::data_type data_type_dst = data_traits<data_t_dst>::data_type; |
173 | memory::data_type data_type_wei = data_traits<data_t_wei>::data_type; |
174 | |
175 | test_convolution_sizes_t cd = p.sizes; |
176 | |
177 | test_convolution_attr_t attr = p.attr; |
178 | attr.dnnl_attr_recreate(); |
179 | |
180 | auto aprop_kind = prop_kind::forward; |
181 | bool with_bias = p.formats.bias_format != memory::format_tag::undef; |
182 | bool with_src_scales = attr.src_scale.is_def(); |
183 | bool with_wei_scales = attr.wei_scale.is_def(); |
184 | bool with_dst_scales = attr.dst_scale.is_def(); |
185 | |
186 | auto c_src_desc = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, data_type_src, |
187 | p.formats.src_format); |
188 | auto c_weights_desc = cd.ng > 1 |
189 | ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw}, |
190 | data_type_wei, p.formats.weights_format) |
191 | : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type_wei, |
192 | p.formats.weights_format); |
193 | auto c_dst_desc = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, data_type_dst, |
194 | p.formats.dst_format); |
195 | auto c_bias_desc = with_bias |
196 | ? create_md({cd.oc}, data_type_dst, p.formats.bias_format) |
197 | : create_md({}, data_type_dst, p.formats.bias_format); |
198 | auto c_src_scales_desc = with_src_scales |
199 | ? create_md({1}, memory::data_type::f32, memory::format_tag::x) |
200 | : create_md({}, memory::data_type::f32, memory::format_tag::x); |
201 | auto c_wei_scales_desc = with_wei_scales |
202 | ? create_md({1}, memory::data_type::f32, memory::format_tag::x) |
203 | : create_md({}, memory::data_type::f32, memory::format_tag::x); |
204 | auto c_dst_scales_desc = with_dst_scales |
205 | ? create_md({1}, memory::data_type::f32, memory::format_tag::x) |
206 | : create_md({}, memory::data_type::f32, memory::format_tag::x); |
207 | |
208 | auto c_src = test_memory(c_src_desc, eng); |
209 | auto c_weights = test_memory(c_weights_desc, eng); |
210 | auto c_bias = test_memory(c_bias_desc, eng); |
211 | auto c_dst = test_memory(c_dst_desc, eng); |
212 | auto c_src_scales = test_memory(c_src_scales_desc, eng); |
213 | auto c_wei_scales = test_memory(c_wei_scales_desc, eng); |
214 | auto c_dst_scales = test_memory(c_dst_scales_desc, eng); |
215 | |
216 | // Only true for dense format |
217 | fill_data<data_t_dst>( |
218 | c_dst.get_size() / sizeof(data_t_dst), c_dst.get()); |
219 | fill_data<data_t_src>( |
220 | c_src.get_size() / sizeof(data_t_src), c_src.get()); |
221 | fill_data<data_t_wei>( |
222 | c_weights.get_size() / sizeof(data_t_wei), c_weights.get()); |
223 | if (with_bias) { |
224 | fill_data<data_t_dst>( |
225 | c_bias.get_size() / sizeof(data_t_dst), c_bias.get()); |
226 | } |
227 | if (with_src_scales) { |
228 | fill_data<float>(c_src_scales.get_size() / sizeof(float), |
229 | c_src_scales.get(), attr.src_scale.scale, 0.0f); |
230 | } |
231 | if (with_wei_scales) { |
232 | fill_data<float>(c_wei_scales.get_size() / sizeof(float), |
233 | c_wei_scales.get(), attr.wei_scale.scale, 0.0f); |
234 | } |
235 | if (with_dst_scales) { |
236 | fill_data<float>(c_dst_scales.get_size() / sizeof(float), |
237 | c_dst_scales.get(), attr.dst_scale.scale, 0.0f); |
238 | } |
239 | |
240 | check_zero_tail<data_t_src>(1, c_src.get()); |
241 | check_zero_tail<data_t_wei>(1, c_weights.get()); |
242 | check_zero_tail<data_t_dst>(1, c_dst.get()); |
243 | |
244 | memory::dims strides = {cd.strh, cd.strw}; |
245 | memory::dims dilations = {cd.dilh, cd.dilw}; |
246 | memory::dims padL = {cd.padh, cd.padw}; |
247 | memory::dims padR = { |
248 | right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh), |
249 | right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)}; |
250 | |
251 | auto conv_primitive_desc = with_bias |
252 | ? convolution_forward::primitive_desc(eng, aprop_kind, |
253 | p.aalgorithm, c_src_desc, c_weights_desc, c_bias_desc, |
254 | c_dst_desc, strides, dilations, padL, padR, |
255 | attr.dnnl_attr) |
256 | : convolution_forward::primitive_desc(eng, aprop_kind, |
257 | p.aalgorithm, c_src_desc, c_weights_desc, c_dst_desc, |
258 | strides, dilations, padL, padR, attr.dnnl_attr); |
259 | |
260 | conv_primitive_desc = convolution_forward::primitive_desc( |
261 | conv_primitive_desc.get()); // test construction from a C pd |
262 | |
263 | ASSERT_TRUE( |
264 | conv_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC) |
265 | == conv_primitive_desc.src_desc()); |
266 | ASSERT_TRUE( |
267 | conv_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_DST) |
268 | == conv_primitive_desc.dst_desc()); |
269 | ASSERT_TRUE(conv_primitive_desc.query_md( |
270 | query::exec_arg_md, DNNL_ARG_WEIGHTS) |
271 | == conv_primitive_desc.weights_desc()); |
272 | ASSERT_TRUE( |
273 | conv_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_BIAS) |
274 | == conv_primitive_desc.bias_desc()); |
275 | |
276 | ASSERT_EQ(conv_primitive_desc.get_algorithm(), p.aalgorithm); |
277 | ASSERT_EQ(conv_primitive_desc.get_prop_kind(), aprop_kind); |
278 | ASSERT_EQ(conv_primitive_desc.get_strides(), strides); |
279 | ASSERT_EQ(conv_primitive_desc.get_dilations(), dilations); |
280 | ASSERT_EQ(conv_primitive_desc.get_padding_l(), padL); |
281 | ASSERT_EQ(conv_primitive_desc.get_padding_r(), padR); |
282 | |
283 | EXPECT_ANY_THROW(convolution_forward(conv_primitive_desc, {})); |
284 | convolution_forward(conv_primitive_desc) |
285 | .execute(strm, |
286 | {{DNNL_ARG_SRC, c_src.get()}, |
287 | {DNNL_ARG_WEIGHTS, c_weights.get()}, |
288 | {DNNL_ARG_BIAS, c_bias.get()}, |
289 | {DNNL_ARG_DST, c_dst.get()}, |
290 | {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, |
291 | c_src_scales.get()}, |
292 | {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, |
293 | c_wei_scales.get()}, |
294 | {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, |
295 | c_dst_scales.get()}}); |
296 | strm.wait(); |
297 | |
298 | auto ref_memory = test::make_memory(c_dst_desc, eng); |
299 | compute_ref_conv_fwd<data_t_src, data_t_wei, data_t_acc, data_t_dst>(cd, |
300 | attr, c_src_desc, c_weights_desc, c_bias_desc, c_dst_desc, |
301 | c_src.get(), c_weights.get(), c_bias.get(), ref_memory); |
302 | check_zero_tail<data_t_dst>(1, ref_memory); |
303 | |
304 | compare_data<data_t_dst>(ref_memory, c_dst.get()); |
305 | check_zero_tail<data_t_dst>(0, c_dst.get()); |
306 | } |
307 | }; |
308 | |
309 | } // namespace dnnl |
310 | #endif |
311 | |