1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef TEST_CONVOLUTION_BACKWARD_DATA_COMMON_H
18#define TEST_CONVOLUTION_BACKWARD_DATA_COMMON_H
19
20#include "dnnl_test_common.hpp"
21#include "gtest/gtest.h"
22
23#include "oneapi/dnnl/dnnl.hpp"
24
25namespace dnnl {
26
27template <typename data_t_diff_dst, typename data_t_wei, typename data_t_acc,
28 typename data_t_diff_src>
29void compute_ref_conv_bwd_data(const test_convolution_sizes_t &c,
30 const memory &diff_src, const memory &weights, const memory &diff_dst) {
31 auto diff_dst_data = map_memory<data_t_diff_dst>(diff_dst);
32 auto weights_data = map_memory<data_t_wei>(weights);
33 auto diff_src_data = map_memory<data_t_diff_src>(diff_src);
34
35 const memory::desc diff_src_d = diff_src.get_desc();
36 const memory::desc weights_d = weights.get_desc();
37 const memory::desc diff_dst_d = diff_dst.get_desc();
38
39 auto padded_ic = diff_src_d.get_padded_dims()[1];
40 auto padded_oc = diff_dst_d.get_padded_dims()[1];
41
42 const dnnl::impl::memory_desc_wrapper diff_src_mdw(diff_src_d.get());
43 const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.get());
44 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.get());
45
46 dnnl::impl::parallel_nd(c.mb, c.ng, c.ic / c.ng, c.ih, c.iw,
47 [&](memory::dim mb, memory::dim g, memory::dim ic, memory::dim ih,
48 memory::dim iw) {
49 memory::dim sidx = mb * padded_ic * c.ih * c.iw
50 + g * padded_ic / c.ng * c.ih * c.iw + ic * c.ih * c.iw
51 + ih * c.iw + iw;
52 data_t_acc a = data_t_acc(0);
53 for (memory::dim oc = 0; oc < c.oc / c.ng; oc++) {
54 for (memory::dim kh = 0; kh < c.kh; kh++) {
55 for (memory::dim kw = 0; kw < c.kw; kw++) {
56 if (iw + c.padw < kw * (1 + c.dilw)
57 || ih + c.padh < kh * (1 + c.dilh))
58 continue;
59 memory::dim ow = iw - kw * (1 + c.dilw) + c.padw;
60 memory::dim oh = ih - kh * (1 + c.dilh) + c.padh;
61 if (ow % c.strw != 0 || oh % c.strh != 0) continue;
62 ow /= c.strw;
63 oh /= c.strh;
64 if (oh < c.oh && ow < c.ow) {
65 memory::dim didx = mb * padded_oc * c.oh * c.ow
66 + g * padded_oc / c.ng * c.oh * c.ow
67 + oc * c.oh * c.ow + oh * c.ow + ow;
68 memory::dim widx = g * padded_oc / c.ng
69 * padded_ic / c.ng * c.kh * c.kw
70 + oc * padded_ic / c.ng * c.kh * c.kw
71 + ic * c.kh * c.kw + kh * c.kw + kw;
72
73 a += (data_t_acc)(
74 diff_dst_data[diff_dst_mdw.off_l(
75 didx, true)]
76 * weights_data[weights_mdw.off_l(
77 widx, true)]);
78 }
79 }
80 }
81 }
82 diff_src_data[diff_src_mdw.off_l(sidx, true)]
83 = (data_t_diff_src)a;
84 });
85}
86
87template <typename data_t_diff_dst, typename data_t_wei, typename data_t_acc,
88 typename data_t_diff_src>
89class convolution_backward_data_test
90 : public ::testing::TestWithParam<test_convolution_params_t> {
91protected:
92 virtual void SetUp() {
93 auto p = ::testing::TestWithParam<
94 test_convolution_params_t>::GetParam();
95
96 SKIP_IF_CUDA(
97 !(cuda_check_format_tags(p.formats.src_format)
98 && cuda_check_format_tags(p.formats.dst_format)
99 && (cuda_check_format_tags(p.formats.weights_format)
100 || (impl::utils::one_of(
101 p.formats.weights_format,
102 /* weights formats */
103 memory::format_tag::gowi,
104 memory::format_tag::gohwi,
105 memory::format_tag::godhwi,
106 memory::format_tag::owi,
107 memory::format_tag::ohwi,
108 memory::format_tag::odhwi)))
109 && data_traits<data_t_diff_src>::data_type
110 == memory::data_type::f32
111 && data_traits<data_t_diff_dst>::data_type
112 == memory::data_type::f32
113 && data_traits<data_t_wei>::data_type
114 == memory::data_type::f32
115 && check_cuda_alg_format(p.formats.dst_format,
116 p.formats.weights_format, p.aalgorithm)),
117 "format is not supported.");
118
119 catch_expected_failures(
120 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
121 }
122
123 bool cuda_check_format_tags(memory::format_tag tag) {
124 return impl::utils::one_of(tag, memory::format_tag::ab,
125 memory::format_tag::abc, memory::format_tag::abcd,
126 memory::format_tag::abcde, memory::format_tag::abcdef,
127 memory::format_tag::acb, memory::format_tag::acdb,
128 memory::format_tag::acdeb);
129 }
130
131 bool check_cuda_alg_format(memory::format_tag dst_fmt,
132 memory::format_tag wei_fmt, algorithm alg) {
133 bool res = dst_fmt == wei_fmt;
134 if (alg == dnnl::algorithm::convolution_winograd) {
135 res = res
136 && impl::utils::one_of(wei_fmt, memory::format_tag::ab,
137 memory::format_tag::abc, memory::format_tag::abcd,
138 memory::format_tag::abcde,
139 memory::format_tag::abcdef);
140 }
141 return res;
142 }
143
144 void Test() {
145 auto p = ::testing::TestWithParam<
146 test_convolution_params_t>::GetParam();
147 ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
148 auto eng = get_test_engine();
149 auto strm = stream(eng);
150 auto data_type_diff_src = data_traits<data_t_diff_src>::data_type;
151 auto data_type_diff_dst = data_traits<data_t_diff_dst>::data_type;
152 auto data_type_wei = data_traits<data_t_wei>::data_type;
153
154 test_convolution_sizes_t cd = p.sizes;
155
156 auto c_src_desc = create_md({cd.mb, cd.ic, cd.ih, cd.iw},
157 data_type_diff_src, p.formats.src_format);
158 auto c_weights_desc = cd.ng > 1
159 ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw},
160 data_type_wei, p.formats.weights_format)
161 : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type_wei,
162 p.formats.weights_format);
163 auto c_dst_desc = create_md({cd.mb, cd.oc, cd.oh, cd.ow},
164 data_type_diff_dst, p.formats.dst_format);
165 auto c_src_desc_f = create_md({cd.mb, cd.ic, cd.ih, cd.iw},
166 data_type_diff_dst, p.formats.src_format);
167 auto c_dst_desc_f = create_md({cd.mb, cd.oc, cd.oh, cd.ow},
168 data_type_diff_src, p.formats.dst_format);
169
170 auto c_diff_src = test_memory(c_src_desc, eng);
171 auto c_weights = test_memory(c_weights_desc, eng);
172 auto c_diff_dst = test_memory(c_dst_desc, eng);
173
174 memory::dims padR = {
175 right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
176 right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)};
177
178 // Only true for dense format
179 fill_data<data_t_wei>(
180 c_weights.get_size() / sizeof(data_t_wei), c_weights.get());
181 fill_data<data_t_diff_dst>(
182 c_diff_dst.get_size() / sizeof(data_t_diff_dst),
183 c_diff_dst.get());
184 fill_data<data_t_diff_src>(
185 c_diff_src.get_size() / sizeof(data_t_diff_src),
186 c_diff_src.get());
187 check_zero_tail<data_t_diff_dst>(1, c_diff_dst.get());
188 check_zero_tail<data_t_wei>(1, c_weights.get());
189 check_zero_tail<data_t_diff_src>(1, c_diff_src.get());
190
191 auto conv_primitive_desc = convolution_forward::primitive_desc(eng,
192 prop_kind::forward_training, p.aalgorithm, c_src_desc_f,
193 c_weights_desc, c_dst_desc_f, {cd.strh, cd.strw},
194 {cd.dilh, cd.dilw}, {cd.padh, cd.padw}, padR);
195
196 auto conv_bwd_data_primitive_desc
197 = convolution_backward_data::primitive_desc(eng, p.aalgorithm,
198 c_src_desc, c_weights_desc, c_dst_desc,
199 {cd.strh, cd.strw}, {cd.dilh, cd.dilw},
200 {cd.padh, cd.padw}, padR, conv_primitive_desc);
201 conv_bwd_data_primitive_desc
202 = convolution_backward_data::primitive_desc(
203 conv_bwd_data_primitive_desc
204 .get()); // test construction from a C pd
205
206 ASSERT_TRUE(conv_bwd_data_primitive_desc.query_md(
207 query::exec_arg_md, DNNL_ARG_DIFF_SRC)
208 == conv_bwd_data_primitive_desc.diff_src_desc());
209 ASSERT_TRUE(conv_bwd_data_primitive_desc.query_md(
210 query::exec_arg_md, DNNL_ARG_DIFF_DST)
211 == conv_bwd_data_primitive_desc.diff_dst_desc());
212 ASSERT_TRUE(conv_bwd_data_primitive_desc.query_md(
213 query::exec_arg_md, DNNL_ARG_WEIGHTS)
214 == conv_bwd_data_primitive_desc.weights_desc());
215
216 EXPECT_ANY_THROW(
217 convolution_backward_data(conv_bwd_data_primitive_desc, {}));
218 convolution_backward_data(conv_bwd_data_primitive_desc)
219 .execute(strm,
220 {{DNNL_ARG_DIFF_DST, c_diff_dst.get()},
221 {DNNL_ARG_WEIGHTS, c_weights.get()},
222 {DNNL_ARG_DIFF_SRC, c_diff_src.get()}});
223 strm.wait();
224
225 auto ref_memory = test::make_memory(c_src_desc, eng);
226 compute_ref_conv_bwd_data<data_t_diff_dst, data_t_wei, data_t_acc,
227 data_t_diff_src>(
228 cd, ref_memory, c_weights.get(), c_diff_dst.get());
229 check_zero_tail<data_t_diff_src>(1, ref_memory);
230
231 compare_data<data_t_diff_src>(ref_memory, c_diff_src.get());
232 check_zero_tail<data_t_diff_src>(0, c_diff_src.get());
233 }
234};
235
236} // namespace dnnl
237#endif
238