1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <algorithm> |
17 | #include "gpu/ocl/ocl_stream.hpp" |
18 | #include "gpu/ocl/xe_lp_x8s8x_1x1_convolution.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace gpu { |
23 | namespace ocl { |
24 | |
25 | status_t xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_conf(engine_t *engine) { |
26 | using namespace format_tag; |
27 | |
28 | const convolution_desc_t &cd = *desc(); |
29 | const memory_desc_wrapper src_mdw(src_md()); |
30 | const memory_desc_wrapper weights_mdw(weights_md()); |
31 | const memory_desc_wrapper dst_mdw(dst_md()); |
32 | const memory_desc_wrapper bias_mdw(weights_md(1)); |
33 | auto dev_info = utils::downcast<compute::compute_engine_t *>(engine) |
34 | ->device_info(); |
35 | |
36 | set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(), |
37 | *weights_md(1), *attr()); |
38 | |
39 | conf.is_nhwc |
40 | = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef |
41 | || dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) |
42 | != format_tag::undef; |
43 | |
44 | if (conf.is_depthwise || conf.kw != 1 || conf.kh != 1 || conf.kd != 1 |
45 | || (conf.with_groups && conf.ngroups > 1 |
46 | && (conf.oc % 32 != 0 || conf.ic % 32 != 0))) |
47 | return status::unimplemented; |
48 | |
49 | conf.src_data_type = src_mdw.data_type(); |
50 | conf.dst_data_type = dst_mdw.data_type(); |
51 | |
52 | conf.mb_block = 32; |
53 | conf.oc_block = 32; |
54 | conf.ic_block = 32; |
55 | conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block); |
56 | int ow = conf.ow; |
57 | int oh = conf.oh; |
58 | int od = conf.od; |
59 | const bool is_stride1 |
60 | = conf.stride_d == 1 && conf.stride_h == 1 && conf.stride_w == 1; |
61 | const bool is_padded = conf.l_pad > 0 || conf.t_pad > 0 || conf.f_pad > 0; |
62 | |
63 | // TODO: fix r-padded shapes issue in 1x1 kernel |
64 | if (conf.r_pad > 0) return status::unimplemented; |
65 | |
66 | if (is_stride1 || is_padded) { |
67 | // reshape to nCx32c |
68 | ow = ow * oh * od; |
69 | oh = od = 1; |
70 | } |
71 | conf.sp = ow; |
72 | |
73 | if ((conf.mb % 16 == 0) && !conf.is_nhwc && !is_padded) { |
74 | conf.mb_block = 32; |
75 | conf.sp_block = 1; |
76 | } else { |
77 | conf.mb_block = 1; |
78 | conf.sp_block = 4; |
79 | auto approx_clocks = [&](const int block) { |
80 | int ic_chunks = utils::div_up(conf.ic, conf.ic_block); |
81 | bool use_slm = utils::div_up(conf.ow, block) % 8 == 0; |
82 | int mem_clocks = ic_chunks * (16 - use_slm * 6) |
83 | + block / 2 * (ic_chunks + 1); |
84 | int compute_clocks = 32 * block * ic_chunks; |
85 | int num_threads = conf.nchunk * conf.mb * od * oh |
86 | * utils::div_up(ow, block); |
87 | return utils::div_up(num_threads, dev_info->hw_threads()) |
88 | * (compute_clocks + mem_clocks); |
89 | }; |
90 | auto clock_compare = [&](const int &block1, const int &block2) { |
91 | return approx_clocks(block1) < approx_clocks(block2); |
92 | }; |
93 | std::vector<int> sorted_blocks = {4, 8, 12, 16}; |
94 | std::sort(sorted_blocks.begin(), sorted_blocks.end(), clock_compare); |
95 | conf.sp_block = sorted_blocks[0]; |
96 | } |
97 | conf.src_data_type = src_mdw.data_type(); |
98 | conf.dst_data_type = dst_mdw.data_type(); |
99 | |
100 | const int ow_group = (utils::div_up(ow, conf.sp_block) % 8) ? 1 : 8; |
101 | |
102 | conf.sub_group_size = 8; |
103 | conf.lws_d[0] = conf.sub_group_size; |
104 | conf.lws_d[1] = ow_group; |
105 | conf.lws_d[2] = 1; |
106 | |
107 | const int num_sp_threads = utils::div_up(ow, conf.sp_block) * oh * od; |
108 | |
109 | conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]); |
110 | conf.gws_d[1] = utils::rnd_up(num_sp_threads, conf.lws_d[1]); |
111 | conf.gws_d[2] = (conf.mb_block == 32 ? 2 : 1) |
112 | * utils::div_up(conf.mb, conf.mb_block); |
113 | |
114 | conf.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
115 | |
116 | format_tag_t src_tag, dst_tag, wei_tag; |
117 | |
118 | if (conf.is_nhwc) { |
119 | src_tag = utils::pick(conf.ndims - 3, nwc, nhwc); |
120 | dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc); |
121 | } else { |
122 | if (conf.mb_block == 32) { |
123 | src_tag = utils::pick(conf.ndims - 3, NCw32n32c, NChw32n32c); |
124 | dst_tag = utils::pick(conf.ndims - 3, NCw32n32c, NChw32n32c); |
125 | } else { |
126 | src_tag = utils::pick(conf.ndims - 3, nCw32c, nChw32c); |
127 | dst_tag = utils::pick(conf.ndims - 3, nCw32c, nChw32c); |
128 | } |
129 | } |
130 | |
131 | wei_tag = conf.with_groups |
132 | ? utils::pick(conf.ndims - 3, gOIw4o8i8o4i, gOIhw4o8i8o4i) |
133 | : utils::pick(conf.ndims - 3, OIw4o8i8o4i, OIhw4o8i8o4i); |
134 | |
135 | conf.src_tag = src_mdw.format_kind() == format_kind::any |
136 | ? src_tag |
137 | : src_mdw.matches_one_of_tag(src_tag); |
138 | conf.wei_tag = weights_mdw.format_kind() == format_kind::any |
139 | ? wei_tag |
140 | : weights_mdw.matches_one_of_tag(wei_tag); |
141 | conf.dst_tag = dst_mdw.format_kind() == format_kind::any |
142 | ? dst_tag |
143 | : dst_mdw.matches_one_of_tag(dst_tag); |
144 | |
145 | if (conf.src_tag != src_tag || conf.wei_tag != wei_tag |
146 | || conf.dst_tag != dst_tag) |
147 | return status::unimplemented; |
148 | |
149 | return status::success; |
150 | } |
151 | |
152 | status_t xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_kernel_ctx( |
153 | compute::kernel_ctx_t &kernel_ctx) const { |
154 | kernel_ctx.define_int("G" , conf.ngroups); |
155 | kernel_ctx.define_int("MB" , conf.mb); |
156 | kernel_ctx.define_int("IC" , conf.ic_without_padding); |
157 | kernel_ctx.define_int("ID" , conf.id); |
158 | kernel_ctx.define_int("IH" , conf.ih); |
159 | kernel_ctx.define_int("IW" , conf.iw); |
160 | kernel_ctx.define_int("OC" , conf.oc_without_padding); |
161 | kernel_ctx.define_int("OD" , conf.od); |
162 | kernel_ctx.define_int("OH" , conf.oh); |
163 | kernel_ctx.define_int("OW" , conf.ow); |
164 | kernel_ctx.define_int("KD" , conf.kd); |
165 | kernel_ctx.define_int("KH" , conf.kh); |
166 | kernel_ctx.define_int("KW" , conf.kw); |
167 | kernel_ctx.define_int("SD" , conf.stride_d); |
168 | kernel_ctx.define_int("SH" , conf.stride_h); |
169 | kernel_ctx.define_int("SW" , conf.stride_w); |
170 | kernel_ctx.define_int("PD" , conf.f_pad); |
171 | kernel_ctx.define_int("PH" , conf.t_pad); |
172 | kernel_ctx.define_int("PW" , conf.l_pad); |
173 | |
174 | kernel_ctx.define_int("SP_BLOCK" , conf.sp_block); |
175 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
176 | kernel_ctx.define_int("OC_BLOCK" , conf.oc_block); |
177 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
178 | |
179 | kernel_ctx.define_int("WITH_BIAS" , conf.with_bias); |
180 | |
181 | def_attr_info( |
182 | kernel_ctx, conf.attr_info, attr()->post_ops_, &(dst_md()->dims)); |
183 | |
184 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
185 | |
186 | kernel_ctx.define_int("LWS_0" , conf.lws_d[0]); |
187 | kernel_ctx.define_int("LWS_1" , conf.lws_d[1]); |
188 | kernel_ctx.define_int("LWS_2" , conf.lws_d[2]); |
189 | |
190 | kernel_ctx.define_int("OC_NCHUNK" , utils::div_up(conf.oc, conf.oc_block)); |
191 | kernel_ctx.define_int("IC_NCHUNK" , utils::div_up(conf.ic, conf.ic_block)); |
192 | |
193 | kernel_ctx.define_int( |
194 | "INT8_WEI_SLM" , utils::div_up(conf.sp, conf.sp_block) % 8 == 0); |
195 | kernel_ctx.define_int("SP_TAIL" , |
196 | utils::div_up(conf.sp, conf.sp_block) % conf.lws_d[1] == 0); |
197 | kernel_ctx.define_int("OUT_SP_TAIL" , conf.sp % conf.sp_block); |
198 | |
199 | kernel_ctx.define_int("WEI_4O8I8O4I" , 1); |
200 | |
201 | kernel_ctx.set_data_type(conf.dst_data_type); |
202 | def_data_type(kernel_ctx, conf.src_data_type, "SRC" ); |
203 | def_data_type(kernel_ctx, conf.dst_data_type, "DST" ); |
204 | def_data_type(kernel_ctx, |
205 | conf.attr_info.sum_data_type == dnnl_data_type_undef |
206 | ? conf.dst_data_type |
207 | : conf.attr_info.sum_data_type, |
208 | "SUM" ); |
209 | kernel_ctx.add_option("-Dcl_intel_subgroups_char" ); |
210 | |
211 | return status::success; |
212 | } |
213 | |
214 | void xe_lp_x8s8x_1x1_convolution_fwd_t::pd_t::init_scratchpad() { |
215 | if (conf.attr_info.with_src_zpoints) { |
216 | size_t size = conf.ngroups * utils::rnd_up(conf.oc, 32); |
217 | |
218 | auto scratchpad = scratchpad_registry().registrar(); |
219 | scratchpad.book(memory_tracking::names::key_conv_wei_reduction, size, |
220 | types::data_type_size(data_type::s32), OCL_BUFFER_ALIGNMENT); |
221 | } |
222 | } |
223 | |
224 | status_t xe_lp_x8s8x_1x1_convolution_fwd_t::execute_forward( |
225 | const exec_ctx_t &ctx) const { |
226 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
227 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
228 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
229 | auto &oscales = CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES); |
230 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
231 | auto &src_zpoints |
232 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); |
233 | auto &dst_zpoints |
234 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); |
235 | |
236 | const auto &conf = pd()->conf; |
237 | |
238 | std::unique_ptr<memory_storage_t> temp_src_compensation; |
239 | if (conf.attr_info.with_src_zpoints) { |
240 | temp_src_compensation = ctx.get_scratchpad_grantor().get_memory_storage( |
241 | memory_tracking::names::key_conv_wei_reduction); |
242 | |
243 | compute::kernel_arg_list_t arg_list; |
244 | arg_list.set(0, src_zpoints); |
245 | arg_list.set(1, weights); |
246 | arg_list.set(2, *temp_src_compensation); |
247 | |
248 | auto nd_range = compute::nd_range_t( |
249 | {8, utils::div_up(conf.oc, 32), conf.ngroups}, {8, 1, 1}); |
250 | status_t status = parallel_for( |
251 | ctx, nd_range, src_compensation_kernel_, arg_list); |
252 | if (status != status::success) return status::runtime_error; |
253 | } |
254 | |
255 | compute::kernel_arg_list_t arg_list; |
256 | arg_list.set(0, src); |
257 | arg_list.set(1, weights); |
258 | arg_list.set(2, bias); |
259 | arg_list.set(3, dst); |
260 | |
261 | unsigned arg_idx = append_post_ops_to_arg_list( |
262 | ctx, arg_list, 4, pd()->attr()->post_ops_); |
263 | |
264 | if (conf.attr_info.with_common_oscales |
265 | || conf.attr_info.with_per_oc_oscales) { |
266 | arg_list.set(arg_idx++, oscales); |
267 | } else { |
268 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
269 | } |
270 | |
271 | if (conf.attr_info.with_src_zpoints) |
272 | arg_list.set(arg_idx++, *temp_src_compensation); |
273 | else |
274 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
275 | |
276 | if (conf.attr_info.with_dst_zpoints) |
277 | arg_list.set(arg_idx++, dst_zpoints); |
278 | else |
279 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
280 | |
281 | auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d); |
282 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
283 | |
284 | if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) { |
285 | ctx.zero_pad_output(DNNL_ARG_DST); |
286 | } |
287 | |
288 | return status; |
289 | } |
290 | |
291 | } // namespace ocl |
292 | } // namespace gpu |
293 | } // namespace impl |
294 | } // namespace dnnl |
295 | |