1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_prelu.hpp"
18#include "gpu/ocl/ocl_utils.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace gpu {
23namespace ocl {
24
25static status_t init_conf_common(
26 prelu_conf_t &conf, const prelu_pd_t *pd, engine_t *engine) {
27
28 conf.is_forward = pd->is_fwd();
29
30 const memory_desc_wrapper src_mdw(pd->src_md(0));
31 const memory_desc_wrapper wei_mdw(pd->weights_md(0));
32 const memory_desc_wrapper dst_mdw(
33 conf.is_forward ? pd->dst_md(0) : pd->diff_dst_md(0));
34
35 conf.src_md_info = memory_desc_info_t::create(src_mdw);
36 conf.wei_md_info = memory_desc_info_t::create(wei_mdw);
37 conf.dst_md_info = memory_desc_info_t::create(dst_mdw);
38 if (!conf.is_forward) {
39 const memory_desc_wrapper diff_src_mdw(pd->diff_src_md(0));
40 const memory_desc_wrapper diff_weights_mdw(pd->diff_weights_md(0));
41 conf.reduce_diff_weights
42 = src_mdw.nelems() != diff_weights_mdw.nelems();
43
44 conf.diff_src_md_info = memory_desc_info_t::create(diff_src_mdw);
45
46 if (conf.reduce_diff_weights) {
47 memory_desc_t red_diff_mem_desc(*pd->src_md(0));
48 red_diff_mem_desc.data_type = dnnl_f32;
49 const memory_desc_wrapper red_diff_mdw(red_diff_mem_desc);
50 conf.diff_wei_md_info = memory_desc_info_t::create(red_diff_mdw);
51 } else {
52 conf.diff_wei_md_info
53 = memory_desc_info_t::create(diff_weights_mdw);
54 }
55 }
56
57 const auto &ndims = dst_mdw.ndims();
58
59 const auto *compute_engine
60 = utils::downcast<compute::compute_engine_t *>(engine);
61 conf.dispatch = compute_engine->create_dispatch(src_mdw.md_);
62
63 for (int i = 0; i < MAX_NDIMS; ++i) {
64 if (i < ndims) {
65 const dnnl_dim_t diff_wei_dim = conf.is_forward
66 ? 1
67 : static_cast<dnnl_dim_t>(
68 conf.diff_wei_md_info.padded_dims[i]);
69 dnnl_dim_t dim2dispatch
70 = nstl::max(dst_mdw.padded_dims()[i], diff_wei_dim);
71 conf.dispatch.define_dim(utils::format("D%d", i), i, dim2dispatch);
72 } else
73 conf.dispatch.define_dim(utils::format("D%d", i), 1);
74 }
75 conf.dispatch.generate(false);
76
77 return status::success;
78};
79
80static status_t init_kernel_ctx_common(
81 compute::kernel_ctx_t &kernel_ctx, const prelu_conf_t &conf) {
82
83 kernel_ctx.set_data_type(conf.dst_md_info.data_type);
84 def_eltwise_alg_kinds(kernel_ctx);
85 kernel_ctx.define_int("WITH_ELTWISE", 1);
86
87 kernel_ctx.define_int("IS_FWD", conf.is_forward);
88
89 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
90 def_memory_desc_info(kernel_ctx, conf.wei_md_info, "WEI");
91 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
92 if (!conf.is_forward) {
93 def_memory_desc_info(kernel_ctx, conf.diff_src_md_info, "DIFF_SRC");
94 def_memory_desc_info(kernel_ctx, conf.diff_wei_md_info, "DIFF_WEI");
95 }
96
97 def_dispatch(kernel_ctx, conf.dispatch);
98
99 return status::success;
100}
101
102status_t ref_prelu_fwd_t::pd_t::init_conf(engine_t *engine) {
103 return init_conf_common(conf, this, engine);
104}
105
106status_t ref_prelu_fwd_t::pd_t::init_kernel_ctx(
107 compute::kernel_ctx_t &kernel_ctx) const {
108 return init_kernel_ctx_common(kernel_ctx, conf);
109}
110
111status_t ref_prelu_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
112
113 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
114 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
115 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
116
117 compute::kernel_arg_list_t arg_list;
118 arg_list.set(0, src);
119 arg_list.set(1, weights);
120 arg_list.set(2, dst);
121
122 auto nd_range = pd()->conf.dispatch.nd_range();
123
124 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
125
126 return status;
127}
128
129status_t ref_prelu_bwd_t::pd_t::init_conf(engine_t *engine) {
130 return init_conf_common(conf, this, engine);
131}
132
133status_t ref_prelu_bwd_t::pd_t::init_kernel_ctx(
134 compute::kernel_ctx_t &kernel_ctx) const {
135 return init_kernel_ctx_common(kernel_ctx, conf);
136}
137
138void ref_prelu_bwd_t::pd_t::init_scratchpad() {
139 if (conf.reduce_diff_weights) {
140 auto scratchpad = scratchpad_registry().registrar();
141 size_t size = utils::array_product(
142 conf.dst_md_info.padded_dims, conf.dst_md_info.ndims);
143 scratchpad.book(memory_tracking::names::key_prelu_reduction, size,
144 types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT);
145
146 scratchpad.book(memory_tracking::names::key_nested,
147 reduction_pd_->scratchpad_registry());
148 }
149}
150
151status_t ref_prelu_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
152 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
153 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
154 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
155
156 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
157 auto &diff_weights = CTX_OUT_STORAGE(DNNL_ARG_DIFF_WEIGHTS);
158
159 const auto &conf = pd()->conf;
160
161 std::unique_ptr<memory_t> diff_weights_to_reduce;
162 if (conf.reduce_diff_weights) {
163 auto scratchpad = ctx.get_scratchpad_grantor().get_memory_storage(
164 memory_tracking::names::key_prelu_reduction);
165 CHECK(safe_ptr_assign(diff_weights_to_reduce,
166 new memory_t(ctx.stream()->engine(), pd()->dst_md(0),
167 std::move(scratchpad))));
168 }
169
170 const auto &diff_weight_arg = conf.reduce_diff_weights
171 ? *diff_weights_to_reduce->memory_storage()
172 : diff_weights;
173
174 compute::kernel_arg_list_t arg_list;
175 arg_list.set(0, src);
176 arg_list.set(1, weights);
177 arg_list.set(2, diff_dst);
178 arg_list.set(3, diff_src);
179 arg_list.set(4, diff_weight_arg);
180
181 auto nd_range = pd()->conf.dispatch.nd_range();
182
183 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
184
185 if (conf.reduce_diff_weights) {
186 exec_args_t reduction_args;
187 reduction_args[DNNL_ARG_SRC]
188 = memory_arg_t {diff_weights_to_reduce.get(), true};
189 reduction_args[DNNL_ARG_DST] = ctx.args().at(DNNL_ARG_DIFF_WEIGHTS);
190 exec_ctx_t reduction_ctx(ctx, std::move(reduction_args));
191
192 nested_scratchpad_t ns(
193 ctx, memory_tracking::names::key_nested, reduction_p_);
194 reduction_ctx.set_scratchpad_grantor(ns.grantor());
195 // Executing the reduction kernel
196 return reduction_p_->execute(reduction_ctx);
197 }
198 return status;
199}
200
201} // namespace ocl
202} // namespace gpu
203} // namespace impl
204} // namespace dnnl
205