1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_prelu.hpp" |
18 | #include "gpu/ocl/ocl_utils.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace gpu { |
23 | namespace ocl { |
24 | |
25 | static status_t init_conf_common( |
26 | prelu_conf_t &conf, const prelu_pd_t *pd, engine_t *engine) { |
27 | |
28 | conf.is_forward = pd->is_fwd(); |
29 | |
30 | const memory_desc_wrapper src_mdw(pd->src_md(0)); |
31 | const memory_desc_wrapper wei_mdw(pd->weights_md(0)); |
32 | const memory_desc_wrapper dst_mdw( |
33 | conf.is_forward ? pd->dst_md(0) : pd->diff_dst_md(0)); |
34 | |
35 | conf.src_md_info = memory_desc_info_t::create(src_mdw); |
36 | conf.wei_md_info = memory_desc_info_t::create(wei_mdw); |
37 | conf.dst_md_info = memory_desc_info_t::create(dst_mdw); |
38 | if (!conf.is_forward) { |
39 | const memory_desc_wrapper diff_src_mdw(pd->diff_src_md(0)); |
40 | const memory_desc_wrapper diff_weights_mdw(pd->diff_weights_md(0)); |
41 | conf.reduce_diff_weights |
42 | = src_mdw.nelems() != diff_weights_mdw.nelems(); |
43 | |
44 | conf.diff_src_md_info = memory_desc_info_t::create(diff_src_mdw); |
45 | |
46 | if (conf.reduce_diff_weights) { |
47 | memory_desc_t red_diff_mem_desc(*pd->src_md(0)); |
48 | red_diff_mem_desc.data_type = dnnl_f32; |
49 | const memory_desc_wrapper red_diff_mdw(red_diff_mem_desc); |
50 | conf.diff_wei_md_info = memory_desc_info_t::create(red_diff_mdw); |
51 | } else { |
52 | conf.diff_wei_md_info |
53 | = memory_desc_info_t::create(diff_weights_mdw); |
54 | } |
55 | } |
56 | |
57 | const auto &ndims = dst_mdw.ndims(); |
58 | |
59 | const auto *compute_engine |
60 | = utils::downcast<compute::compute_engine_t *>(engine); |
61 | conf.dispatch = compute_engine->create_dispatch(src_mdw.md_); |
62 | |
63 | for (int i = 0; i < MAX_NDIMS; ++i) { |
64 | if (i < ndims) { |
65 | const dnnl_dim_t diff_wei_dim = conf.is_forward |
66 | ? 1 |
67 | : static_cast<dnnl_dim_t>( |
68 | conf.diff_wei_md_info.padded_dims[i]); |
69 | dnnl_dim_t dim2dispatch |
70 | = nstl::max(dst_mdw.padded_dims()[i], diff_wei_dim); |
71 | conf.dispatch.define_dim(utils::format("D%d" , i), i, dim2dispatch); |
72 | } else |
73 | conf.dispatch.define_dim(utils::format("D%d" , i), 1); |
74 | } |
75 | conf.dispatch.generate(false); |
76 | |
77 | return status::success; |
78 | }; |
79 | |
80 | static status_t init_kernel_ctx_common( |
81 | compute::kernel_ctx_t &kernel_ctx, const prelu_conf_t &conf) { |
82 | |
83 | kernel_ctx.set_data_type(conf.dst_md_info.data_type); |
84 | def_eltwise_alg_kinds(kernel_ctx); |
85 | kernel_ctx.define_int("WITH_ELTWISE" , 1); |
86 | |
87 | kernel_ctx.define_int("IS_FWD" , conf.is_forward); |
88 | |
89 | def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC" ); |
90 | def_memory_desc_info(kernel_ctx, conf.wei_md_info, "WEI" ); |
91 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
92 | if (!conf.is_forward) { |
93 | def_memory_desc_info(kernel_ctx, conf.diff_src_md_info, "DIFF_SRC" ); |
94 | def_memory_desc_info(kernel_ctx, conf.diff_wei_md_info, "DIFF_WEI" ); |
95 | } |
96 | |
97 | def_dispatch(kernel_ctx, conf.dispatch); |
98 | |
99 | return status::success; |
100 | } |
101 | |
102 | status_t ref_prelu_fwd_t::pd_t::init_conf(engine_t *engine) { |
103 | return init_conf_common(conf, this, engine); |
104 | } |
105 | |
106 | status_t ref_prelu_fwd_t::pd_t::init_kernel_ctx( |
107 | compute::kernel_ctx_t &kernel_ctx) const { |
108 | return init_kernel_ctx_common(kernel_ctx, conf); |
109 | } |
110 | |
111 | status_t ref_prelu_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
112 | |
113 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
114 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
115 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
116 | |
117 | compute::kernel_arg_list_t arg_list; |
118 | arg_list.set(0, src); |
119 | arg_list.set(1, weights); |
120 | arg_list.set(2, dst); |
121 | |
122 | auto nd_range = pd()->conf.dispatch.nd_range(); |
123 | |
124 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
125 | |
126 | return status; |
127 | } |
128 | |
129 | status_t ref_prelu_bwd_t::pd_t::init_conf(engine_t *engine) { |
130 | return init_conf_common(conf, this, engine); |
131 | } |
132 | |
133 | status_t ref_prelu_bwd_t::pd_t::init_kernel_ctx( |
134 | compute::kernel_ctx_t &kernel_ctx) const { |
135 | return init_kernel_ctx_common(kernel_ctx, conf); |
136 | } |
137 | |
138 | void ref_prelu_bwd_t::pd_t::init_scratchpad() { |
139 | if (conf.reduce_diff_weights) { |
140 | auto scratchpad = scratchpad_registry().registrar(); |
141 | size_t size = utils::array_product( |
142 | conf.dst_md_info.padded_dims, conf.dst_md_info.ndims); |
143 | scratchpad.book(memory_tracking::names::key_prelu_reduction, size, |
144 | types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT); |
145 | |
146 | scratchpad.book(memory_tracking::names::key_nested, |
147 | reduction_pd_->scratchpad_registry()); |
148 | } |
149 | } |
150 | |
151 | status_t ref_prelu_bwd_t::execute_backward(const exec_ctx_t &ctx) const { |
152 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
153 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
154 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
155 | |
156 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
157 | auto &diff_weights = CTX_OUT_STORAGE(DNNL_ARG_DIFF_WEIGHTS); |
158 | |
159 | const auto &conf = pd()->conf; |
160 | |
161 | std::unique_ptr<memory_t> diff_weights_to_reduce; |
162 | if (conf.reduce_diff_weights) { |
163 | auto scratchpad = ctx.get_scratchpad_grantor().get_memory_storage( |
164 | memory_tracking::names::key_prelu_reduction); |
165 | CHECK(safe_ptr_assign(diff_weights_to_reduce, |
166 | new memory_t(ctx.stream()->engine(), pd()->dst_md(0), |
167 | std::move(scratchpad)))); |
168 | } |
169 | |
170 | const auto &diff_weight_arg = conf.reduce_diff_weights |
171 | ? *diff_weights_to_reduce->memory_storage() |
172 | : diff_weights; |
173 | |
174 | compute::kernel_arg_list_t arg_list; |
175 | arg_list.set(0, src); |
176 | arg_list.set(1, weights); |
177 | arg_list.set(2, diff_dst); |
178 | arg_list.set(3, diff_src); |
179 | arg_list.set(4, diff_weight_arg); |
180 | |
181 | auto nd_range = pd()->conf.dispatch.nd_range(); |
182 | |
183 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
184 | |
185 | if (conf.reduce_diff_weights) { |
186 | exec_args_t reduction_args; |
187 | reduction_args[DNNL_ARG_SRC] |
188 | = memory_arg_t {diff_weights_to_reduce.get(), true}; |
189 | reduction_args[DNNL_ARG_DST] = ctx.args().at(DNNL_ARG_DIFF_WEIGHTS); |
190 | exec_ctx_t reduction_ctx(ctx, std::move(reduction_args)); |
191 | |
192 | nested_scratchpad_t ns( |
193 | ctx, memory_tracking::names::key_nested, reduction_p_); |
194 | reduction_ctx.set_scratchpad_grantor(ns.grantor()); |
195 | // Executing the reduction kernel |
196 | return reduction_p_->execute(reduction_ctx); |
197 | } |
198 | return status; |
199 | } |
200 | |
201 | } // namespace ocl |
202 | } // namespace gpu |
203 | } // namespace impl |
204 | } // namespace dnnl |
205 | |