1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_inner_product.hpp"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_traits.hpp"
21#include "common/math_utils.hpp"
22#include "common/type_helpers.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace ocl {
28
29static status_t init_conf_common(inner_product_conf_t &conf, offsets_t &off,
30 const inner_product_pd_t *pd, engine_t *engine) {
31 const inner_product_desc_t &ipd = *pd->desc();
32 const memory_desc_wrapper src_d(pd->invariant_src_md());
33 const memory_desc_wrapper wei_d(pd->invariant_wei_md());
34 const memory_desc_wrapper dst_d(pd->invariant_dst_md());
35 data_type_t acc_data_type = pd->desc()->accum_data_type;
36
37 const int ndims = src_d.ndims();
38
39 conf.ndims = ndims;
40 conf.src_ndims = src_d.ndims();
41 conf.wei_ndims = wei_d.ndims();
42 conf.dst_ndims = dst_d.ndims();
43
44 conf.has_spatial = utils::one_of(conf.ndims, 3, 4, 5);
45
46 conf.mb = pd->MB();
47 conf.ic = pd->IC();
48
49 conf.id = pd->ID();
50 conf.ih = pd->IH();
51 conf.iw = pd->IW();
52
53 const auto &src_dims = src_d.padded_dims();
54 conf.ic_total = utils::array_product(&src_dims[1], conf.ndims - 1);
55
56 conf.oc = pd->OC();
57
58 conf.od = pd->OD();
59 conf.oh = pd->OH();
60 conf.ow = pd->OW();
61
62 conf.kd = pd->KD();
63 conf.kh = pd->KH();
64 conf.kw = pd->KW();
65
66 conf.src_dt = src_d.data_type();
67 conf.wei_dt = wei_d.data_type();
68 conf.dst_dt = dst_d.data_type();
69 conf.acc_dt = acc_data_type;
70
71 conf.is_forward = utils::one_of(
72 ipd.prop_kind, prop_kind::forward, prop_kind::forward_inference);
73 conf.is_backward_data = ipd.prop_kind == prop_kind::backward_data;
74 conf.is_backward_weights = ipd.prop_kind == prop_kind::backward_weights;
75
76 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
77 if (conf.is_forward) {
78 conf.with_bias = ipd.bias_desc.format_kind != format_kind::undef;
79 conf.bia_dt = conf.with_bias ? ipd.bias_desc.data_type : data_type::f32;
80 conf.dispatch = compute_engine->create_dispatch(dst_d.md_);
81 conf.dispatch.define_dim("MB", 0, conf.mb);
82 conf.dispatch.define_dim("OC", 1, conf.oc);
83 conf.dispatch.generate();
84 } else if (conf.is_backward_weights) {
85 conf.with_bias = ipd.diff_bias_desc.format_kind != format_kind::undef;
86 conf.bia_dt = conf.with_bias ? ipd.diff_bias_desc.data_type
87 : data_type::f32;
88 conf.dispatch = compute_engine->create_dispatch(wei_d.md_);
89 conf.dispatch.define_dim("OC", 0, conf.oc);
90 conf.dispatch.define_dim("IC", 1, conf.ic);
91 conf.dispatch.define_dim("KD", nstl::max(1, ndims - 3), conf.kd);
92 conf.dispatch.define_dim("KH", nstl::max(1, ndims - 2), conf.kh);
93 conf.dispatch.define_dim("KW", nstl::max(1, ndims - 1), conf.kw);
94 conf.dispatch.generate();
95 } else {
96 conf.with_bias = false;
97 conf.bia_dt = data_type::f32;
98 conf.dispatch = compute_engine->create_dispatch(src_d.md_);
99 conf.dispatch.define_dim("MB_IC", 0, conf.mb * conf.ic);
100 conf.dispatch.define_dim("KD", nstl::max(1, ndims - 3), conf.kd);
101 conf.dispatch.define_dim("KH", nstl::max(1, ndims - 2), conf.kh);
102 conf.dispatch.define_dim("KW", nstl::max(1, ndims - 1), conf.kw);
103 conf.dispatch.generate();
104 }
105
106 set_offsets(src_d, off.src_off);
107 set_offsets(wei_d, off.wei_off);
108 set_offsets(dst_d, off.dst_off);
109
110 conf.attr_info = attr_info_t::create(pd->attr());
111
112 return status::success;
113}
114
115static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
116 const inner_product_conf_t &conf, const offsets_t &off,
117 const post_ops_t &post_ops) {
118 kernel_ctx.define_int("NDIMS", conf.ndims);
119 kernel_ctx.define_int("MB", conf.mb);
120 kernel_ctx.define_int("OC", conf.oc);
121 kernel_ctx.define_int("IC", conf.ic);
122 kernel_ctx.define_int("IC_TOTAL", conf.ic_total);
123 kernel_ctx.define_int("ID", conf.id);
124 kernel_ctx.define_int("IH", conf.ih);
125 kernel_ctx.define_int("IW", conf.iw);
126 kernel_ctx.define_int("OD", conf.od);
127 kernel_ctx.define_int("OH", conf.oh);
128 kernel_ctx.define_int("OW", conf.ow);
129 kernel_ctx.define_int("KD", conf.kd);
130 kernel_ctx.define_int("KH", conf.kh);
131 kernel_ctx.define_int("KW", conf.kw);
132 if (conf.with_bias) kernel_ctx.define_int("WITH_BIAS", 1);
133 if (conf.has_spatial) kernel_ctx.define_int("HAS_SPATIAL", 1);
134
135 if (conf.is_forward)
136 kernel_ctx.define_int("IS_FWD", 1);
137 else if (conf.is_backward_data)
138 kernel_ctx.define_int("IS_BWD_D", 1);
139 else if (conf.is_backward_weights)
140 kernel_ctx.define_int("IS_BWD_W", 1);
141
142 def_attr_info(kernel_ctx, conf.attr_info, post_ops);
143
144 def_offsets(off.src_off, kernel_ctx, "SRC", conf.src_ndims);
145 def_offsets(off.wei_off, kernel_ctx, "WEI", conf.wei_ndims);
146 def_offsets(off.dst_off, kernel_ctx, "DST", conf.dst_ndims);
147
148 if (conf.src_dt == data_type::f16)
149 kernel_ctx.set_data_type(data_type::f16);
150 else
151 kernel_ctx.set_data_type(data_type::f32);
152
153 def_data_type(kernel_ctx, conf.src_dt, "SRC");
154 def_data_type(kernel_ctx, conf.wei_dt, "WEI");
155 def_data_type(kernel_ctx, conf.bia_dt, "BIA");
156 def_data_type(kernel_ctx, conf.dst_dt, "DST");
157 def_data_type(kernel_ctx, conf.acc_dt, "ACC");
158
159 def_dispatch(kernel_ctx, conf.dispatch);
160
161 return status::success;
162}
163
164status_t ref_inner_product_fwd_t::pd_t::init_conf(engine_t *engine) {
165 return init_conf_common(conf, off, this, engine);
166}
167
168status_t ref_inner_product_fwd_t::pd_t::init_kernel_ctx(
169 compute::kernel_ctx_t &kernel_ctx) const {
170 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
171}
172
173status_t ref_inner_product_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
174 status_t status = status::success;
175
176 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
177 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
178 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
179 auto &dst = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status);
180 CHECK(status);
181
182 const auto &conf = pd()->conf;
183
184 compute::kernel_arg_list_t arg_list;
185 arg_list.set(0, src);
186 arg_list.set(1, weights);
187 arg_list.set(2, bias);
188 arg_list.set(3, dst);
189
190 unsigned arg_idx = append_post_ops_to_arg_list(
191 ctx, arg_list, 4, pd()->attr()->post_ops_);
192
193 auto &src_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
194 auto &wei_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
195 auto &dst_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
196
197 arg_list.set(arg_idx++, src_scales);
198 arg_list.set(arg_idx++, wei_scales);
199 arg_list.set(arg_idx++, dst_scales);
200
201 auto nd_range = conf.dispatch.nd_range();
202
203 status = parallel_for(ctx, nd_range, kernel_, arg_list);
204
205 return status;
206}
207
208status_t ref_inner_product_bwd_data_t::pd_t::init_conf(engine_t *engine) {
209 return init_conf_common(conf, off, this, engine);
210}
211
212status_t ref_inner_product_bwd_data_t::pd_t::init_kernel_ctx(
213 compute::kernel_ctx_t &kernel_ctx) const {
214 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
215}
216
217status_t ref_inner_product_bwd_data_t::execute_backward_data(
218 const exec_ctx_t &ctx) const {
219 status_t status = status::success;
220
221 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
222 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
223 auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status);
224 CHECK(status);
225
226 const auto &conf = pd()->conf;
227
228 compute::kernel_arg_list_t arg_list;
229 arg_list.set(0, diff_src);
230 arg_list.set(1, weights);
231 arg_list.set(2, diff_dst);
232
233 auto nd_range = conf.dispatch.nd_range();
234
235 status = parallel_for(ctx, nd_range, kernel_, arg_list);
236
237 return status;
238}
239
240status_t ref_inner_product_bwd_weights_t::pd_t::init_conf(engine_t *engine) {
241 return init_conf_common(conf, off, this, engine);
242}
243
244status_t ref_inner_product_bwd_weights_t::pd_t::init_kernel_ctx(
245 compute::kernel_ctx_t &kernel_ctx) const {
246 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
247}
248
249status_t ref_inner_product_bwd_weights_t::execute_backward_weights(
250 const exec_ctx_t &ctx) const {
251 status_t status = status::success;
252
253 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
254 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
255 auto &diff_weights = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_WEIGHTS, status);
256 CHECK(status);
257 auto &diff_bias = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_BIAS, status);
258 CHECK(status);
259
260 const auto &conf = pd()->conf;
261
262 compute::kernel_arg_list_t arg_list;
263 arg_list.set(0, src);
264 arg_list.set(1, diff_weights);
265 arg_list.set(2, diff_bias);
266 arg_list.set(3, diff_dst);
267
268 auto nd_range = conf.dispatch.nd_range();
269
270 status = parallel_for(ctx, nd_range, kernel_, arg_list);
271
272 return status;
273}
274
275} // namespace ocl
276} // namespace gpu
277} // namespace impl
278} // namespace dnnl
279