1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_REF_DECONVOLUTION_HPP |
18 | #define GPU_OCL_REF_DECONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | #include "gpu/compute/compute.hpp" |
25 | #include "gpu/gpu_convolution_pd.hpp" |
26 | #include "gpu/gpu_deconvolution_pd.hpp" |
27 | #include "gpu/gpu_primitive.hpp" |
28 | #include "gpu/gpu_resource.hpp" |
29 | #include "gpu/ocl/ocl_stream.hpp" |
30 | #include "gpu/primitive_conf.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace gpu { |
35 | namespace ocl { |
36 | |
37 | static status_t weights_axes_permutation( |
38 | memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) { |
39 | int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation |
40 | for (int d = 0; d < DNNL_MAX_NDIMS; ++d) |
41 | perm[d] = d; |
42 | nstl::swap(perm[0 + with_groups], perm[1 + with_groups]); |
43 | |
44 | return memory_desc_permute_axes(*o_md, *i_md, perm); |
45 | } |
46 | |
47 | static status_t conv_descr_create( |
48 | const deconvolution_desc_t *dd, convolution_desc_t *cd) { |
49 | using namespace prop_kind; |
50 | alg_kind_t alg_kind = alg_kind::convolution_direct; |
51 | |
52 | const memory_desc_t *src_md, *dst_md, *d_weights_d; |
53 | prop_kind_t prop_kind; |
54 | |
55 | switch (dd->prop_kind) { |
56 | case forward: |
57 | case forward_inference: |
58 | prop_kind = backward_data; |
59 | src_md = &dd->dst_desc; |
60 | dst_md = &dd->src_desc; |
61 | d_weights_d = &dd->weights_desc; |
62 | break; |
63 | case backward_data: |
64 | prop_kind = forward_training; |
65 | src_md = &dd->diff_dst_desc; |
66 | dst_md = &dd->diff_src_desc; |
67 | d_weights_d = &dd->weights_desc; |
68 | break; |
69 | case backward_weights: |
70 | prop_kind = dd->prop_kind; |
71 | src_md = &dd->diff_dst_desc; |
72 | dst_md = &dd->src_desc; |
73 | d_weights_d = &dd->diff_weights_desc; |
74 | break; |
75 | default: assert(!"unknown prop kind" ); return status::invalid_arguments; |
76 | } |
77 | |
78 | // Create weights desc for convolution |
79 | memory_desc_t c_weights_d; |
80 | const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; |
81 | CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups)); |
82 | |
83 | return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, |
84 | prop_kind != backward_weights ? &dd->bias_desc : nullptr, dst_md, |
85 | dd->strides, dd->dilates, dd->padding[0], dd->padding[1]); |
86 | } |
87 | |
88 | struct ref_deconvolution_fwd_t : public gpu_primitive_t { |
89 | using gpu_primitive_t::gpu_primitive_t; |
90 | struct pd_t : public gpu_deconvolution_fwd_pd_t { |
91 | pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, |
92 | const deconvolution_fwd_pd_t *hint_fwd_pd) |
93 | : gpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
94 | |
95 | pd_t(const pd_t &other) = default; |
96 | |
97 | ~pd_t() = default; |
98 | |
99 | DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t); |
100 | |
101 | status_t init_convolution(engine_t *engine) { |
102 | convolution_desc_t cd; |
103 | CHECK(conv_descr_create(desc(), &cd)); |
104 | primitive_attr_t conv_attr(*attr()); |
105 | if (!conv_attr.is_initialized()) return status::out_of_memory; |
106 | primitive_desc_iterator_t it( |
107 | engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
108 | if (!it.is_initialized()) return status::out_of_memory; |
109 | conv_pd_ = *(++it); |
110 | return (conv_pd_) ? status::success : status::unimplemented; |
111 | } |
112 | |
113 | status_t init(engine_t *engine) { |
114 | using namespace format_tag; |
115 | using sm = primitive_attr_t::skip_mask_t; |
116 | |
117 | const auto attr_skip_mask = sm::post_ops | sm::zero_points_runtime |
118 | | sm::scales_runtime; |
119 | |
120 | bool ok = is_fwd() |
121 | && desc()->alg_kind == alg_kind::deconvolution_direct |
122 | && attr()->has_default_values(attr_skip_mask) |
123 | && post_ops_with_binary_ok( |
124 | attr(), desc()->dst_desc.data_type, ndims()) |
125 | && (utils::everyone_is(data_type::f32, |
126 | desc()->src_desc.data_type, |
127 | desc()->weights_desc.data_type, |
128 | desc()->dst_desc.data_type) |
129 | || ((utils::everyone_is(data_type::f16, |
130 | desc()->src_desc.data_type, |
131 | desc()->weights_desc.data_type) |
132 | || utils::everyone_is(data_type::f32, |
133 | desc()->src_desc.data_type, |
134 | desc()->weights_desc.data_type) |
135 | || utils::everyone_is(data_type::bf16, |
136 | desc()->src_desc.data_type, |
137 | desc()->weights_desc.data_type)) |
138 | && utils::one_of(desc()->dst_desc.data_type, |
139 | data_type::f16, data_type::u8, |
140 | data_type::s8)) |
141 | || (utils::everyone_is(data_type::bf16, |
142 | desc()->src_desc.data_type, |
143 | desc()->weights_desc.data_type) |
144 | && utils::one_of(desc()->dst_desc.data_type, |
145 | data_type::f32, data_type::bf16)) |
146 | || (desc()->weights_desc.data_type == data_type::s8 |
147 | && utils::one_of(desc()->src_desc.data_type, |
148 | data_type::u8, data_type::s8) |
149 | && desc()->dst_desc.data_type |
150 | != data_type::f64)); |
151 | if (ok) { |
152 | CHECK(init_convolution(engine)); |
153 | if (weights_md_.format_kind == format_kind::any) |
154 | CHECK(weights_axes_permutation(&weights_md_, |
155 | conv_pd_->weights_md(), with_groups())); |
156 | if (src_md_.format_kind == format_kind::any) |
157 | src_md_ = *conv_pd_->diff_dst_md(); |
158 | if (dst_md_.format_kind == format_kind::any) |
159 | dst_md_ = *conv_pd_->diff_src_md(); |
160 | if (bias_md_.format_kind == format_kind::any) |
161 | CHECK(memory_desc_init_by_tag(bias_md_, x)); |
162 | init_scratchpad(); |
163 | CHECK(attr_.set_default_formats(dst_md(0))); |
164 | |
165 | return status::success; |
166 | } |
167 | |
168 | return status::unimplemented; |
169 | } |
170 | |
171 | std::shared_ptr<primitive_desc_t> conv_pd_; |
172 | |
173 | private: |
174 | void init_scratchpad() { |
175 | auto scratchpad = scratchpad_registry().registrar(); |
176 | scratchpad.book(memory_tracking::names::key_nested, |
177 | conv_pd_->scratchpad_registry()); |
178 | } |
179 | }; |
180 | |
181 | status_t init(engine_t *engine) override { |
182 | return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); |
183 | } |
184 | |
185 | status_t execute(const exec_ctx_t &ctx) const override { |
186 | using namespace memory_tracking::names; |
187 | const auto &args = ctx.args(); |
188 | exec_args_t conv_args; |
189 | conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); |
190 | conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); |
191 | conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST); |
192 | if (pd()->with_bias()) |
193 | conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS); |
194 | |
195 | for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) { |
196 | if (pd()->attr()->post_ops_.entry_[idx].is_binary()) { |
197 | conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1] |
198 | = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
199 | | DNNL_ARG_SRC_1); |
200 | } else if (pd()->attr()->post_ops_.entry_[idx].is_prelu()) { |
201 | conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
202 | | DNNL_ARG_WEIGHTS] |
203 | = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
204 | | DNNL_ARG_WEIGHTS); |
205 | } |
206 | } |
207 | const auto z_src = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC; |
208 | const auto z_dst = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST; |
209 | if (args.find(z_src) != args.end()) conv_args[z_src] = args.at(z_src); |
210 | if (args.find(z_dst) != args.end()) conv_args[z_dst] = args.at(z_dst); |
211 | |
212 | for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { |
213 | int key = DNNL_ARG_ATTR_SCALES | arg; |
214 | if (args.find(key) != args.end()) conv_args[key] = args.at(key); |
215 | } |
216 | |
217 | exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
218 | |
219 | nested_scratchpad_t ns(ctx, key_nested, conv_p_); |
220 | conv_ctx.set_scratchpad_grantor(ns.grantor()); |
221 | // Executing the convolution kernel |
222 | return conv_p_->execute(conv_ctx); |
223 | } |
224 | |
225 | private: |
226 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
227 | std::shared_ptr<primitive_t> conv_p_; |
228 | }; |
229 | |
230 | struct ref_deconvolution_bwd_data_t : public gpu_primitive_t { |
231 | using gpu_primitive_t::gpu_primitive_t; |
232 | struct pd_t : public gpu_deconvolution_bwd_data_pd_t { |
233 | pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, |
234 | const deconvolution_fwd_pd_t *hint_fwd_pd) |
235 | : gpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) |
236 | , conv_pd_(nullptr) {} |
237 | |
238 | pd_t(const pd_t &other) = default; |
239 | |
240 | ~pd_t() = default; |
241 | |
242 | DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t); |
243 | |
244 | status_t init_convolution(engine_t *engine) { |
245 | convolution_desc_t cd; |
246 | CHECK(conv_descr_create(desc(), &cd)); |
247 | primitive_attr_t conv_attr(*attr()); |
248 | if (!conv_attr.is_initialized()) return status::out_of_memory; |
249 | primitive_desc_iterator_t it( |
250 | engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
251 | if (!it.is_initialized()) return status::out_of_memory; |
252 | conv_pd_ = *(++it); |
253 | return status::success; |
254 | } |
255 | |
256 | status_t init(engine_t *engine) { |
257 | bool ok = desc()->prop_kind == prop_kind::backward_data |
258 | && (utils::everyone_is(data_type::f32, |
259 | desc()->diff_src_desc.data_type, |
260 | desc()->weights_desc.data_type, |
261 | desc()->diff_dst_desc.data_type) |
262 | || utils::everyone_is(data_type::bf16, |
263 | desc()->weights_desc.data_type, |
264 | desc()->diff_dst_desc.data_type)) |
265 | && utils::one_of(desc()->diff_src_desc.data_type, |
266 | data_type::bf16, data_type::f32) |
267 | && desc()->alg_kind == alg_kind::deconvolution_direct |
268 | && attr()->has_default_values(); |
269 | |
270 | if (ok) { |
271 | CHECK(init_convolution(engine)); |
272 | if (weights_md_.format_kind == format_kind::any) |
273 | CHECK(weights_axes_permutation(&weights_md_, |
274 | conv_pd_->weights_md(), with_groups())); |
275 | if (diff_src_md_.format_kind == format_kind::any) |
276 | diff_src_md_ = *conv_pd_->dst_md(); |
277 | if (diff_dst_md_.format_kind == format_kind::any) |
278 | diff_dst_md_ = *conv_pd_->src_md(); |
279 | init_scratchpad(); |
280 | |
281 | return status::success; |
282 | } |
283 | |
284 | return status::unimplemented; |
285 | } |
286 | |
287 | std::shared_ptr<primitive_desc_t> conv_pd_; |
288 | |
289 | private: |
290 | void init_scratchpad() { |
291 | auto scratchpad = scratchpad_registry().registrar(); |
292 | scratchpad.book(memory_tracking::names::key_nested, |
293 | conv_pd_->scratchpad_registry()); |
294 | } |
295 | }; |
296 | |
297 | status_t init(engine_t *engine) override { |
298 | return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); |
299 | } |
300 | |
301 | status_t execute(const exec_ctx_t &ctx) const override { |
302 | using namespace memory_tracking::names; |
303 | const auto &args = ctx.args(); |
304 | exec_args_t conv_args; |
305 | conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); |
306 | conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); |
307 | conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC); |
308 | if (!types::is_zero_md(pd()->scratchpad_md())) |
309 | conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD); |
310 | exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
311 | |
312 | nested_scratchpad_t ns(ctx, key_nested, conv_p_); |
313 | conv_ctx.set_scratchpad_grantor(ns.grantor()); |
314 | // Executing the convolution kernel |
315 | return conv_p_->execute(conv_ctx); |
316 | } |
317 | |
318 | private: |
319 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
320 | std::shared_ptr<primitive_t> conv_p_; |
321 | }; |
322 | |
323 | struct ref_deconvolution_bwd_weights_t : public gpu_primitive_t { |
324 | using gpu_primitive_t::gpu_primitive_t; |
325 | struct pd_t : public gpu_deconvolution_bwd_weights_pd_t { |
326 | pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, |
327 | const deconvolution_fwd_pd_t *hint_fwd_pd) |
328 | : gpu_deconvolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {} |
329 | |
330 | pd_t(const pd_t &other) = default; |
331 | |
332 | ~pd_t() = default; |
333 | |
334 | DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t); |
335 | |
336 | status_t init_convolution(engine_t *engine) { |
337 | convolution_desc_t cd; |
338 | CHECK(conv_descr_create(desc(), &cd)); |
339 | primitive_attr_t conv_attr(*attr()); |
340 | if (!conv_attr.is_initialized()) return status::out_of_memory; |
341 | primitive_desc_iterator_t it( |
342 | engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
343 | if (!it.is_initialized()) return status::out_of_memory; |
344 | conv_pd_ = *(++it); |
345 | return status::success; |
346 | } |
347 | |
348 | status_t init(engine_t *engine) { |
349 | using namespace format_tag; |
350 | bool ok = desc()->prop_kind == prop_kind::backward_weights |
351 | && (utils::everyone_is(data_type::f32, |
352 | desc()->src_desc.data_type, |
353 | desc()->diff_weights_desc.data_type, |
354 | desc()->diff_dst_desc.data_type) |
355 | || utils::everyone_is(data_type::bf16, |
356 | desc()->diff_dst_desc.data_type, |
357 | desc()->src_desc.data_type)) |
358 | && utils::one_of( |
359 | desc()->alg_kind, alg_kind::deconvolution_direct) |
360 | && attr()->has_default_values() |
361 | && utils::one_of(desc()->diff_weights_desc.data_type, |
362 | data_type::bf16, data_type::f32); |
363 | if (ok) { |
364 | CHECK(init_convolution(engine)); |
365 | if (diff_weights_md_.format_kind == format_kind::any) |
366 | CHECK(weights_axes_permutation(&diff_weights_md_, |
367 | conv_pd_->diff_weights_md(), with_groups())); |
368 | if (src_md_.format_kind == format_kind::any) |
369 | src_md_ = *conv_pd_->diff_dst_md(); |
370 | if (diff_dst_md_.format_kind == format_kind::any) |
371 | diff_dst_md_ = *conv_pd_->src_md(); |
372 | if (diff_bias_md_.format_kind == format_kind::any) |
373 | CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); |
374 | init_scratchpad(); |
375 | |
376 | return status::success; |
377 | } |
378 | |
379 | return status::unimplemented; |
380 | } |
381 | |
382 | std::shared_ptr<primitive_desc_t> conv_pd_; |
383 | |
384 | private: |
385 | void init_scratchpad() { |
386 | auto scratchpad = scratchpad_registry().registrar(); |
387 | scratchpad.book(memory_tracking::names::key_nested, |
388 | conv_pd_->scratchpad_registry()); |
389 | } |
390 | }; |
391 | |
392 | status_t init(engine_t *engine) override { |
393 | // Creating convolution primitve |
394 | CHECK(create_nested_primitive(conv_p_, pd()->conv_pd_, engine)); |
395 | |
396 | if (!pd()->with_bias()) return status::success; |
397 | // Initializing values for the deconv bias kernel |
398 | compute::kernel_ctx_t kernel_ctx; |
399 | |
400 | memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md()); |
401 | kernel_ctx.set_data_type(pd()->diff_dst_md()->data_type); |
402 | offsets_t off; |
403 | set_offsets(diff_dst_mdw, off.dst_off); |
404 | def_offsets(off.dst_off, kernel_ctx, "DST" , |
405 | pd()->desc()->diff_dst_desc.ndims); |
406 | |
407 | kernel_ctx.define_int("MB" , pd()->MB()); |
408 | kernel_ctx.define_int("OH" , pd()->OH()); |
409 | kernel_ctx.define_int("OW" , pd()->OW()); |
410 | kernel_ctx.define_int("OD" , pd()->OD()); |
411 | kernel_ctx.define_int("OC" , pd()->OC() / pd()->G()); |
412 | kernel_ctx.define_int("NDIMS" , pd()->desc()->src_desc.ndims); |
413 | |
414 | gws[0] = pd()->OC(); |
415 | gws[1] = 1; |
416 | gws[2] = 1; |
417 | |
418 | dst_data_type = pd()->diff_dst_md()->data_type; |
419 | bias_data_type = pd()->diff_weights_md(1)->data_type; |
420 | accum_data_type = pd()->desc()->accum_data_type; |
421 | |
422 | def_data_type(kernel_ctx, dst_data_type, "DST" ); |
423 | def_data_type(kernel_ctx, bias_data_type, "BIA" ); |
424 | def_data_type(kernel_ctx, accum_data_type, "ACC" ); |
425 | |
426 | create_kernel( |
427 | engine, &bias_kernel_, "ref_deconv_backward_bias" , kernel_ctx); |
428 | if (!bias_kernel_) return status::runtime_error; |
429 | |
430 | return status::success; |
431 | } |
432 | |
433 | status_t execute(const exec_ctx_t &ctx) const override { |
434 | using namespace memory_tracking::names; |
435 | |
436 | const auto &args = ctx.args(); |
437 | exec_args_t conv_args; |
438 | conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); |
439 | conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); |
440 | conv_args[DNNL_ARG_DIFF_WEIGHTS] = args.at(DNNL_ARG_DIFF_WEIGHTS); |
441 | if (!types::is_zero_md(pd()->scratchpad_md())) |
442 | conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD); |
443 | exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
444 | |
445 | nested_scratchpad_t ns(ctx, key_nested, conv_p_); |
446 | conv_ctx.set_scratchpad_grantor(ns.grantor()); |
447 | |
448 | status_t status = conv_p_->execute(conv_ctx); |
449 | if (status != status::success) return status; |
450 | |
451 | if (pd()->with_bias()) { |
452 | // Calling the bias kernel if bias=1 |
453 | auto &diff_bias = CTX_OUT_STORAGE(DNNL_ARG_DIFF_BIAS); |
454 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
455 | |
456 | compute::kernel_arg_list_t arg_list; |
457 | arg_list.set(0, diff_dst); |
458 | arg_list.set(1, diff_bias); |
459 | |
460 | // Setting up global work-space to {OC*G, 1, 1} |
461 | auto nd_range = compute::nd_range_t({gws[0], gws[1], gws[2]}); |
462 | status = parallel_for(ctx, nd_range, bias_kernel_, arg_list); |
463 | } |
464 | return status::success; |
465 | } |
466 | |
467 | private: |
468 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
469 | std::shared_ptr<primitive_t> conv_p_; |
470 | compute::kernel_t bias_kernel_; |
471 | size_t gws[3]; |
472 | data_type_t dst_data_type = data_type::undef; |
473 | data_type_t bias_data_type = data_type::undef; |
474 | data_type_t accum_data_type = data_type::undef; |
475 | }; |
476 | |
477 | } // namespace ocl |
478 | } // namespace gpu |
479 | } // namespace impl |
480 | } // namespace dnnl |
481 | |
482 | #endif |
483 | |