1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_inner_product.hpp" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_traits.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace ocl { |
28 | |
29 | static status_t init_conf_common(inner_product_conf_t &conf, offsets_t &off, |
30 | const inner_product_pd_t *pd, engine_t *engine) { |
31 | const inner_product_desc_t &ipd = *pd->desc(); |
32 | const memory_desc_wrapper src_d(pd->invariant_src_md()); |
33 | const memory_desc_wrapper wei_d(pd->invariant_wei_md()); |
34 | const memory_desc_wrapper dst_d(pd->invariant_dst_md()); |
35 | data_type_t acc_data_type = pd->desc()->accum_data_type; |
36 | |
37 | const int ndims = src_d.ndims(); |
38 | |
39 | conf.ndims = ndims; |
40 | conf.src_ndims = src_d.ndims(); |
41 | conf.wei_ndims = wei_d.ndims(); |
42 | conf.dst_ndims = dst_d.ndims(); |
43 | |
44 | conf.has_spatial = utils::one_of(conf.ndims, 3, 4, 5); |
45 | |
46 | conf.mb = pd->MB(); |
47 | conf.ic = pd->IC(); |
48 | |
49 | conf.id = pd->ID(); |
50 | conf.ih = pd->IH(); |
51 | conf.iw = pd->IW(); |
52 | |
53 | const auto &src_dims = src_d.padded_dims(); |
54 | conf.ic_total = utils::array_product(&src_dims[1], conf.ndims - 1); |
55 | |
56 | conf.oc = pd->OC(); |
57 | |
58 | conf.od = pd->OD(); |
59 | conf.oh = pd->OH(); |
60 | conf.ow = pd->OW(); |
61 | |
62 | conf.kd = pd->KD(); |
63 | conf.kh = pd->KH(); |
64 | conf.kw = pd->KW(); |
65 | |
66 | conf.src_dt = src_d.data_type(); |
67 | conf.wei_dt = wei_d.data_type(); |
68 | conf.dst_dt = dst_d.data_type(); |
69 | conf.acc_dt = acc_data_type; |
70 | |
71 | conf.is_forward = utils::one_of( |
72 | ipd.prop_kind, prop_kind::forward, prop_kind::forward_inference); |
73 | conf.is_backward_data = ipd.prop_kind == prop_kind::backward_data; |
74 | conf.is_backward_weights = ipd.prop_kind == prop_kind::backward_weights; |
75 | |
76 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
77 | if (conf.is_forward) { |
78 | conf.with_bias = ipd.bias_desc.format_kind != format_kind::undef; |
79 | conf.bia_dt = conf.with_bias ? ipd.bias_desc.data_type : data_type::f32; |
80 | conf.dispatch = compute_engine->create_dispatch(dst_d.md_); |
81 | conf.dispatch.define_dim("MB" , 0, conf.mb); |
82 | conf.dispatch.define_dim("OC" , 1, conf.oc); |
83 | conf.dispatch.generate(); |
84 | } else if (conf.is_backward_weights) { |
85 | conf.with_bias = ipd.diff_bias_desc.format_kind != format_kind::undef; |
86 | conf.bia_dt = conf.with_bias ? ipd.diff_bias_desc.data_type |
87 | : data_type::f32; |
88 | conf.dispatch = compute_engine->create_dispatch(wei_d.md_); |
89 | conf.dispatch.define_dim("OC" , 0, conf.oc); |
90 | conf.dispatch.define_dim("IC" , 1, conf.ic); |
91 | conf.dispatch.define_dim("KD" , nstl::max(1, ndims - 3), conf.kd); |
92 | conf.dispatch.define_dim("KH" , nstl::max(1, ndims - 2), conf.kh); |
93 | conf.dispatch.define_dim("KW" , nstl::max(1, ndims - 1), conf.kw); |
94 | conf.dispatch.generate(); |
95 | } else { |
96 | conf.with_bias = false; |
97 | conf.bia_dt = data_type::f32; |
98 | conf.dispatch = compute_engine->create_dispatch(src_d.md_); |
99 | conf.dispatch.define_dim("MB_IC" , 0, conf.mb * conf.ic); |
100 | conf.dispatch.define_dim("KD" , nstl::max(1, ndims - 3), conf.kd); |
101 | conf.dispatch.define_dim("KH" , nstl::max(1, ndims - 2), conf.kh); |
102 | conf.dispatch.define_dim("KW" , nstl::max(1, ndims - 1), conf.kw); |
103 | conf.dispatch.generate(); |
104 | } |
105 | |
106 | set_offsets(src_d, off.src_off); |
107 | set_offsets(wei_d, off.wei_off); |
108 | set_offsets(dst_d, off.dst_off); |
109 | |
110 | conf.attr_info = attr_info_t::create(pd->attr()); |
111 | |
112 | return status::success; |
113 | } |
114 | |
115 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
116 | const inner_product_conf_t &conf, const offsets_t &off, |
117 | const post_ops_t &post_ops) { |
118 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
119 | kernel_ctx.define_int("MB" , conf.mb); |
120 | kernel_ctx.define_int("OC" , conf.oc); |
121 | kernel_ctx.define_int("IC" , conf.ic); |
122 | kernel_ctx.define_int("IC_TOTAL" , conf.ic_total); |
123 | kernel_ctx.define_int("ID" , conf.id); |
124 | kernel_ctx.define_int("IH" , conf.ih); |
125 | kernel_ctx.define_int("IW" , conf.iw); |
126 | kernel_ctx.define_int("OD" , conf.od); |
127 | kernel_ctx.define_int("OH" , conf.oh); |
128 | kernel_ctx.define_int("OW" , conf.ow); |
129 | kernel_ctx.define_int("KD" , conf.kd); |
130 | kernel_ctx.define_int("KH" , conf.kh); |
131 | kernel_ctx.define_int("KW" , conf.kw); |
132 | if (conf.with_bias) kernel_ctx.define_int("WITH_BIAS" , 1); |
133 | if (conf.has_spatial) kernel_ctx.define_int("HAS_SPATIAL" , 1); |
134 | |
135 | if (conf.is_forward) |
136 | kernel_ctx.define_int("IS_FWD" , 1); |
137 | else if (conf.is_backward_data) |
138 | kernel_ctx.define_int("IS_BWD_D" , 1); |
139 | else if (conf.is_backward_weights) |
140 | kernel_ctx.define_int("IS_BWD_W" , 1); |
141 | |
142 | def_attr_info(kernel_ctx, conf.attr_info, post_ops); |
143 | |
144 | def_offsets(off.src_off, kernel_ctx, "SRC" , conf.src_ndims); |
145 | def_offsets(off.wei_off, kernel_ctx, "WEI" , conf.wei_ndims); |
146 | def_offsets(off.dst_off, kernel_ctx, "DST" , conf.dst_ndims); |
147 | |
148 | if (conf.src_dt == data_type::f16) |
149 | kernel_ctx.set_data_type(data_type::f16); |
150 | else |
151 | kernel_ctx.set_data_type(data_type::f32); |
152 | |
153 | def_data_type(kernel_ctx, conf.src_dt, "SRC" ); |
154 | def_data_type(kernel_ctx, conf.wei_dt, "WEI" ); |
155 | def_data_type(kernel_ctx, conf.bia_dt, "BIA" ); |
156 | def_data_type(kernel_ctx, conf.dst_dt, "DST" ); |
157 | def_data_type(kernel_ctx, conf.acc_dt, "ACC" ); |
158 | |
159 | def_dispatch(kernel_ctx, conf.dispatch); |
160 | |
161 | return status::success; |
162 | } |
163 | |
164 | status_t ref_inner_product_fwd_t::pd_t::init_conf(engine_t *engine) { |
165 | return init_conf_common(conf, off, this, engine); |
166 | } |
167 | |
168 | status_t ref_inner_product_fwd_t::pd_t::init_kernel_ctx( |
169 | compute::kernel_ctx_t &kernel_ctx) const { |
170 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
171 | } |
172 | |
173 | status_t ref_inner_product_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
174 | status_t status = status::success; |
175 | |
176 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
177 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
178 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
179 | auto &dst = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status); |
180 | CHECK(status); |
181 | |
182 | const auto &conf = pd()->conf; |
183 | |
184 | compute::kernel_arg_list_t arg_list; |
185 | arg_list.set(0, src); |
186 | arg_list.set(1, weights); |
187 | arg_list.set(2, bias); |
188 | arg_list.set(3, dst); |
189 | |
190 | unsigned arg_idx = append_post_ops_to_arg_list( |
191 | ctx, arg_list, 4, pd()->attr()->post_ops_); |
192 | |
193 | auto &src_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
194 | auto &wei_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); |
195 | auto &dst_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
196 | |
197 | arg_list.set(arg_idx++, src_scales); |
198 | arg_list.set(arg_idx++, wei_scales); |
199 | arg_list.set(arg_idx++, dst_scales); |
200 | |
201 | auto nd_range = conf.dispatch.nd_range(); |
202 | |
203 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
204 | |
205 | return status; |
206 | } |
207 | |
208 | status_t ref_inner_product_bwd_data_t::pd_t::init_conf(engine_t *engine) { |
209 | return init_conf_common(conf, off, this, engine); |
210 | } |
211 | |
212 | status_t ref_inner_product_bwd_data_t::pd_t::init_kernel_ctx( |
213 | compute::kernel_ctx_t &kernel_ctx) const { |
214 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
215 | } |
216 | |
217 | status_t ref_inner_product_bwd_data_t::execute_backward_data( |
218 | const exec_ctx_t &ctx) const { |
219 | status_t status = status::success; |
220 | |
221 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
222 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
223 | auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status); |
224 | CHECK(status); |
225 | |
226 | const auto &conf = pd()->conf; |
227 | |
228 | compute::kernel_arg_list_t arg_list; |
229 | arg_list.set(0, diff_src); |
230 | arg_list.set(1, weights); |
231 | arg_list.set(2, diff_dst); |
232 | |
233 | auto nd_range = conf.dispatch.nd_range(); |
234 | |
235 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
236 | |
237 | return status; |
238 | } |
239 | |
240 | status_t ref_inner_product_bwd_weights_t::pd_t::init_conf(engine_t *engine) { |
241 | return init_conf_common(conf, off, this, engine); |
242 | } |
243 | |
244 | status_t ref_inner_product_bwd_weights_t::pd_t::init_kernel_ctx( |
245 | compute::kernel_ctx_t &kernel_ctx) const { |
246 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
247 | } |
248 | |
249 | status_t ref_inner_product_bwd_weights_t::execute_backward_weights( |
250 | const exec_ctx_t &ctx) const { |
251 | status_t status = status::success; |
252 | |
253 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
254 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
255 | auto &diff_weights = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_WEIGHTS, status); |
256 | CHECK(status); |
257 | auto &diff_bias = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_BIAS, status); |
258 | CHECK(status); |
259 | |
260 | const auto &conf = pd()->conf; |
261 | |
262 | compute::kernel_arg_list_t arg_list; |
263 | arg_list.set(0, src); |
264 | arg_list.set(1, diff_weights); |
265 | arg_list.set(2, diff_bias); |
266 | arg_list.set(3, diff_dst); |
267 | |
268 | auto nd_range = conf.dispatch.nd_range(); |
269 | |
270 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
271 | |
272 | return status; |
273 | } |
274 | |
275 | } // namespace ocl |
276 | } // namespace gpu |
277 | } // namespace impl |
278 | } // namespace dnnl |
279 | |