1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_DECONVOLUTION_HPP
18#define GPU_OCL_REF_DECONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_convolution_pd.hpp"
26#include "gpu/gpu_deconvolution_pd.hpp"
27#include "gpu/gpu_primitive.hpp"
28#include "gpu/gpu_resource.hpp"
29#include "gpu/ocl/ocl_stream.hpp"
30#include "gpu/primitive_conf.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37static status_t weights_axes_permutation(
38 memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) {
39 int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation
40 for (int d = 0; d < DNNL_MAX_NDIMS; ++d)
41 perm[d] = d;
42 nstl::swap(perm[0 + with_groups], perm[1 + with_groups]);
43
44 return memory_desc_permute_axes(*o_md, *i_md, perm);
45}
46
47static status_t conv_descr_create(
48 const deconvolution_desc_t *dd, convolution_desc_t *cd) {
49 using namespace prop_kind;
50 alg_kind_t alg_kind = alg_kind::convolution_direct;
51
52 const memory_desc_t *src_md, *dst_md, *d_weights_d;
53 prop_kind_t prop_kind;
54
55 switch (dd->prop_kind) {
56 case forward:
57 case forward_inference:
58 prop_kind = backward_data;
59 src_md = &dd->dst_desc;
60 dst_md = &dd->src_desc;
61 d_weights_d = &dd->weights_desc;
62 break;
63 case backward_data:
64 prop_kind = forward_training;
65 src_md = &dd->diff_dst_desc;
66 dst_md = &dd->diff_src_desc;
67 d_weights_d = &dd->weights_desc;
68 break;
69 case backward_weights:
70 prop_kind = dd->prop_kind;
71 src_md = &dd->diff_dst_desc;
72 dst_md = &dd->src_desc;
73 d_weights_d = &dd->diff_weights_desc;
74 break;
75 default: assert(!"unknown prop kind"); return status::invalid_arguments;
76 }
77
78 // Create weights desc for convolution
79 memory_desc_t c_weights_d;
80 const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
81 CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups));
82
83 return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
84 prop_kind != backward_weights ? &dd->bias_desc : nullptr, dst_md,
85 dd->strides, dd->dilates, dd->padding[0], dd->padding[1]);
86}
87
88struct ref_deconvolution_fwd_t : public gpu_primitive_t {
89 using gpu_primitive_t::gpu_primitive_t;
90 struct pd_t : public gpu_deconvolution_fwd_pd_t {
91 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
92 const deconvolution_fwd_pd_t *hint_fwd_pd)
93 : gpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
94
95 pd_t(const pd_t &other) = default;
96
97 ~pd_t() = default;
98
99 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t);
100
101 status_t init_convolution(engine_t *engine) {
102 convolution_desc_t cd;
103 CHECK(conv_descr_create(desc(), &cd));
104 primitive_attr_t conv_attr(*attr());
105 if (!conv_attr.is_initialized()) return status::out_of_memory;
106 primitive_desc_iterator_t it(
107 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
108 if (!it.is_initialized()) return status::out_of_memory;
109 conv_pd_ = *(++it);
110 return (conv_pd_) ? status::success : status::unimplemented;
111 }
112
113 status_t init(engine_t *engine) {
114 using namespace format_tag;
115 using sm = primitive_attr_t::skip_mask_t;
116
117 const auto attr_skip_mask = sm::post_ops | sm::zero_points_runtime
118 | sm::scales_runtime;
119
120 bool ok = is_fwd()
121 && desc()->alg_kind == alg_kind::deconvolution_direct
122 && attr()->has_default_values(attr_skip_mask)
123 && post_ops_with_binary_ok(
124 attr(), desc()->dst_desc.data_type, ndims())
125 && (utils::everyone_is(data_type::f32,
126 desc()->src_desc.data_type,
127 desc()->weights_desc.data_type,
128 desc()->dst_desc.data_type)
129 || ((utils::everyone_is(data_type::f16,
130 desc()->src_desc.data_type,
131 desc()->weights_desc.data_type)
132 || utils::everyone_is(data_type::f32,
133 desc()->src_desc.data_type,
134 desc()->weights_desc.data_type)
135 || utils::everyone_is(data_type::bf16,
136 desc()->src_desc.data_type,
137 desc()->weights_desc.data_type))
138 && utils::one_of(desc()->dst_desc.data_type,
139 data_type::f16, data_type::u8,
140 data_type::s8))
141 || (utils::everyone_is(data_type::bf16,
142 desc()->src_desc.data_type,
143 desc()->weights_desc.data_type)
144 && utils::one_of(desc()->dst_desc.data_type,
145 data_type::f32, data_type::bf16))
146 || (desc()->weights_desc.data_type == data_type::s8
147 && utils::one_of(desc()->src_desc.data_type,
148 data_type::u8, data_type::s8)
149 && desc()->dst_desc.data_type
150 != data_type::f64));
151 if (ok) {
152 CHECK(init_convolution(engine));
153 if (weights_md_.format_kind == format_kind::any)
154 CHECK(weights_axes_permutation(&weights_md_,
155 conv_pd_->weights_md(), with_groups()));
156 if (src_md_.format_kind == format_kind::any)
157 src_md_ = *conv_pd_->diff_dst_md();
158 if (dst_md_.format_kind == format_kind::any)
159 dst_md_ = *conv_pd_->diff_src_md();
160 if (bias_md_.format_kind == format_kind::any)
161 CHECK(memory_desc_init_by_tag(bias_md_, x));
162 init_scratchpad();
163 CHECK(attr_.set_default_formats(dst_md(0)));
164
165 return status::success;
166 }
167
168 return status::unimplemented;
169 }
170
171 std::shared_ptr<primitive_desc_t> conv_pd_;
172
173 private:
174 void init_scratchpad() {
175 auto scratchpad = scratchpad_registry().registrar();
176 scratchpad.book(memory_tracking::names::key_nested,
177 conv_pd_->scratchpad_registry());
178 }
179 };
180
181 status_t init(engine_t *engine) override {
182 return create_nested_primitive(conv_p_, pd()->conv_pd_, engine);
183 }
184
185 status_t execute(const exec_ctx_t &ctx) const override {
186 using namespace memory_tracking::names;
187 const auto &args = ctx.args();
188 exec_args_t conv_args;
189 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
190 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
191 conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST);
192 if (pd()->with_bias())
193 conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);
194
195 for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) {
196 if (pd()->attr()->post_ops_.entry_[idx].is_binary()) {
197 conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1]
198 = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
199 | DNNL_ARG_SRC_1);
200 } else if (pd()->attr()->post_ops_.entry_[idx].is_prelu()) {
201 conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
202 | DNNL_ARG_WEIGHTS]
203 = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
204 | DNNL_ARG_WEIGHTS);
205 }
206 }
207 const auto z_src = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC;
208 const auto z_dst = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST;
209 if (args.find(z_src) != args.end()) conv_args[z_src] = args.at(z_src);
210 if (args.find(z_dst) != args.end()) conv_args[z_dst] = args.at(z_dst);
211
212 for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
213 int key = DNNL_ARG_ATTR_SCALES | arg;
214 if (args.find(key) != args.end()) conv_args[key] = args.at(key);
215 }
216
217 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
218
219 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
220 conv_ctx.set_scratchpad_grantor(ns.grantor());
221 // Executing the convolution kernel
222 return conv_p_->execute(conv_ctx);
223 }
224
225private:
226 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
227 std::shared_ptr<primitive_t> conv_p_;
228};
229
230struct ref_deconvolution_bwd_data_t : public gpu_primitive_t {
231 using gpu_primitive_t::gpu_primitive_t;
232 struct pd_t : public gpu_deconvolution_bwd_data_pd_t {
233 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
234 const deconvolution_fwd_pd_t *hint_fwd_pd)
235 : gpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
236 , conv_pd_(nullptr) {}
237
238 pd_t(const pd_t &other) = default;
239
240 ~pd_t() = default;
241
242 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t);
243
244 status_t init_convolution(engine_t *engine) {
245 convolution_desc_t cd;
246 CHECK(conv_descr_create(desc(), &cd));
247 primitive_attr_t conv_attr(*attr());
248 if (!conv_attr.is_initialized()) return status::out_of_memory;
249 primitive_desc_iterator_t it(
250 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
251 if (!it.is_initialized()) return status::out_of_memory;
252 conv_pd_ = *(++it);
253 return status::success;
254 }
255
256 status_t init(engine_t *engine) {
257 bool ok = desc()->prop_kind == prop_kind::backward_data
258 && (utils::everyone_is(data_type::f32,
259 desc()->diff_src_desc.data_type,
260 desc()->weights_desc.data_type,
261 desc()->diff_dst_desc.data_type)
262 || utils::everyone_is(data_type::bf16,
263 desc()->weights_desc.data_type,
264 desc()->diff_dst_desc.data_type))
265 && utils::one_of(desc()->diff_src_desc.data_type,
266 data_type::bf16, data_type::f32)
267 && desc()->alg_kind == alg_kind::deconvolution_direct
268 && attr()->has_default_values();
269
270 if (ok) {
271 CHECK(init_convolution(engine));
272 if (weights_md_.format_kind == format_kind::any)
273 CHECK(weights_axes_permutation(&weights_md_,
274 conv_pd_->weights_md(), with_groups()));
275 if (diff_src_md_.format_kind == format_kind::any)
276 diff_src_md_ = *conv_pd_->dst_md();
277 if (diff_dst_md_.format_kind == format_kind::any)
278 diff_dst_md_ = *conv_pd_->src_md();
279 init_scratchpad();
280
281 return status::success;
282 }
283
284 return status::unimplemented;
285 }
286
287 std::shared_ptr<primitive_desc_t> conv_pd_;
288
289 private:
290 void init_scratchpad() {
291 auto scratchpad = scratchpad_registry().registrar();
292 scratchpad.book(memory_tracking::names::key_nested,
293 conv_pd_->scratchpad_registry());
294 }
295 };
296
297 status_t init(engine_t *engine) override {
298 return create_nested_primitive(conv_p_, pd()->conv_pd_, engine);
299 }
300
301 status_t execute(const exec_ctx_t &ctx) const override {
302 using namespace memory_tracking::names;
303 const auto &args = ctx.args();
304 exec_args_t conv_args;
305 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
306 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
307 conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
308 if (!types::is_zero_md(pd()->scratchpad_md()))
309 conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD);
310 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
311
312 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
313 conv_ctx.set_scratchpad_grantor(ns.grantor());
314 // Executing the convolution kernel
315 return conv_p_->execute(conv_ctx);
316 }
317
318private:
319 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
320 std::shared_ptr<primitive_t> conv_p_;
321};
322
323struct ref_deconvolution_bwd_weights_t : public gpu_primitive_t {
324 using gpu_primitive_t::gpu_primitive_t;
325 struct pd_t : public gpu_deconvolution_bwd_weights_pd_t {
326 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
327 const deconvolution_fwd_pd_t *hint_fwd_pd)
328 : gpu_deconvolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
329
330 pd_t(const pd_t &other) = default;
331
332 ~pd_t() = default;
333
334 DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t);
335
336 status_t init_convolution(engine_t *engine) {
337 convolution_desc_t cd;
338 CHECK(conv_descr_create(desc(), &cd));
339 primitive_attr_t conv_attr(*attr());
340 if (!conv_attr.is_initialized()) return status::out_of_memory;
341 primitive_desc_iterator_t it(
342 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
343 if (!it.is_initialized()) return status::out_of_memory;
344 conv_pd_ = *(++it);
345 return status::success;
346 }
347
348 status_t init(engine_t *engine) {
349 using namespace format_tag;
350 bool ok = desc()->prop_kind == prop_kind::backward_weights
351 && (utils::everyone_is(data_type::f32,
352 desc()->src_desc.data_type,
353 desc()->diff_weights_desc.data_type,
354 desc()->diff_dst_desc.data_type)
355 || utils::everyone_is(data_type::bf16,
356 desc()->diff_dst_desc.data_type,
357 desc()->src_desc.data_type))
358 && utils::one_of(
359 desc()->alg_kind, alg_kind::deconvolution_direct)
360 && attr()->has_default_values()
361 && utils::one_of(desc()->diff_weights_desc.data_type,
362 data_type::bf16, data_type::f32);
363 if (ok) {
364 CHECK(init_convolution(engine));
365 if (diff_weights_md_.format_kind == format_kind::any)
366 CHECK(weights_axes_permutation(&diff_weights_md_,
367 conv_pd_->diff_weights_md(), with_groups()));
368 if (src_md_.format_kind == format_kind::any)
369 src_md_ = *conv_pd_->diff_dst_md();
370 if (diff_dst_md_.format_kind == format_kind::any)
371 diff_dst_md_ = *conv_pd_->src_md();
372 if (diff_bias_md_.format_kind == format_kind::any)
373 CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
374 init_scratchpad();
375
376 return status::success;
377 }
378
379 return status::unimplemented;
380 }
381
382 std::shared_ptr<primitive_desc_t> conv_pd_;
383
384 private:
385 void init_scratchpad() {
386 auto scratchpad = scratchpad_registry().registrar();
387 scratchpad.book(memory_tracking::names::key_nested,
388 conv_pd_->scratchpad_registry());
389 }
390 };
391
392 status_t init(engine_t *engine) override {
393 // Creating convolution primitve
394 CHECK(create_nested_primitive(conv_p_, pd()->conv_pd_, engine));
395
396 if (!pd()->with_bias()) return status::success;
397 // Initializing values for the deconv bias kernel
398 compute::kernel_ctx_t kernel_ctx;
399
400 memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md());
401 kernel_ctx.set_data_type(pd()->diff_dst_md()->data_type);
402 offsets_t off;
403 set_offsets(diff_dst_mdw, off.dst_off);
404 def_offsets(off.dst_off, kernel_ctx, "DST",
405 pd()->desc()->diff_dst_desc.ndims);
406
407 kernel_ctx.define_int("MB", pd()->MB());
408 kernel_ctx.define_int("OH", pd()->OH());
409 kernel_ctx.define_int("OW", pd()->OW());
410 kernel_ctx.define_int("OD", pd()->OD());
411 kernel_ctx.define_int("OC", pd()->OC() / pd()->G());
412 kernel_ctx.define_int("NDIMS", pd()->desc()->src_desc.ndims);
413
414 gws[0] = pd()->OC();
415 gws[1] = 1;
416 gws[2] = 1;
417
418 dst_data_type = pd()->diff_dst_md()->data_type;
419 bias_data_type = pd()->diff_weights_md(1)->data_type;
420 accum_data_type = pd()->desc()->accum_data_type;
421
422 def_data_type(kernel_ctx, dst_data_type, "DST");
423 def_data_type(kernel_ctx, bias_data_type, "BIA");
424 def_data_type(kernel_ctx, accum_data_type, "ACC");
425
426 create_kernel(
427 engine, &bias_kernel_, "ref_deconv_backward_bias", kernel_ctx);
428 if (!bias_kernel_) return status::runtime_error;
429
430 return status::success;
431 }
432
433 status_t execute(const exec_ctx_t &ctx) const override {
434 using namespace memory_tracking::names;
435
436 const auto &args = ctx.args();
437 exec_args_t conv_args;
438 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
439 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
440 conv_args[DNNL_ARG_DIFF_WEIGHTS] = args.at(DNNL_ARG_DIFF_WEIGHTS);
441 if (!types::is_zero_md(pd()->scratchpad_md()))
442 conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD);
443 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
444
445 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
446 conv_ctx.set_scratchpad_grantor(ns.grantor());
447
448 status_t status = conv_p_->execute(conv_ctx);
449 if (status != status::success) return status;
450
451 if (pd()->with_bias()) {
452 // Calling the bias kernel if bias=1
453 auto &diff_bias = CTX_OUT_STORAGE(DNNL_ARG_DIFF_BIAS);
454 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
455
456 compute::kernel_arg_list_t arg_list;
457 arg_list.set(0, diff_dst);
458 arg_list.set(1, diff_bias);
459
460 // Setting up global work-space to {OC*G, 1, 1}
461 auto nd_range = compute::nd_range_t({gws[0], gws[1], gws[2]});
462 status = parallel_for(ctx, nd_range, bias_kernel_, arg_list);
463 }
464 return status::success;
465 }
466
467private:
468 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
469 std::shared_ptr<primitive_t> conv_p_;
470 compute::kernel_t bias_kernel_;
471 size_t gws[3];
472 data_type_t dst_data_type = data_type::undef;
473 data_type_t bias_data_type = data_type::undef;
474 data_type_t accum_data_type = data_type::undef;
475};
476
477} // namespace ocl
478} // namespace gpu
479} // namespace impl
480} // namespace dnnl
481
482#endif
483