1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_convolution.hpp"
18
19#include "gpu/primitive_conf.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace ocl {
25
26static status_t init_conf_common(
27 conv_conf_t &conf, const convolution_pd_t *pd, engine_t *engine) {
28 const convolution_desc_t &cd = *pd->desc();
29 const memory_desc_t &src_md = *pd->invariant_src_md();
30 const memory_desc_t &weights_md = *pd->invariant_wei_md();
31 const memory_desc_t &dst_md = *pd->invariant_dst_md();
32 const memory_desc_t &bias_md = *pd->invariant_bia_md();
33 const primitive_attr_t &attr = *pd->attr();
34
35 set_default_conf(conf, cd, src_md, weights_md, dst_md, bias_md, attr);
36
37 int oc_idx = (int)conf.with_groups;
38 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
39 switch (cd.prop_kind) {
40 case prop_kind::forward_training:
41 case prop_kind::forward_inference: {
42 conf.dispatch = compute_engine->create_dispatch(&dst_md);
43 conf.dispatch.define_dim("MB", 0, conf.mb);
44 conf.dispatch.define_dim("G", 1, conf.ngroups);
45 conf.dispatch.define_dim("OC", 1, conf.oc);
46 conf.dispatch.define_dim(
47 "OD", nstl::max(2, conf.ndims - 3), conf.od);
48 conf.dispatch.define_dim(
49 "OH", nstl::max(2, conf.ndims - 2), conf.oh);
50 conf.dispatch.define_dim(
51 "OW", nstl::max(2, conf.ndims - 1), conf.ow);
52 conf.dispatch.generate();
53 break;
54 }
55 case prop_kind::backward_data: {
56 conf.dispatch = compute_engine->create_dispatch(&src_md);
57 conf.dispatch.define_dim_with_nesting_level(
58 "IC", conf.ndims, conf.ic);
59 conf.dispatch.define_dim("MB", conf.mb);
60 conf.dispatch.define_dim("G", conf.ngroups);
61 conf.dispatch.define_dim(
62 "ID", nstl::max(2, conf.ndims - 3), conf.id);
63 conf.dispatch.define_dim(
64 "IH", nstl::max(2, conf.ndims - 2), conf.ih);
65 conf.dispatch.define_dim(
66 "IW", nstl::max(2, conf.ndims - 1), conf.iw);
67 conf.dispatch.generate();
68 break;
69 }
70 case prop_kind::backward_weights: {
71 conf.dispatch = compute_engine->create_dispatch(&weights_md);
72 conf.dispatch.define_dim("G", 0, conf.ngroups);
73 conf.dispatch.define_dim("OC", oc_idx, conf.oc);
74 conf.dispatch.define_dim("IC", oc_idx + 1, conf.ic);
75 conf.dispatch.define_dim(
76 "KD", oc_idx + nstl::max(2, conf.ndims - 3), conf.kd);
77 conf.dispatch.define_dim(
78 "KH", oc_idx + nstl::max(2, conf.ndims - 2), conf.kh);
79 conf.dispatch.define_dim(
80 "KW", oc_idx + nstl::max(2, conf.ndims - 1), conf.kw);
81 conf.dispatch.generate();
82 break;
83 }
84 default: break;
85 }
86
87 return status::success;
88}
89
90static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
91 const conv_conf_t &conf, const post_ops_t &post_ops) {
92 kernel_ctx.define_int("NDIMS", conf.ndims);
93 kernel_ctx.define_int("G", conf.ngroups);
94 kernel_ctx.define_int("WITH_GROUPS", conf.with_groups);
95 kernel_ctx.define_int("MB", conf.mb);
96 kernel_ctx.define_int("IC", conf.ic);
97 kernel_ctx.define_int("ID", conf.id);
98 kernel_ctx.define_int("IH", conf.ih);
99 kernel_ctx.define_int("IW", conf.iw);
100 kernel_ctx.define_int("OC", conf.oc);
101 kernel_ctx.define_int("OD", conf.od);
102 kernel_ctx.define_int("OH", conf.oh);
103 kernel_ctx.define_int("OW", conf.ow);
104 kernel_ctx.define_int("KD", conf.kd);
105 kernel_ctx.define_int("KH", conf.kh);
106 kernel_ctx.define_int("KW", conf.kw);
107 kernel_ctx.define_int("SD", conf.stride_d);
108 kernel_ctx.define_int("SH", conf.stride_h);
109 kernel_ctx.define_int("SW", conf.stride_w);
110 kernel_ctx.define_int("PD", conf.f_pad);
111 kernel_ctx.define_int("PH", conf.t_pad);
112 kernel_ctx.define_int("PW", conf.l_pad);
113 kernel_ctx.define_int("PD_R", conf.back_pad);
114 kernel_ctx.define_int("PH_R", conf.b_pad);
115 kernel_ctx.define_int("PW_R", conf.r_pad);
116 kernel_ctx.define_int("DD", conf.dilate_d);
117 kernel_ctx.define_int("DH", conf.dilate_h);
118 kernel_ctx.define_int("DW", conf.dilate_w);
119 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
120 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
121
122 kernel_ctx.define_int("IS_FWD",
123 utils::one_of(conf.prop_kind, prop_kind::forward_inference,
124 prop_kind::forward_training));
125 kernel_ctx.define_int(
126 "IS_BWD_D", conf.prop_kind == prop_kind::backward_data);
127 kernel_ctx.define_int(
128 "IS_BWD_W", conf.prop_kind == prop_kind::backward_weights);
129
130 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
131 def_memory_desc_info(kernel_ctx, conf.wei_md_info, "WEI");
132 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
133
134 def_dispatch(kernel_ctx, conf.dispatch);
135
136 switch (conf.prop_kind) {
137 case prop_kind::forward_training:
138 case prop_kind::forward_inference:
139 kernel_ctx.set_data_type(conf.dst_data_type);
140 break;
141 case prop_kind::backward_data:
142 kernel_ctx.set_data_type(conf.src_data_type);
143 break;
144 case prop_kind::backward_weights:
145 kernel_ctx.set_data_type(conf.weights_data_type);
146 break;
147 default: break;
148 }
149
150 def_data_type(kernel_ctx, conf.src_data_type, "SRC");
151 def_data_type(kernel_ctx, conf.weights_data_type, "WEI");
152 def_data_type(kernel_ctx, conf.bias_data_type, "BIA");
153 def_data_type(kernel_ctx, conf.dst_data_type, "DST");
154 def_data_type(kernel_ctx, conf.acc_data_type, "ACC");
155 def_data_type(kernel_ctx,
156 conf.attr_info.sum_data_type == dnnl_data_type_undef
157 ? conf.dst_data_type
158 : conf.attr_info.sum_data_type,
159 "SUM");
160
161 dims_t dst_dims {};
162
163 for (int d = 0; d < MAX_NDIMS; d++) {
164 if (d < conf.ndims)
165 dst_dims[d] = (conf.prop_kind & dnnl_backward)
166 ? conf.src_md_info.dims[d]
167 : conf.dst_md_info.dims[d];
168 else
169 dst_dims[d] = 1;
170 }
171 def_attr_info(kernel_ctx, conf.attr_info, post_ops, &dst_dims);
172 return status::success;
173}
174
175status_t ref_convolution_fwd_t::pd_t::init_conf(engine_t *engine) {
176 CHECK(init_conf_common(conf, this, engine));
177 return status::success;
178}
179
180status_t ref_convolution_fwd_t::pd_t::init_kernel_ctx(
181 compute::kernel_ctx_t &kernel_ctx) const {
182 return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_);
183}
184
185status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
186
187 status_t status = status::success;
188 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
189 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
190 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
191 auto &dst = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status);
192 CHECK(status);
193 auto &src_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
194 auto &wei_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
195 auto &dst_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
196 auto &src_zpoints
197 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
198 auto &dst_zpoints
199 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
200
201 auto &conf = pd()->conf;
202
203 compute::kernel_arg_list_t arg_list;
204 arg_list.set(0, src);
205 arg_list.set(1, weights);
206 arg_list.set(2, bias);
207 arg_list.set(3, dst);
208
209 unsigned arg_idx = append_post_ops_to_arg_list(
210 ctx, arg_list, 4, pd()->attr()->post_ops_);
211
212 arg_list.set(arg_idx++, src_scales);
213 arg_list.set(arg_idx++, wei_scales);
214 arg_list.set(arg_idx++, dst_scales);
215
216 if (conf.attr_info.with_src_zpoints)
217 arg_list.set(arg_idx++, src_zpoints);
218 else
219 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
220
221 if (conf.attr_info.with_dst_zpoints)
222 arg_list.set(arg_idx++, dst_zpoints);
223 else
224 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
225
226 auto nd_range = pd()->conf.dispatch.nd_range();
227
228 status = parallel_for(ctx, nd_range, kernel_, arg_list);
229 return status;
230}
231
232status_t ref_convolution_bwd_data_t::pd_t::init_conf(engine_t *engine) {
233 CHECK(init_conf_common(conf, this, engine));
234 return status::success;
235}
236
237status_t ref_convolution_bwd_data_t::pd_t::init_kernel_ctx(
238 compute::kernel_ctx_t &kernel_ctx) const {
239 return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_);
240}
241
242status_t ref_convolution_bwd_data_t::execute_backward_data(
243 const exec_ctx_t &ctx) const {
244
245 status_t status = status::success;
246 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
247 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
248 auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status);
249 CHECK(status);
250 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
251 auto &src_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
252 auto &wei_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
253 auto &dst_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
254 auto &src_zpoints
255 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
256 auto &dst_zpoints
257 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
258
259 auto &conf = pd()->conf;
260
261 compute::kernel_arg_list_t arg_list;
262 arg_list.set(0, diff_src);
263 arg_list.set(1, weights);
264 arg_list.set(2, diff_dst);
265 arg_list.set(3, bias);
266
267 unsigned arg_idx = append_post_ops_to_arg_list(
268 ctx, arg_list, 4, pd()->attr()->post_ops_);
269
270 arg_list.set(arg_idx++, src_scales);
271 arg_list.set(arg_idx++, wei_scales);
272 arg_list.set(arg_idx++, dst_scales);
273
274 if (conf.attr_info.with_src_zpoints)
275 arg_list.set(arg_idx++, src_zpoints);
276 else
277 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
278
279 if (conf.attr_info.with_dst_zpoints)
280 arg_list.set(arg_idx++, dst_zpoints);
281 else
282 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
283
284 auto nd_range = pd()->conf.dispatch.nd_range();
285
286 status = parallel_for(ctx, nd_range, kernel_, arg_list);
287
288 return status;
289}
290
291status_t ref_convolution_bwd_weights_t::pd_t::init_conf(engine_t *engine) {
292 return init_conf_common(conf, this, engine);
293}
294
295status_t ref_convolution_bwd_weights_t::pd_t::init_kernel_ctx(
296 compute::kernel_ctx_t &kernel_ctx) const {
297 return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_);
298}
299
300status_t ref_convolution_bwd_weights_t::execute_backward_weights(
301 const exec_ctx_t &ctx) const {
302
303 status_t status = status::success;
304 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
305 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
306 auto &diff_weights = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_WEIGHTS, status);
307 CHECK(status);
308 auto &diff_bias = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_BIAS, status);
309 CHECK(status);
310
311 compute::kernel_arg_list_t arg_list;
312 arg_list.set(0, src);
313 arg_list.set(1, diff_weights);
314 arg_list.set(2, diff_bias);
315 arg_list.set(3, diff_dst);
316
317 auto nd_range = pd()->conf.dispatch.nd_range();
318
319 status = parallel_for(ctx, nd_range, kernel_, arg_list);
320
321 return status;
322}
323
324} // namespace ocl
325} // namespace gpu
326} // namespace impl
327} // namespace dnnl
328