1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef TEST_CONVOLUTION_BACKWARD_DATA_COMMON_H |
18 | #define TEST_CONVOLUTION_BACKWARD_DATA_COMMON_H |
19 | |
20 | #include "dnnl_test_common.hpp" |
21 | #include "gtest/gtest.h" |
22 | |
23 | #include "oneapi/dnnl/dnnl.hpp" |
24 | |
25 | namespace dnnl { |
26 | |
27 | template <typename data_t_diff_dst, typename data_t_wei, typename data_t_acc, |
28 | typename data_t_diff_src> |
29 | void compute_ref_conv_bwd_data(const test_convolution_sizes_t &c, |
30 | const memory &diff_src, const memory &weights, const memory &diff_dst) { |
31 | auto diff_dst_data = map_memory<data_t_diff_dst>(diff_dst); |
32 | auto weights_data = map_memory<data_t_wei>(weights); |
33 | auto diff_src_data = map_memory<data_t_diff_src>(diff_src); |
34 | |
35 | const memory::desc diff_src_d = diff_src.get_desc(); |
36 | const memory::desc weights_d = weights.get_desc(); |
37 | const memory::desc diff_dst_d = diff_dst.get_desc(); |
38 | |
39 | auto padded_ic = diff_src_d.get_padded_dims()[1]; |
40 | auto padded_oc = diff_dst_d.get_padded_dims()[1]; |
41 | |
42 | const dnnl::impl::memory_desc_wrapper diff_src_mdw(diff_src_d.get()); |
43 | const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get()); |
44 | const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get()); |
45 | |
46 | dnnl::impl::parallel_nd(c.mb, c.ng, c.ic / c.ng, c.ih, c.iw, |
47 | [&](memory::dim mb, memory::dim g, memory::dim ic, memory::dim ih, |
48 | memory::dim iw) { |
49 | memory::dim sidx = mb * padded_ic * c.ih * c.iw |
50 | + g * padded_ic / c.ng * c.ih * c.iw + ic * c.ih * c.iw |
51 | + ih * c.iw + iw; |
52 | data_t_acc a = data_t_acc(0); |
53 | for (memory::dim oc = 0; oc < c.oc / c.ng; oc++) { |
54 | for (memory::dim kh = 0; kh < c.kh; kh++) { |
55 | for (memory::dim kw = 0; kw < c.kw; kw++) { |
56 | if (iw + c.padw < kw * (1 + c.dilw) |
57 | || ih + c.padh < kh * (1 + c.dilh)) |
58 | continue; |
59 | memory::dim ow = iw - kw * (1 + c.dilw) + c.padw; |
60 | memory::dim oh = ih - kh * (1 + c.dilh) + c.padh; |
61 | if (ow % c.strw != 0 || oh % c.strh != 0) continue; |
62 | ow /= c.strw; |
63 | oh /= c.strh; |
64 | if (oh < c.oh && ow < c.ow) { |
65 | memory::dim didx = mb * padded_oc * c.oh * c.ow |
66 | + g * padded_oc / c.ng * c.oh * c.ow |
67 | + oc * c.oh * c.ow + oh * c.ow + ow; |
68 | memory::dim widx = g * padded_oc / c.ng |
69 | * padded_ic / c.ng * c.kh * c.kw |
70 | + oc * padded_ic / c.ng * c.kh * c.kw |
71 | + ic * c.kh * c.kw + kh * c.kw + kw; |
72 | |
73 | a += (data_t_acc)( |
74 | diff_dst_data[diff_dst_mdw.off_l( |
75 | didx, true)] |
76 | * weights_data[weights_mdw.off_l( |
77 | widx, true)]); |
78 | } |
79 | } |
80 | } |
81 | } |
82 | diff_src_data[diff_src_mdw.off_l(sidx, true)] |
83 | = (data_t_diff_src)a; |
84 | }); |
85 | } |
86 | |
87 | template <typename data_t_diff_dst, typename data_t_wei, typename data_t_acc, |
88 | typename data_t_diff_src> |
89 | class convolution_backward_data_test |
90 | : public ::testing::TestWithParam<test_convolution_params_t> { |
91 | protected: |
92 | virtual void SetUp() { |
93 | auto p = ::testing::TestWithParam< |
94 | test_convolution_params_t>::GetParam(); |
95 | |
96 | SKIP_IF_CUDA( |
97 | !(cuda_check_format_tags(p.formats.src_format) |
98 | && cuda_check_format_tags(p.formats.dst_format) |
99 | && (cuda_check_format_tags(p.formats.weights_format) |
100 | || (impl::utils::one_of( |
101 | p.formats.weights_format, |
102 | /* weights formats */ |
103 | memory::format_tag::gowi, |
104 | memory::format_tag::gohwi, |
105 | memory::format_tag::godhwi, |
106 | memory::format_tag::owi, |
107 | memory::format_tag::ohwi, |
108 | memory::format_tag::odhwi))) |
109 | && data_traits<data_t_diff_src>::data_type |
110 | == memory::data_type::f32 |
111 | && data_traits<data_t_diff_dst>::data_type |
112 | == memory::data_type::f32 |
113 | && data_traits<data_t_wei>::data_type |
114 | == memory::data_type::f32 |
115 | && check_cuda_alg_format(p.formats.dst_format, |
116 | p.formats.weights_format, p.aalgorithm)), |
117 | "format is not supported." ); |
118 | |
119 | catch_expected_failures( |
120 | [=]() { Test(); }, p.expect_to_fail, p.expected_status); |
121 | } |
122 | |
123 | bool cuda_check_format_tags(memory::format_tag tag) { |
124 | return impl::utils::one_of(tag, memory::format_tag::ab, |
125 | memory::format_tag::abc, memory::format_tag::abcd, |
126 | memory::format_tag::abcde, memory::format_tag::abcdef, |
127 | memory::format_tag::acb, memory::format_tag::acdb, |
128 | memory::format_tag::acdeb); |
129 | } |
130 | |
131 | bool check_cuda_alg_format(memory::format_tag dst_fmt, |
132 | memory::format_tag wei_fmt, algorithm alg) { |
133 | bool res = dst_fmt == wei_fmt; |
134 | if (alg == dnnl::algorithm::convolution_winograd) { |
135 | res = res |
136 | && impl::utils::one_of(wei_fmt, memory::format_tag::ab, |
137 | memory::format_tag::abc, memory::format_tag::abcd, |
138 | memory::format_tag::abcde, |
139 | memory::format_tag::abcdef); |
140 | } |
141 | return res; |
142 | } |
143 | |
144 | void Test() { |
145 | auto p = ::testing::TestWithParam< |
146 | test_convolution_params_t>::GetParam(); |
147 | ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct); |
148 | auto eng = get_test_engine(); |
149 | auto strm = stream(eng); |
150 | auto data_type_diff_src = data_traits<data_t_diff_src>::data_type; |
151 | auto data_type_diff_dst = data_traits<data_t_diff_dst>::data_type; |
152 | auto data_type_wei = data_traits<data_t_wei>::data_type; |
153 | |
154 | test_convolution_sizes_t cd = p.sizes; |
155 | |
156 | auto c_src_desc = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, |
157 | data_type_diff_src, p.formats.src_format); |
158 | auto c_weights_desc = cd.ng > 1 |
159 | ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw}, |
160 | data_type_wei, p.formats.weights_format) |
161 | : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type_wei, |
162 | p.formats.weights_format); |
163 | auto c_dst_desc = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, |
164 | data_type_diff_dst, p.formats.dst_format); |
165 | auto c_src_desc_f = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, |
166 | data_type_diff_dst, p.formats.src_format); |
167 | auto c_dst_desc_f = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, |
168 | data_type_diff_src, p.formats.dst_format); |
169 | |
170 | auto c_diff_src = test_memory(c_src_desc, eng); |
171 | auto c_weights = test_memory(c_weights_desc, eng); |
172 | auto c_diff_dst = test_memory(c_dst_desc, eng); |
173 | |
174 | memory::dims padR = { |
175 | right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh), |
176 | right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)}; |
177 | |
178 | // Only true for dense format |
179 | fill_data<data_t_wei>( |
180 | c_weights.get_size() / sizeof(data_t_wei), c_weights.get()); |
181 | fill_data<data_t_diff_dst>( |
182 | c_diff_dst.get_size() / sizeof(data_t_diff_dst), |
183 | c_diff_dst.get()); |
184 | fill_data<data_t_diff_src>( |
185 | c_diff_src.get_size() / sizeof(data_t_diff_src), |
186 | c_diff_src.get()); |
187 | check_zero_tail<data_t_diff_dst>(1, c_diff_dst.get()); |
188 | check_zero_tail<data_t_wei>(1, c_weights.get()); |
189 | check_zero_tail<data_t_diff_src>(1, c_diff_src.get()); |
190 | |
191 | auto conv_primitive_desc = convolution_forward::primitive_desc(eng, |
192 | prop_kind::forward_training, p.aalgorithm, c_src_desc_f, |
193 | c_weights_desc, c_dst_desc_f, {cd.strh, cd.strw}, |
194 | {cd.dilh, cd.dilw}, {cd.padh, cd.padw}, padR); |
195 | |
196 | auto conv_bwd_data_primitive_desc |
197 | = convolution_backward_data::primitive_desc(eng, p.aalgorithm, |
198 | c_src_desc, c_weights_desc, c_dst_desc, |
199 | {cd.strh, cd.strw}, {cd.dilh, cd.dilw}, |
200 | {cd.padh, cd.padw}, padR, conv_primitive_desc); |
201 | conv_bwd_data_primitive_desc |
202 | = convolution_backward_data::primitive_desc( |
203 | conv_bwd_data_primitive_desc |
204 | .get()); // test construction from a C pd |
205 | |
206 | ASSERT_TRUE(conv_bwd_data_primitive_desc.query_md( |
207 | query::exec_arg_md, DNNL_ARG_DIFF_SRC) |
208 | == conv_bwd_data_primitive_desc.diff_src_desc()); |
209 | ASSERT_TRUE(conv_bwd_data_primitive_desc.query_md( |
210 | query::exec_arg_md, DNNL_ARG_DIFF_DST) |
211 | == conv_bwd_data_primitive_desc.diff_dst_desc()); |
212 | ASSERT_TRUE(conv_bwd_data_primitive_desc.query_md( |
213 | query::exec_arg_md, DNNL_ARG_WEIGHTS) |
214 | == conv_bwd_data_primitive_desc.weights_desc()); |
215 | |
216 | EXPECT_ANY_THROW( |
217 | convolution_backward_data(conv_bwd_data_primitive_desc, {})); |
218 | convolution_backward_data(conv_bwd_data_primitive_desc) |
219 | .execute(strm, |
220 | {{DNNL_ARG_DIFF_DST, c_diff_dst.get()}, |
221 | {DNNL_ARG_WEIGHTS, c_weights.get()}, |
222 | {DNNL_ARG_DIFF_SRC, c_diff_src.get()}}); |
223 | strm.wait(); |
224 | |
225 | auto ref_memory = test::make_memory(c_src_desc, eng); |
226 | compute_ref_conv_bwd_data<data_t_diff_dst, data_t_wei, data_t_acc, |
227 | data_t_diff_src>( |
228 | cd, ref_memory, c_weights.get(), c_diff_dst.get()); |
229 | check_zero_tail<data_t_diff_src>(1, ref_memory); |
230 | |
231 | compare_data<data_t_diff_src>(ref_memory, c_diff_src.get()); |
232 | check_zero_tail<data_t_diff_src>(0, c_diff_src.get()); |
233 | } |
234 | }; |
235 | |
236 | } // namespace dnnl |
237 | #endif |
238 | |