1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <algorithm>
17#include "gpu/ocl/ocl_stream.hpp"
18#include "gpu/ocl/xe_lp_x8s8x_1x1_convolution.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace gpu {
23namespace ocl {
24
25status_t xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_conf(engine_t *engine) {
26 using namespace format_tag;
27
28 const convolution_desc_t &cd = *desc();
29 const memory_desc_wrapper src_mdw(src_md());
30 const memory_desc_wrapper weights_mdw(weights_md());
31 const memory_desc_wrapper dst_mdw(dst_md());
32 const memory_desc_wrapper bias_mdw(weights_md(1));
33 auto dev_info = utils::downcast<compute::compute_engine_t *>(engine)
34 ->device_info();
35
36 set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(),
37 *weights_md(1), *attr());
38
39 conf.is_nhwc
40 = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef
41 || dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc)
42 != format_tag::undef;
43
44 if (conf.is_depthwise || conf.kw != 1 || conf.kh != 1 || conf.kd != 1
45 || (conf.with_groups && conf.ngroups > 1
46 && (conf.oc % 32 != 0 || conf.ic % 32 != 0)))
47 return status::unimplemented;
48
49 conf.src_data_type = src_mdw.data_type();
50 conf.dst_data_type = dst_mdw.data_type();
51
52 conf.mb_block = 32;
53 conf.oc_block = 32;
54 conf.ic_block = 32;
55 conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block);
56 int ow = conf.ow;
57 int oh = conf.oh;
58 int od = conf.od;
59 const bool is_stride1
60 = conf.stride_d == 1 && conf.stride_h == 1 && conf.stride_w == 1;
61 const bool is_padded = conf.l_pad > 0 || conf.t_pad > 0 || conf.f_pad > 0;
62
63 // TODO: fix r-padded shapes issue in 1x1 kernel
64 if (conf.r_pad > 0) return status::unimplemented;
65
66 if (is_stride1 || is_padded) {
67 // reshape to nCx32c
68 ow = ow * oh * od;
69 oh = od = 1;
70 }
71 conf.sp = ow;
72
73 if ((conf.mb % 16 == 0) && !conf.is_nhwc && !is_padded) {
74 conf.mb_block = 32;
75 conf.sp_block = 1;
76 } else {
77 conf.mb_block = 1;
78 conf.sp_block = 4;
79 auto approx_clocks = [&](const int block) {
80 int ic_chunks = utils::div_up(conf.ic, conf.ic_block);
81 bool use_slm = utils::div_up(conf.ow, block) % 8 == 0;
82 int mem_clocks = ic_chunks * (16 - use_slm * 6)
83 + block / 2 * (ic_chunks + 1);
84 int compute_clocks = 32 * block * ic_chunks;
85 int num_threads = conf.nchunk * conf.mb * od * oh
86 * utils::div_up(ow, block);
87 return utils::div_up(num_threads, dev_info->hw_threads())
88 * (compute_clocks + mem_clocks);
89 };
90 auto clock_compare = [&](const int &block1, const int &block2) {
91 return approx_clocks(block1) < approx_clocks(block2);
92 };
93 std::vector<int> sorted_blocks = {4, 8, 12, 16};
94 std::sort(sorted_blocks.begin(), sorted_blocks.end(), clock_compare);
95 conf.sp_block = sorted_blocks[0];
96 }
97 conf.src_data_type = src_mdw.data_type();
98 conf.dst_data_type = dst_mdw.data_type();
99
100 const int ow_group = (utils::div_up(ow, conf.sp_block) % 8) ? 1 : 8;
101
102 conf.sub_group_size = 8;
103 conf.lws_d[0] = conf.sub_group_size;
104 conf.lws_d[1] = ow_group;
105 conf.lws_d[2] = 1;
106
107 const int num_sp_threads = utils::div_up(ow, conf.sp_block) * oh * od;
108
109 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
110 conf.gws_d[1] = utils::rnd_up(num_sp_threads, conf.lws_d[1]);
111 conf.gws_d[2] = (conf.mb_block == 32 ? 2 : 1)
112 * utils::div_up(conf.mb, conf.mb_block);
113
114 conf.with_bias = cd.bias_desc.format_kind != format_kind::undef;
115
116 format_tag_t src_tag, dst_tag, wei_tag;
117
118 if (conf.is_nhwc) {
119 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc);
120 dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc);
121 } else {
122 if (conf.mb_block == 32) {
123 src_tag = utils::pick(conf.ndims - 3, NCw32n32c, NChw32n32c);
124 dst_tag = utils::pick(conf.ndims - 3, NCw32n32c, NChw32n32c);
125 } else {
126 src_tag = utils::pick(conf.ndims - 3, nCw32c, nChw32c);
127 dst_tag = utils::pick(conf.ndims - 3, nCw32c, nChw32c);
128 }
129 }
130
131 wei_tag = conf.with_groups
132 ? utils::pick(conf.ndims - 3, gOIw4o8i8o4i, gOIhw4o8i8o4i)
133 : utils::pick(conf.ndims - 3, OIw4o8i8o4i, OIhw4o8i8o4i);
134
135 conf.src_tag = src_mdw.format_kind() == format_kind::any
136 ? src_tag
137 : src_mdw.matches_one_of_tag(src_tag);
138 conf.wei_tag = weights_mdw.format_kind() == format_kind::any
139 ? wei_tag
140 : weights_mdw.matches_one_of_tag(wei_tag);
141 conf.dst_tag = dst_mdw.format_kind() == format_kind::any
142 ? dst_tag
143 : dst_mdw.matches_one_of_tag(dst_tag);
144
145 if (conf.src_tag != src_tag || conf.wei_tag != wei_tag
146 || conf.dst_tag != dst_tag)
147 return status::unimplemented;
148
149 return status::success;
150}
151
152status_t xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_kernel_ctx(
153 compute::kernel_ctx_t &kernel_ctx) const {
154 kernel_ctx.define_int("G", conf.ngroups);
155 kernel_ctx.define_int("MB", conf.mb);
156 kernel_ctx.define_int("IC", conf.ic_without_padding);
157 kernel_ctx.define_int("ID", conf.id);
158 kernel_ctx.define_int("IH", conf.ih);
159 kernel_ctx.define_int("IW", conf.iw);
160 kernel_ctx.define_int("OC", conf.oc_without_padding);
161 kernel_ctx.define_int("OD", conf.od);
162 kernel_ctx.define_int("OH", conf.oh);
163 kernel_ctx.define_int("OW", conf.ow);
164 kernel_ctx.define_int("KD", conf.kd);
165 kernel_ctx.define_int("KH", conf.kh);
166 kernel_ctx.define_int("KW", conf.kw);
167 kernel_ctx.define_int("SD", conf.stride_d);
168 kernel_ctx.define_int("SH", conf.stride_h);
169 kernel_ctx.define_int("SW", conf.stride_w);
170 kernel_ctx.define_int("PD", conf.f_pad);
171 kernel_ctx.define_int("PH", conf.t_pad);
172 kernel_ctx.define_int("PW", conf.l_pad);
173
174 kernel_ctx.define_int("SP_BLOCK", conf.sp_block);
175 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
176 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
177 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
178
179 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
180
181 def_attr_info(
182 kernel_ctx, conf.attr_info, attr()->post_ops_, &(dst_md()->dims));
183
184 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
185
186 kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
187 kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
188 kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
189
190 kernel_ctx.define_int("OC_NCHUNK", utils::div_up(conf.oc, conf.oc_block));
191 kernel_ctx.define_int("IC_NCHUNK", utils::div_up(conf.ic, conf.ic_block));
192
193 kernel_ctx.define_int(
194 "INT8_WEI_SLM", utils::div_up(conf.sp, conf.sp_block) % 8 == 0);
195 kernel_ctx.define_int("SP_TAIL",
196 utils::div_up(conf.sp, conf.sp_block) % conf.lws_d[1] == 0);
197 kernel_ctx.define_int("OUT_SP_TAIL", conf.sp % conf.sp_block);
198
199 kernel_ctx.define_int("WEI_4O8I8O4I", 1);
200
201 kernel_ctx.set_data_type(conf.dst_data_type);
202 def_data_type(kernel_ctx, conf.src_data_type, "SRC");
203 def_data_type(kernel_ctx, conf.dst_data_type, "DST");
204 def_data_type(kernel_ctx,
205 conf.attr_info.sum_data_type == dnnl_data_type_undef
206 ? conf.dst_data_type
207 : conf.attr_info.sum_data_type,
208 "SUM");
209 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
210
211 return status::success;
212}
213
214void xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_scratchpad() {
215 if (conf.attr_info.with_src_zpoints) {
216 size_t size = conf.ngroups * utils::rnd_up(conf.oc, 32);
217
218 auto scratchpad = scratchpad_registry().registrar();
219 scratchpad.book(memory_tracking::names::key_conv_wei_reduction, size,
220 types::data_type_size(data_type::s32), OCL_BUFFER_ALIGNMENT);
221 }
222}
223
224status_t xe_lp_x8s8x_1x1_convolution_fwd_t::execute_forward(
225 const exec_ctx_t &ctx) const {
226 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
227 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
228 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
229 auto &oscales = CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES);
230 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
231 auto &src_zpoints
232 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
233 auto &dst_zpoints
234 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
235
236 const auto &conf = pd()->conf;
237
238 std::unique_ptr<memory_storage_t> temp_src_compensation;
239 if (conf.attr_info.with_src_zpoints) {
240 temp_src_compensation = ctx.get_scratchpad_grantor().get_memory_storage(
241 memory_tracking::names::key_conv_wei_reduction);
242
243 compute::kernel_arg_list_t arg_list;
244 arg_list.set(0, src_zpoints);
245 arg_list.set(1, weights);
246 arg_list.set(2, *temp_src_compensation);
247
248 auto nd_range = compute::nd_range_t(
249 {8, utils::div_up(conf.oc, 32), conf.ngroups}, {8, 1, 1});
250 status_t status = parallel_for(
251 ctx, nd_range, src_compensation_kernel_, arg_list);
252 if (status != status::success) return status::runtime_error;
253 }
254
255 compute::kernel_arg_list_t arg_list;
256 arg_list.set(0, src);
257 arg_list.set(1, weights);
258 arg_list.set(2, bias);
259 arg_list.set(3, dst);
260
261 unsigned arg_idx = append_post_ops_to_arg_list(
262 ctx, arg_list, 4, pd()->attr()->post_ops_);
263
264 if (conf.attr_info.with_common_oscales
265 || conf.attr_info.with_per_oc_oscales) {
266 arg_list.set(arg_idx++, oscales);
267 } else {
268 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
269 }
270
271 if (conf.attr_info.with_src_zpoints)
272 arg_list.set(arg_idx++, *temp_src_compensation);
273 else
274 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
275
276 if (conf.attr_info.with_dst_zpoints)
277 arg_list.set(arg_idx++, dst_zpoints);
278 else
279 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
280
281 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
282 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
283
284 if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) {
285 ctx.zero_pad_output(DNNL_ARG_DST);
286 }
287
288 return status;
289}
290
291} // namespace ocl
292} // namespace gpu
293} // namespace impl
294} // namespace dnnl
295