1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_convolution.hpp" |
18 | |
19 | #include "gpu/primitive_conf.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace ocl { |
25 | |
26 | static status_t init_conf_common( |
27 | conv_conf_t &conf, const convolution_pd_t *pd, engine_t *engine) { |
28 | const convolution_desc_t &cd = *pd->desc(); |
29 | const memory_desc_t &src_md = *pd->invariant_src_md(); |
30 | const memory_desc_t &weights_md = *pd->invariant_wei_md(); |
31 | const memory_desc_t &dst_md = *pd->invariant_dst_md(); |
32 | const memory_desc_t &bias_md = *pd->invariant_bia_md(); |
33 | const primitive_attr_t &attr = *pd->attr(); |
34 | |
35 | set_default_conf(conf, cd, src_md, weights_md, dst_md, bias_md, attr); |
36 | |
37 | int oc_idx = (int)conf.with_groups; |
38 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
39 | switch (cd.prop_kind) { |
40 | case prop_kind::forward_training: |
41 | case prop_kind::forward_inference: { |
42 | conf.dispatch = compute_engine->create_dispatch(&dst_md); |
43 | conf.dispatch.define_dim("MB" , 0, conf.mb); |
44 | conf.dispatch.define_dim("G" , 1, conf.ngroups); |
45 | conf.dispatch.define_dim("OC" , 1, conf.oc); |
46 | conf.dispatch.define_dim( |
47 | "OD" , nstl::max(2, conf.ndims - 3), conf.od); |
48 | conf.dispatch.define_dim( |
49 | "OH" , nstl::max(2, conf.ndims - 2), conf.oh); |
50 | conf.dispatch.define_dim( |
51 | "OW" , nstl::max(2, conf.ndims - 1), conf.ow); |
52 | conf.dispatch.generate(); |
53 | break; |
54 | } |
55 | case prop_kind::backward_data: { |
56 | conf.dispatch = compute_engine->create_dispatch(&src_md); |
57 | conf.dispatch.define_dim_with_nesting_level( |
58 | "IC" , conf.ndims, conf.ic); |
59 | conf.dispatch.define_dim("MB" , conf.mb); |
60 | conf.dispatch.define_dim("G" , conf.ngroups); |
61 | conf.dispatch.define_dim( |
62 | "ID" , nstl::max(2, conf.ndims - 3), conf.id); |
63 | conf.dispatch.define_dim( |
64 | "IH" , nstl::max(2, conf.ndims - 2), conf.ih); |
65 | conf.dispatch.define_dim( |
66 | "IW" , nstl::max(2, conf.ndims - 1), conf.iw); |
67 | conf.dispatch.generate(); |
68 | break; |
69 | } |
70 | case prop_kind::backward_weights: { |
71 | conf.dispatch = compute_engine->create_dispatch(&weights_md); |
72 | conf.dispatch.define_dim("G" , 0, conf.ngroups); |
73 | conf.dispatch.define_dim("OC" , oc_idx, conf.oc); |
74 | conf.dispatch.define_dim("IC" , oc_idx + 1, conf.ic); |
75 | conf.dispatch.define_dim( |
76 | "KD" , oc_idx + nstl::max(2, conf.ndims - 3), conf.kd); |
77 | conf.dispatch.define_dim( |
78 | "KH" , oc_idx + nstl::max(2, conf.ndims - 2), conf.kh); |
79 | conf.dispatch.define_dim( |
80 | "KW" , oc_idx + nstl::max(2, conf.ndims - 1), conf.kw); |
81 | conf.dispatch.generate(); |
82 | break; |
83 | } |
84 | default: break; |
85 | } |
86 | |
87 | return status::success; |
88 | } |
89 | |
90 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
91 | const conv_conf_t &conf, const post_ops_t &post_ops) { |
92 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
93 | kernel_ctx.define_int("G" , conf.ngroups); |
94 | kernel_ctx.define_int("WITH_GROUPS" , conf.with_groups); |
95 | kernel_ctx.define_int("MB" , conf.mb); |
96 | kernel_ctx.define_int("IC" , conf.ic); |
97 | kernel_ctx.define_int("ID" , conf.id); |
98 | kernel_ctx.define_int("IH" , conf.ih); |
99 | kernel_ctx.define_int("IW" , conf.iw); |
100 | kernel_ctx.define_int("OC" , conf.oc); |
101 | kernel_ctx.define_int("OD" , conf.od); |
102 | kernel_ctx.define_int("OH" , conf.oh); |
103 | kernel_ctx.define_int("OW" , conf.ow); |
104 | kernel_ctx.define_int("KD" , conf.kd); |
105 | kernel_ctx.define_int("KH" , conf.kh); |
106 | kernel_ctx.define_int("KW" , conf.kw); |
107 | kernel_ctx.define_int("SD" , conf.stride_d); |
108 | kernel_ctx.define_int("SH" , conf.stride_h); |
109 | kernel_ctx.define_int("SW" , conf.stride_w); |
110 | kernel_ctx.define_int("PD" , conf.f_pad); |
111 | kernel_ctx.define_int("PH" , conf.t_pad); |
112 | kernel_ctx.define_int("PW" , conf.l_pad); |
113 | kernel_ctx.define_int("PD_R" , conf.back_pad); |
114 | kernel_ctx.define_int("PH_R" , conf.b_pad); |
115 | kernel_ctx.define_int("PW_R" , conf.r_pad); |
116 | kernel_ctx.define_int("DD" , conf.dilate_d); |
117 | kernel_ctx.define_int("DH" , conf.dilate_h); |
118 | kernel_ctx.define_int("DW" , conf.dilate_w); |
119 | kernel_ctx.define_int("WITH_BIAS" , conf.with_bias); |
120 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
121 | |
122 | kernel_ctx.define_int("IS_FWD" , |
123 | utils::one_of(conf.prop_kind, prop_kind::forward_inference, |
124 | prop_kind::forward_training)); |
125 | kernel_ctx.define_int( |
126 | "IS_BWD_D" , conf.prop_kind == prop_kind::backward_data); |
127 | kernel_ctx.define_int( |
128 | "IS_BWD_W" , conf.prop_kind == prop_kind::backward_weights); |
129 | |
130 | def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC" ); |
131 | def_memory_desc_info(kernel_ctx, conf.wei_md_info, "WEI" ); |
132 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
133 | |
134 | def_dispatch(kernel_ctx, conf.dispatch); |
135 | |
136 | switch (conf.prop_kind) { |
137 | case prop_kind::forward_training: |
138 | case prop_kind::forward_inference: |
139 | kernel_ctx.set_data_type(conf.dst_data_type); |
140 | break; |
141 | case prop_kind::backward_data: |
142 | kernel_ctx.set_data_type(conf.src_data_type); |
143 | break; |
144 | case prop_kind::backward_weights: |
145 | kernel_ctx.set_data_type(conf.weights_data_type); |
146 | break; |
147 | default: break; |
148 | } |
149 | |
150 | def_data_type(kernel_ctx, conf.src_data_type, "SRC" ); |
151 | def_data_type(kernel_ctx, conf.weights_data_type, "WEI" ); |
152 | def_data_type(kernel_ctx, conf.bias_data_type, "BIA" ); |
153 | def_data_type(kernel_ctx, conf.dst_data_type, "DST" ); |
154 | def_data_type(kernel_ctx, conf.acc_data_type, "ACC" ); |
155 | def_data_type(kernel_ctx, |
156 | conf.attr_info.sum_data_type == dnnl_data_type_undef |
157 | ? conf.dst_data_type |
158 | : conf.attr_info.sum_data_type, |
159 | "SUM" ); |
160 | |
161 | dims_t dst_dims {}; |
162 | |
163 | for (int d = 0; d < MAX_NDIMS; d++) { |
164 | if (d < conf.ndims) |
165 | dst_dims[d] = (conf.prop_kind & dnnl_backward) |
166 | ? conf.src_md_info.dims[d] |
167 | : conf.dst_md_info.dims[d]; |
168 | else |
169 | dst_dims[d] = 1; |
170 | } |
171 | def_attr_info(kernel_ctx, conf.attr_info, post_ops, &dst_dims); |
172 | return status::success; |
173 | } |
174 | |
175 | status_t ref_convolution_fwd_t::pd_t::init_conf(engine_t *engine) { |
176 | CHECK(init_conf_common(conf, this, engine)); |
177 | return status::success; |
178 | } |
179 | |
180 | status_t ref_convolution_fwd_t::pd_t::init_kernel_ctx( |
181 | compute::kernel_ctx_t &kernel_ctx) const { |
182 | return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_); |
183 | } |
184 | |
185 | status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
186 | |
187 | status_t status = status::success; |
188 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
189 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
190 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
191 | auto &dst = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status); |
192 | CHECK(status); |
193 | auto &src_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
194 | auto &wei_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); |
195 | auto &dst_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
196 | auto &src_zpoints |
197 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); |
198 | auto &dst_zpoints |
199 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); |
200 | |
201 | auto &conf = pd()->conf; |
202 | |
203 | compute::kernel_arg_list_t arg_list; |
204 | arg_list.set(0, src); |
205 | arg_list.set(1, weights); |
206 | arg_list.set(2, bias); |
207 | arg_list.set(3, dst); |
208 | |
209 | unsigned arg_idx = append_post_ops_to_arg_list( |
210 | ctx, arg_list, 4, pd()->attr()->post_ops_); |
211 | |
212 | arg_list.set(arg_idx++, src_scales); |
213 | arg_list.set(arg_idx++, wei_scales); |
214 | arg_list.set(arg_idx++, dst_scales); |
215 | |
216 | if (conf.attr_info.with_src_zpoints) |
217 | arg_list.set(arg_idx++, src_zpoints); |
218 | else |
219 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
220 | |
221 | if (conf.attr_info.with_dst_zpoints) |
222 | arg_list.set(arg_idx++, dst_zpoints); |
223 | else |
224 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
225 | |
226 | auto nd_range = pd()->conf.dispatch.nd_range(); |
227 | |
228 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
229 | return status; |
230 | } |
231 | |
232 | status_t ref_convolution_bwd_data_t::pd_t::init_conf(engine_t *engine) { |
233 | CHECK(init_conf_common(conf, this, engine)); |
234 | return status::success; |
235 | } |
236 | |
237 | status_t ref_convolution_bwd_data_t::pd_t::init_kernel_ctx( |
238 | compute::kernel_ctx_t &kernel_ctx) const { |
239 | return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_); |
240 | } |
241 | |
242 | status_t ref_convolution_bwd_data_t::execute_backward_data( |
243 | const exec_ctx_t &ctx) const { |
244 | |
245 | status_t status = status::success; |
246 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
247 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
248 | auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status); |
249 | CHECK(status); |
250 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
251 | auto &src_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
252 | auto &wei_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); |
253 | auto &dst_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
254 | auto &src_zpoints |
255 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); |
256 | auto &dst_zpoints |
257 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); |
258 | |
259 | auto &conf = pd()->conf; |
260 | |
261 | compute::kernel_arg_list_t arg_list; |
262 | arg_list.set(0, diff_src); |
263 | arg_list.set(1, weights); |
264 | arg_list.set(2, diff_dst); |
265 | arg_list.set(3, bias); |
266 | |
267 | unsigned arg_idx = append_post_ops_to_arg_list( |
268 | ctx, arg_list, 4, pd()->attr()->post_ops_); |
269 | |
270 | arg_list.set(arg_idx++, src_scales); |
271 | arg_list.set(arg_idx++, wei_scales); |
272 | arg_list.set(arg_idx++, dst_scales); |
273 | |
274 | if (conf.attr_info.with_src_zpoints) |
275 | arg_list.set(arg_idx++, src_zpoints); |
276 | else |
277 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
278 | |
279 | if (conf.attr_info.with_dst_zpoints) |
280 | arg_list.set(arg_idx++, dst_zpoints); |
281 | else |
282 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
283 | |
284 | auto nd_range = pd()->conf.dispatch.nd_range(); |
285 | |
286 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
287 | |
288 | return status; |
289 | } |
290 | |
291 | status_t ref_convolution_bwd_weights_t::pd_t::init_conf(engine_t *engine) { |
292 | return init_conf_common(conf, this, engine); |
293 | } |
294 | |
295 | status_t ref_convolution_bwd_weights_t::pd_t::init_kernel_ctx( |
296 | compute::kernel_ctx_t &kernel_ctx) const { |
297 | return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_); |
298 | } |
299 | |
300 | status_t ref_convolution_bwd_weights_t::execute_backward_weights( |
301 | const exec_ctx_t &ctx) const { |
302 | |
303 | status_t status = status::success; |
304 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
305 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
306 | auto &diff_weights = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_WEIGHTS, status); |
307 | CHECK(status); |
308 | auto &diff_bias = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_BIAS, status); |
309 | CHECK(status); |
310 | |
311 | compute::kernel_arg_list_t arg_list; |
312 | arg_list.set(0, src); |
313 | arg_list.set(1, diff_weights); |
314 | arg_list.set(2, diff_bias); |
315 | arg_list.set(3, diff_dst); |
316 | |
317 | auto nd_range = pd()->conf.dispatch.nd_range(); |
318 | |
319 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
320 | |
321 | return status; |
322 | } |
323 | |
324 | } // namespace ocl |
325 | } // namespace gpu |
326 | } // namespace impl |
327 | } // namespace dnnl |
328 | |