1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/xe_lp_x8s8x_convolution.hpp" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_traits.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | namespace ocl { |
27 | |
28 | bool is_nhwc(const memory_desc_wrapper &src_mdw, |
29 | const memory_desc_wrapper &dst_mdw) { |
30 | using namespace format_tag; |
31 | const bool is_src_nhwc |
32 | = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef; |
33 | const bool is_dst_nhwc |
34 | = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef; |
35 | const bool is_nhwc = is_src_nhwc || is_dst_nhwc; |
36 | return is_nhwc; |
37 | } |
38 | |
39 | status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_conf() { |
40 | using namespace format_tag; |
41 | |
42 | const memory_desc_t *src = src_md(); |
43 | const memory_desc_t *dst = dst_md(); |
44 | const memory_desc_t *wei = weights_md(); |
45 | const memory_desc_t *bia = weights_md(1); |
46 | |
47 | memory_desc_t r_src, r_wei, r_dst; |
48 | |
49 | int ndims = src_md()->ndims; |
50 | |
51 | // XXX: try reduce number of spatial dims when iw/ow/kw=1, |
52 | // memory tags will be selected based on the number of input dimensions |
53 | bool use_reshaped_mem = ndims > 3; |
54 | if (memory_desc_reshape(r_src, *src, src->ndims - 1, src->dims) |
55 | != status::success) |
56 | use_reshaped_mem = false; |
57 | if (memory_desc_reshape(r_dst, *dst, dst->ndims - 1, dst->dims) |
58 | != status::success) |
59 | use_reshaped_mem = false; |
60 | if (memory_desc_reshape(r_wei, *wei, wei->ndims - 1, wei->dims) |
61 | != status::success) |
62 | use_reshaped_mem = false; |
63 | |
64 | if (use_reshaped_mem) { |
65 | src = &r_src; |
66 | dst = &r_dst; |
67 | wei = &r_wei; |
68 | } |
69 | |
70 | const convolution_desc_t &cd = *desc(); |
71 | const memory_desc_wrapper src_mdw(src); |
72 | const memory_desc_wrapper weights_mdw(wei); |
73 | const memory_desc_wrapper dst_mdw(dst); |
74 | |
75 | set_default_conf(conf, cd, *src, *wei, *dst, *bia, *attr()); |
76 | |
77 | const bool is_1stconv = conf.ic_without_padding <= 4 && !conf.is_depthwise; |
78 | |
79 | conf.is_nhwc = is_nhwc(src_mdw, dst_mdw); |
80 | conf.is_dst_nhwc |
81 | = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef; |
82 | // TODO: Add group convolution support in NHWC kernel. |
83 | if (!conf.is_depthwise && conf.with_groups && conf.ngroups > 1 |
84 | && (conf.oc % 32 != 0 || conf.ic % 32 != 0)) |
85 | return status::unimplemented; |
86 | |
87 | conf.dst_data_type = dst_mdw.data_type(); |
88 | conf.src_data_type = src_mdw.data_type(); |
89 | |
90 | conf.oc_block = 32; |
91 | conf.ic_block = 32; |
92 | conf.mb_block = 1; |
93 | conf.ow_block = 1; |
94 | |
95 | if (conf.is_nhwc) { |
96 | conf.ver = ver_nhwc; |
97 | if (conf.is_depthwise) { |
98 | if (!(conf.kw <= 4 && conf.stride_w <= 2 && conf.dilate_w == 0 |
99 | && conf.l_pad < 4)) { |
100 | conf.mb_block = 32; |
101 | } else { |
102 | int off = conf.kw == 4 ? 1 : 0; |
103 | if (conf.ow < 15 - off) { |
104 | conf.ow_block = conf.ow; |
105 | } else { |
106 | for (int i = 0; i < 7; ++i) { |
107 | conf.ow_block = utils::max_div(conf.ow + i, 14 - off); |
108 | if (conf.ow_block > 4) break; |
109 | } |
110 | } |
111 | } |
112 | |
113 | int ow_nchunk = utils::div_up(conf.ow, conf.ow_block); |
114 | |
115 | conf.sub_group_size = 16; |
116 | |
117 | conf.lws_d[0] = 16; |
118 | conf.lws_d[1] = 1; |
119 | conf.lws_d[2] = 1; |
120 | |
121 | conf.gws_d[0] = utils::div_up(conf.ngroups, 32) * conf.lws_d[0]; |
122 | conf.gws_d[1] = conf.od * conf.oh * ow_nchunk; |
123 | conf.gws_d[2] = utils::div_up(conf.mb, |
124 | utils::div_up(conf.mb_block, conf.mb_block == 32 ? 4 : 1)); |
125 | } else { |
126 | if (!is_1stconv) { |
127 | conf.ow_block |
128 | = (conf.mb * conf.oc * conf.oh * conf.ow < 49 * 1024) |
129 | ? 4 |
130 | : 8; |
131 | } else { // 1st conv |
132 | conf.ic_block = 4; |
133 | conf.ow_block = (conf.kw * conf.kh <= 49 && conf.ow % 16 < 8) |
134 | ? 16 |
135 | : 12; |
136 | if (conf.mb == 8 || conf.mb % 16 == 0) { conf.mb_block = 32; } |
137 | } |
138 | |
139 | int max_oc = 4; |
140 | int oc_group = utils::max_div( |
141 | utils::div_up(conf.oc, conf.oc_block), max_oc); |
142 | int max_subgroups = 32; |
143 | int max_ow_group = max_subgroups / oc_group; |
144 | int ow_nchunk = utils::div_up(conf.ow, conf.ow_block); |
145 | int ow_group = utils::max_div(ow_nchunk, max_ow_group); |
146 | |
147 | conf.sub_group_size = 8; |
148 | conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block); |
149 | conf.src_slm_size = conf.ic_block / 4 |
150 | * (ow_group * conf.stride_w * conf.ow_block |
151 | + (conf.kw - 1) * (1 + conf.dilate_w)); |
152 | |
153 | conf.lws_d[0] = 8 * oc_group; |
154 | conf.lws_d[1] = ow_group; |
155 | conf.lws_d[2] = 1; |
156 | |
157 | conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]); |
158 | conf.gws_d[1] |
159 | = conf.od * conf.oh * utils::rnd_up(ow_nchunk, ow_group); |
160 | conf.gws_d[2] = is_1stconv |
161 | ? (conf.mb <= 16 && dst_mdw.is_plain()) |
162 | ? conf.mb |
163 | : utils::rnd_up(conf.mb, conf.mb_block) |
164 | : utils::div_up(conf.mb, |
165 | utils::div_up(conf.mb_block, |
166 | conf.mb_block == 32 ? 2 : 1)); |
167 | } |
168 | |
169 | } else if (conf.is_depthwise) { |
170 | if (conf.mb == 8 || conf.mb % 16 == 0 |
171 | || !(conf.kw <= 4 && conf.stride_w <= 2 && conf.dilate_w == 0 |
172 | && conf.l_pad < 4)) { |
173 | conf.ver = ver_mb_block; |
174 | conf.mb_block = 32; |
175 | } else { |
176 | conf.ver = ver_ow_block; |
177 | int off = conf.kw == 4 ? 1 : 0; |
178 | // Try to do not use ow blocks of size > 10 as there is |
179 | // a lot of GRF memory used what leads to spills |
180 | if (conf.ow < 10 - off) { |
181 | conf.ow_block = conf.ow; |
182 | } else { |
183 | for (int i = 0; i < 7; ++i) { |
184 | conf.ow_block = utils::max_div(conf.ow + i, 10 - off); |
185 | if (conf.ow_block > 4) break; |
186 | } |
187 | } |
188 | } |
189 | |
190 | conf.sub_group_size = 16; |
191 | const int spatial_global_size |
192 | = conf.od * conf.oh * utils::div_up(conf.ow, conf.ow_block); |
193 | |
194 | conf.lws_d[0] = 16; |
195 | conf.lws_d[1] = 1; |
196 | if (conf.ver == ver_mb_block) { |
197 | // Try to increase WG size in order to improve caching |
198 | for (const int pixels_per_wg : {2, 3, 5}) { |
199 | if (spatial_global_size % pixels_per_wg == 0) { |
200 | conf.lws_d[1] = pixels_per_wg; |
201 | break; |
202 | } |
203 | } |
204 | } |
205 | conf.lws_d[2] = 1; |
206 | |
207 | conf.gws_d[0] = utils::div_up(conf.ngroups, 32) * conf.lws_d[0]; |
208 | conf.gws_d[1] = spatial_global_size; |
209 | conf.gws_d[2] = (conf.mb_block == 32 ? 4 : 1) |
210 | * utils::div_up(conf.mb, conf.mb_block); |
211 | |
212 | } else { |
213 | if (conf.mb % 16 == 0) { |
214 | conf.ver = ver_mb_block; |
215 | conf.mb_block = 32; |
216 | } else { |
217 | conf.ver = ver_ow_block; |
218 | } |
219 | if (conf.ic <= 4) conf.ver = ver_1stconv; |
220 | |
221 | int max_oc = 4; |
222 | int oc_group |
223 | = utils::max_div(utils::div_up(conf.oc, conf.oc_block), max_oc); |
224 | int max_subgroups = 32; |
225 | int max_ow_group = max_subgroups / oc_group; |
226 | int ow_group = 1; |
227 | int ow_nchunk = 1; |
228 | |
229 | conf.sub_group_size = 8; |
230 | conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block); |
231 | |
232 | switch (conf.ver) { |
233 | case ver_mb_block: |
234 | oc_group = 1; |
235 | conf.ow_block = 1; |
236 | ow_group = 1; |
237 | break; |
238 | case ver_ow_block: |
239 | conf.ow_block |
240 | = (conf.mb * conf.oc * conf.oh * conf.ow < 49 * 1024) |
241 | ? 4 |
242 | : 8; |
243 | ow_nchunk = utils::div_up(conf.ow, conf.ow_block); |
244 | ow_group = utils::max_div(ow_nchunk, max_ow_group); |
245 | break; |
246 | case ver_1stconv: |
247 | conf.ic_block = 4; |
248 | conf.ow_block = (conf.kw * conf.kh <= 49 && conf.ow % 16 < 8) |
249 | ? 16 |
250 | : 12; |
251 | ow_nchunk = utils::div_up(conf.ow, conf.ow_block); |
252 | ow_group = utils::max_div(ow_nchunk, max_ow_group); |
253 | if (ow_group == 1) |
254 | ow_group = utils::max_div(ow_nchunk + 1, max_ow_group); |
255 | break; |
256 | } |
257 | |
258 | conf.src_slm_size = conf.ic_block / 4 |
259 | * (ow_group * conf.stride_w * conf.ow_block |
260 | + (conf.kw - 1) * (1 + conf.dilate_w)); |
261 | |
262 | conf.lws_d[0] = 8 * oc_group; |
263 | conf.lws_d[1] = ow_group; |
264 | conf.lws_d[2] = 1; |
265 | |
266 | conf.src_slm_size = conf.ic_block / 4 |
267 | * (conf.lws_d[1] * conf.stride_w * conf.ow_block |
268 | + (conf.kw - 1) * (1 + conf.dilate_w) + conf.l_pad); |
269 | |
270 | conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]); |
271 | conf.gws_d[1] = conf.od * conf.oh |
272 | * utils::rnd_up( |
273 | utils::div_up(conf.ow, conf.ow_block), ow_group); |
274 | conf.gws_d[2] = (conf.mb_block == 32 ? 2 : 1) |
275 | * utils::div_up(conf.mb, conf.mb_block); |
276 | |
277 | if (conf.ver == ver_1stconv) { |
278 | conf.gws_d[2] = utils::rnd_up(conf.mb, conf.mb_block); |
279 | // Save opportunity to use this implementation with nchw formats, |
280 | // which will result in worse performance, but prevent us using reorder. |
281 | // That can be efficient in some cases. |
282 | conf.is_nchw = src_mdw.matches_one_of_tag(ncw, nchw, ncdhw) |
283 | || src_mdw.format_kind() == format_kind::any; |
284 | // decrease src ic_block in case of input nchw |
285 | if (conf.is_nchw) conf.ic_block = 1; |
286 | } |
287 | } |
288 | |
289 | // TODO: add support for nhwc and dw ow_block |
290 | const bool has_compensation = conf.attr_info.with_src_zpoints |
291 | || conf.attr_info.with_dst_zpoints; |
292 | if (has_compensation) |
293 | if (conf.is_nhwc || (conf.is_depthwise && conf.mb_block != 32)) |
294 | return status::unimplemented; |
295 | |
296 | conf.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
297 | |
298 | format_tag_t src_tag, dst_tag, wei_tag; |
299 | |
300 | if (conf.is_nhwc) { |
301 | src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
302 | dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
303 | |
304 | if (is_1stconv) { |
305 | wei_tag = conf.with_groups |
306 | ? utils::pick(ndims - 3, gOIw8o4i, gOIhw8o4i, gOIdhw8o4i) |
307 | : utils::pick(ndims - 3, OIw8o4i, OIhw8o4i, OIdhw8o4i); |
308 | if (!conf.is_dst_nhwc) { |
309 | if (conf.mb_block == 32) { |
310 | dst_tag = utils::pick( |
311 | ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c); |
312 | } else { |
313 | dst_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c); |
314 | } |
315 | } |
316 | } else if (conf.is_depthwise) { |
317 | wei_tag = utils::pick(ndims - 3, Goiw32g, Goihw32g, Goidhw32g); |
318 | } else { |
319 | wei_tag = conf.with_groups ? utils::pick(ndims - 3, gOIw4o8i8o4i, |
320 | gOIhw4o8i8o4i, gOIdhw4o8i8o4i) |
321 | : utils::pick(ndims - 3, OIw4o8i8o4i, |
322 | OIhw4o8i8o4i, OIdhw4o8i8o4i); |
323 | } |
324 | |
325 | } else { |
326 | if (conf.mb_block == 32) { |
327 | src_tag = utils::pick( |
328 | ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c); |
329 | dst_tag = utils::pick( |
330 | ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c); |
331 | } else { |
332 | src_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c); |
333 | dst_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c); |
334 | } |
335 | |
336 | if (!conf.is_depthwise && conf.ver == ver_1stconv) { |
337 | src_tag = (conf.is_nchw) |
338 | ? utils::pick(ndims - 3, ncw, nchw, ncdhw) |
339 | : utils::pick(ndims - 3, nCw4c, nChw4c, nCdhw4c); |
340 | } |
341 | |
342 | if (conf.is_depthwise) { |
343 | wei_tag = utils::pick(ndims - 3, Goiw32g, Goihw32g, Goidhw32g); |
344 | } else { |
345 | if (conf.ver == ver_1stconv) { |
346 | wei_tag = conf.with_groups |
347 | ? utils::pick( |
348 | ndims - 3, gOIw8o4i, gOIhw8o4i, gOIdhw8o4i) |
349 | : utils::pick(ndims - 3, OIw8o4i, OIhw8o4i, OIdhw8o4i); |
350 | } else { |
351 | wei_tag = conf.with_groups ? utils::pick(ndims - 3, |
352 | gOIw4o8i8o4i, gOIhw4o8i8o4i, gOIdhw4o8i8o4i) |
353 | : utils::pick(ndims - 3, OIw4o8i8o4i, |
354 | OIhw4o8i8o4i, OIdhw4o8i8o4i); |
355 | } |
356 | } |
357 | } |
358 | |
359 | conf.src_tag = src_mdw.format_kind() == format_kind::any |
360 | ? src_tag |
361 | : src_mdw.matches_one_of_tag(src_tag); |
362 | conf.wei_tag = weights_mdw.format_kind() == format_kind::any |
363 | ? wei_tag |
364 | : weights_mdw.matches_one_of_tag(wei_tag); |
365 | conf.dst_tag = dst_mdw.format_kind() == format_kind::any |
366 | ? dst_tag |
367 | : dst_mdw.matches_one_of_tag(dst_tag); |
368 | |
369 | if (conf.src_tag != src_tag || conf.wei_tag != wei_tag |
370 | || conf.dst_tag != dst_tag) |
371 | return status::unimplemented; |
372 | |
373 | return status::success; |
374 | } |
375 | |
376 | status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_kernel_ctx( |
377 | compute::kernel_ctx_t &kernel_ctx) const { |
378 | int owx = nstl::max( |
379 | 1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w)); |
380 | int ow_block_with_stride = conf.stride_w * conf.ow_block; |
381 | int iw_with_l_pad = conf.iw + conf.l_pad; |
382 | int iw_len = iw_with_l_pad < ow_block_with_stride + conf.kw - 1 |
383 | ? iw_with_l_pad - ow_block_with_stride |
384 | : iw_with_l_pad % ow_block_with_stride; |
385 | int iw_tail |
386 | = iw_len < (conf.kw - 1) ? ow_block_with_stride + iw_len : iw_len; |
387 | int ow_tail = conf.ow % conf.ow_block; |
388 | int iw_nchunk = utils::div_up(conf.iw, ow_block_with_stride); |
389 | int ow_nchunk = utils::div_up(conf.ow, conf.ow_block); |
390 | int min_w_nchunk = nstl::min(ow_nchunk, iw_nchunk); |
391 | int slm_tail |
392 | = conf.iw - (conf.stride_w * conf.ow_block * (min_w_nchunk - 1)); |
393 | int zero_tail = utils::rnd_up(conf.ow, conf.ow_block) * conf.stride_w |
394 | - conf.iw + (conf.kw - 1) * (1 + conf.dilate_w) - conf.l_pad; |
395 | |
396 | kernel_ctx.define_int("NCHW" , conf.is_nchw); |
397 | kernel_ctx.define_int("DST_NHWC" , conf.is_dst_nhwc); |
398 | kernel_ctx.define_int("G" , conf.ngroups); |
399 | kernel_ctx.define_int("MB" , conf.mb); |
400 | kernel_ctx.define_int("IC" , conf.ic); |
401 | kernel_ctx.define_int("ID" , conf.id); |
402 | kernel_ctx.define_int("IH" , conf.ih); |
403 | kernel_ctx.define_int("IW" , conf.iw); |
404 | kernel_ctx.define_int("OC" , conf.oc); |
405 | kernel_ctx.define_int("OD" , conf.od); |
406 | kernel_ctx.define_int("OH" , conf.oh); |
407 | kernel_ctx.define_int("OW" , conf.ow); |
408 | kernel_ctx.define_int("KD" , conf.kd); |
409 | kernel_ctx.define_int("KH" , conf.kh); |
410 | kernel_ctx.define_int("KW" , conf.kw); |
411 | kernel_ctx.define_int("SD" , conf.stride_d); |
412 | kernel_ctx.define_int("SH" , conf.stride_h); |
413 | kernel_ctx.define_int("SW" , conf.stride_w); |
414 | kernel_ctx.define_int("PD" , conf.f_pad); |
415 | kernel_ctx.define_int("PH" , conf.t_pad); |
416 | kernel_ctx.define_int("PW" , conf.l_pad); |
417 | kernel_ctx.define_int("DD" , conf.dilate_d); |
418 | kernel_ctx.define_int("DH" , conf.dilate_h); |
419 | kernel_ctx.define_int("DW" , conf.dilate_w); |
420 | |
421 | kernel_ctx.define_int("OW_PADDED" , |
422 | utils::rnd_up( |
423 | utils::div_up(conf.ow, conf.ow_block), conf.lws_d[1])); |
424 | int ow = nstl::max( |
425 | 1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w)); |
426 | kernel_ctx.define_int("OWX" , ow); |
427 | kernel_ctx.define_int("OWB" , utils::div_up(conf.ow, conf.ow_block)); |
428 | |
429 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
430 | kernel_ctx.define_int("OC_BLOCK" , conf.oc_block); |
431 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
432 | kernel_ctx.define_int("OW_BLOCK" , conf.ow_block); |
433 | kernel_ctx.define_int("SRC_SLM_SIZE" , conf.src_slm_size); |
434 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
435 | kernel_ctx.define_int("WITH_BIAS" , conf.with_bias); |
436 | kernel_ctx.define_int("LWS_0" , conf.lws_d[0]); |
437 | kernel_ctx.define_int("LWS_1" , conf.lws_d[1]); |
438 | kernel_ctx.define_int("LWS_2" , conf.lws_d[2]); |
439 | |
440 | kernel_ctx.define_int("OW_TAIL" , ow_tail); |
441 | kernel_ctx.define_int("IW_TAIL" , iw_tail); |
442 | kernel_ctx.define_int("SLM_TAIL" , slm_tail); |
443 | kernel_ctx.define_int("ZERO_TAIL" , zero_tail); |
444 | |
445 | kernel_ctx.define_int("OW_PADDED" , utils::rnd_up(ow_nchunk, conf.lws_d[1])); |
446 | kernel_ctx.define_int("G_PADDED" , |
447 | utils::div_up(conf.ngroups, conf.oc_block) * conf.oc_block); |
448 | |
449 | kernel_ctx.define_int("MB_GROUP" , 1); |
450 | kernel_ctx.define_int("SP_GROUP" , conf.lws_d[1]); |
451 | kernel_ctx.define_int("OC_GROUP" , utils::div_up(conf.lws_d[0], 8)); |
452 | |
453 | kernel_ctx.define_int("OC_NCHUNK" , utils::div_up(conf.oc, conf.oc_block)); |
454 | kernel_ctx.define_int("IC_NCHUNK" , utils::div_up(conf.ic, conf.ic_block)); |
455 | kernel_ctx.define_int("OW_NCHUNK" , ow_nchunk); |
456 | kernel_ctx.define_int("SLM_NCHUNK" , min_w_nchunk); |
457 | kernel_ctx.define_int("OWB" , ow_nchunk); |
458 | kernel_ctx.define_int("OWX" , owx); |
459 | |
460 | if (conf.is_depthwise) |
461 | kernel_ctx.define_int("WEI_32G" , 1); |
462 | else |
463 | kernel_ctx.define_int("WEI_4O8I8O4I" , 1); |
464 | |
465 | kernel_ctx.set_data_type(conf.dst_data_type); |
466 | |
467 | def_data_type(kernel_ctx, conf.src_data_type, "SRC" ); |
468 | def_data_type(kernel_ctx, conf.dst_data_type, "DST" ); |
469 | def_data_type(kernel_ctx, |
470 | conf.attr_info.sum_data_type == dnnl_data_type_undef |
471 | ? conf.dst_data_type |
472 | : conf.attr_info.sum_data_type, |
473 | "SUM" ); |
474 | |
475 | def_attr_info( |
476 | kernel_ctx, conf.attr_info, attr()->post_ops_, &(dst_md()->dims)); |
477 | |
478 | kernel_ctx.add_option("-Dcl_intel_subgroups_char" ); |
479 | kernel_ctx.add_option("-Dcl_intel_subgroups_long" ); |
480 | |
481 | return status::success; |
482 | } |
483 | |
484 | void xe_lp_x8s8x_convolution_fwd_t::pd_t::init_scratchpad() { |
485 | if (conf.attr_info.with_src_zpoints) { |
486 | size_t size = conf.is_depthwise |
487 | ? utils::rnd_up(conf.ngroups, 32) |
488 | : conf.ngroups * utils::rnd_up(conf.oc, 32); |
489 | |
490 | auto scratchpad = scratchpad_registry().registrar(); |
491 | scratchpad.book(memory_tracking::names::key_conv_wei_reduction, size, |
492 | types::data_type_size(data_type::s32), OCL_BUFFER_ALIGNMENT); |
493 | } |
494 | } |
495 | |
496 | status_t xe_lp_x8s8x_convolution_fwd_t::execute_forward( |
497 | const exec_ctx_t &ctx) const { |
498 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
499 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
500 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
501 | auto &oscales = CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES); |
502 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
503 | auto &src_zpoints |
504 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); |
505 | auto &dst_zpoints |
506 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); |
507 | |
508 | const auto &conf = pd()->conf; |
509 | |
510 | // XXX: first convolution calculates compensation in-place |
511 | const bool precompute_compensation = conf.is_depthwise || conf.ic > 4; |
512 | |
513 | std::unique_ptr<memory_storage_t> temp_src_compensation; |
514 | if (conf.attr_info.with_src_zpoints && precompute_compensation) { |
515 | temp_src_compensation = ctx.get_scratchpad_grantor().get_memory_storage( |
516 | memory_tracking::names::key_conv_wei_reduction); |
517 | |
518 | compute::kernel_arg_list_t arg_list; |
519 | arg_list.set(0, src_zpoints); |
520 | arg_list.set(1, weights); |
521 | arg_list.set(2, *temp_src_compensation); |
522 | |
523 | auto nd_range = conf.is_depthwise |
524 | ? compute::nd_range_t( |
525 | {16, utils::div_up(conf.ngroups, 32), 1}, {16, 1, 1}) |
526 | : compute::nd_range_t( |
527 | {8, utils::div_up(conf.oc, 32), conf.ngroups}, |
528 | {8, 1, 1}); |
529 | status_t status = parallel_for( |
530 | ctx, nd_range, src_compensation_kernel_, arg_list); |
531 | if (status != status::success) return status::runtime_error; |
532 | } |
533 | |
534 | compute::kernel_arg_list_t arg_list; |
535 | arg_list.set(0, src); |
536 | arg_list.set(1, weights); |
537 | arg_list.set(2, bias); |
538 | arg_list.set(3, dst); |
539 | |
540 | unsigned arg_idx = append_post_ops_to_arg_list( |
541 | ctx, arg_list, 4, pd()->attr()->post_ops_); |
542 | |
543 | if (conf.attr_info.with_common_oscales |
544 | || conf.attr_info.with_per_oc_oscales) { |
545 | arg_list.set(arg_idx++, oscales); |
546 | } else { |
547 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
548 | } |
549 | |
550 | if (conf.attr_info.with_src_zpoints) { |
551 | if (precompute_compensation) |
552 | arg_list.set(arg_idx++, *temp_src_compensation); |
553 | else |
554 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
555 | arg_list.set(arg_idx++, src_zpoints); |
556 | } else { |
557 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
558 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
559 | } |
560 | |
561 | if (conf.attr_info.with_dst_zpoints) |
562 | arg_list.set(arg_idx++, dst_zpoints); |
563 | else |
564 | arg_list.set(arg_idx++, memory_storage_t::empty_storage()); |
565 | |
566 | auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d); |
567 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
568 | |
569 | if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) { |
570 | ctx.zero_pad_output(DNNL_ARG_DST); |
571 | } |
572 | return status; |
573 | } |
574 | |
575 | status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_conf() { |
576 | using namespace format_tag; |
577 | |
578 | const convolution_desc_t &cd = *desc(); |
579 | const memory_desc_wrapper src_mdw(diff_src_md()); |
580 | const memory_desc_wrapper weights_mdw(weights_md()); |
581 | const memory_desc_wrapper dst_mdw(diff_dst_md()); |
582 | const memory_desc_wrapper bias_mdw(weights_md(1)); |
583 | |
584 | set_default_conf(conf, cd, *diff_src_md(), *weights_md(), *diff_dst_md(), |
585 | *weights_md(1), *attr()); |
586 | |
587 | conf.is_nhwc = is_nhwc(src_mdw, dst_mdw); |
588 | |
589 | if (conf.with_groups && conf.ngroups > 1 |
590 | && (conf.oc % 32 != 0 || conf.ic % 32 != 0)) |
591 | return status::unimplemented; |
592 | |
593 | if (!conf.is_nhwc) { |
594 | if (conf.mb % 16 == 0) { |
595 | conf.ver = ver_mb_block; |
596 | } else { |
597 | conf.ver = ver_ow_block; |
598 | } |
599 | } |
600 | |
601 | conf.oc_block = 32; |
602 | conf.ic_block = 32; |
603 | conf.iw_block = 1; |
604 | |
605 | conf.sub_group_size = 8; |
606 | conf.nchunk = utils::div_up(conf.ic * conf.ngroups, conf.ic_block); |
607 | int ic_group = nstl::min(conf.nchunk, 2); |
608 | |
609 | if (conf.ver == ver_ow_block || conf.is_nhwc) { |
610 | conf.mb_block = 1; |
611 | int max_ic = 4; |
612 | ic_group |
613 | = utils::max_div(utils::div_up(conf.ic, conf.ic_block), max_ic); |
614 | int max_subgroups = 32; |
615 | int max_iw_group = max_subgroups / ic_group; |
616 | conf.iw_block |
617 | = (conf.mb * conf.ic * conf.ih * conf.iw < 49 * 1024) ? 4 : 8; |
618 | int iw_nchunk = utils::div_up(conf.iw, conf.iw_block); |
619 | int iw_group = utils::max_div(iw_nchunk, max_iw_group); |
620 | |
621 | //an upper bound on the number of elems per subgroup |
622 | conf.dst_slm_size = (conf.oc_block / 4) |
623 | * ((iw_group * conf.iw_block) |
624 | + (conf.kw - 1) * (1 + conf.dilate_w)); |
625 | conf.iw_tail = conf.iw % conf.iw_block; |
626 | |
627 | conf.lws_d[0] = 8 * ic_group; |
628 | conf.lws_d[1] = iw_group; |
629 | conf.lws_d[2] = 1; |
630 | |
631 | conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]); |
632 | conf.gws_d[1] = conf.id * conf.ih * iw_nchunk; |
633 | conf.gws_d[2] = utils::div_up(conf.mb, utils::div_up(conf.mb_block, 2)); |
634 | } else { //ver_mb_block |
635 | conf.mb_block = 32; |
636 | conf.lws_d[0] = 8 * ic_group; |
637 | conf.lws_d[1] = 8; |
638 | conf.lws_d[2] = 1; |
639 | |
640 | conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]); |
641 | conf.gws_d[1] |
642 | = conf.id * conf.ih * utils::rnd_up(conf.iw, conf.lws_d[1]); |
643 | conf.gws_d[2] = 2 * utils::div_up(conf.mb, conf.mb_block); |
644 | } |
645 | conf.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
646 | |
647 | format_tag_t src_tag, dst_tag, wei_tag; |
648 | |
649 | if (conf.is_nhwc) { |
650 | src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
651 | dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
652 | } else { |
653 | src_tag = (conf.ver == ver_ow_block) |
654 | ? utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c) |
655 | : utils::pick( |
656 | conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c); |
657 | dst_tag = (conf.ver == ver_ow_block) |
658 | ? utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c) |
659 | : utils::pick( |
660 | conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c); |
661 | } |
662 | |
663 | wei_tag = conf.with_groups ? utils::pick(conf.ndims - 3, gIOw4i8o8i4o, |
664 | gIOhw4i8o8i4o, gIOdhw4i8o8i4o) |
665 | : utils::pick(conf.ndims - 3, IOw4i8o8i4o, |
666 | IOhw4i8o8i4o, IOdhw4i8o8i4o); |
667 | |
668 | conf.dst_data_type = dst_mdw.data_type(); |
669 | conf.src_data_type = src_mdw.data_type(); |
670 | |
671 | conf.src_tag = src_mdw.format_kind() == format_kind::any |
672 | ? src_tag |
673 | : src_mdw.matches_one_of_tag(src_tag); |
674 | conf.wei_tag = weights_mdw.format_kind() == format_kind::any |
675 | ? wei_tag |
676 | : weights_mdw.matches_one_of_tag(wei_tag); |
677 | conf.dst_tag = dst_mdw.format_kind() == format_kind::any |
678 | ? dst_tag |
679 | : dst_mdw.matches_one_of_tag(dst_tag); |
680 | |
681 | if (conf.src_tag != src_tag || conf.wei_tag != wei_tag |
682 | || conf.dst_tag != dst_tag) |
683 | return status::unimplemented; |
684 | |
685 | return status::success; |
686 | } |
687 | |
688 | status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_kernel_ctx( |
689 | compute::kernel_ctx_t &kernel_ctx) const { |
690 | kernel_ctx.define_int("G" , conf.ngroups); |
691 | kernel_ctx.define_int("MB" , conf.mb); |
692 | kernel_ctx.define_int("IC" , conf.ic); |
693 | kernel_ctx.define_int("ID" , conf.id); |
694 | kernel_ctx.define_int("IH" , conf.ih); |
695 | kernel_ctx.define_int("IW" , conf.iw); |
696 | kernel_ctx.define_int("OC" , conf.oc); |
697 | kernel_ctx.define_int("OD" , conf.od); |
698 | kernel_ctx.define_int("OH" , conf.oh); |
699 | kernel_ctx.define_int("OW" , conf.ow); |
700 | kernel_ctx.define_int("KD" , conf.kd); |
701 | kernel_ctx.define_int("KH" , conf.kh); |
702 | kernel_ctx.define_int("KW" , conf.kw); |
703 | kernel_ctx.define_int("SD" , conf.stride_d); |
704 | kernel_ctx.define_int("SH" , conf.stride_h); |
705 | kernel_ctx.define_int("SW" , conf.stride_w); |
706 | kernel_ctx.define_int("PD" , conf.f_pad); |
707 | kernel_ctx.define_int("PH" , conf.t_pad); |
708 | kernel_ctx.define_int("PW" , conf.l_pad); |
709 | kernel_ctx.define_int("DD" , conf.dilate_d); |
710 | kernel_ctx.define_int("DH" , conf.dilate_h); |
711 | kernel_ctx.define_int("DW" , conf.dilate_w); |
712 | |
713 | kernel_ctx.define_int("IW_PADDED" , utils::rnd_up(conf.iw, conf.lws_d[1])); |
714 | kernel_ctx.define_int("IW_TAIL" , conf.iw_tail); |
715 | |
716 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
717 | kernel_ctx.define_int("OC_BLOCK" , conf.oc_block); |
718 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
719 | kernel_ctx.define_int("IW_BLOCK" , conf.iw_block); |
720 | |
721 | kernel_ctx.define_int("MB_GROUP" , 1); |
722 | kernel_ctx.define_int("IC_GROUP" , utils::div_up(conf.lws_d[0], 8)); |
723 | kernel_ctx.define_int("SP_GROUP" , conf.lws_d[1]); |
724 | |
725 | kernel_ctx.define_int("IW_NCHUNK" , utils::div_up(conf.iw, conf.iw_block)); |
726 | kernel_ctx.define_int("OC_NCHUNK" , utils::div_up(conf.oc, conf.oc_block)); |
727 | kernel_ctx.define_int("IC_NCHUNK" , utils::div_up(conf.ic, conf.ic_block)); |
728 | |
729 | kernel_ctx.define_int("DST_SLM_SIZE" , conf.dst_slm_size); |
730 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
731 | |
732 | kernel_ctx.define_int("WITH_BIAS" , conf.with_bias); |
733 | |
734 | kernel_ctx.define_int("LWS_0" , conf.lws_d[0]); |
735 | kernel_ctx.define_int("LWS_1" , conf.lws_d[1]); |
736 | kernel_ctx.define_int("LWS_2" , conf.lws_d[2]); |
737 | |
738 | kernel_ctx.define_int("IS_NHWC" , conf.is_nhwc); |
739 | |
740 | kernel_ctx.set_data_type(conf.dst_data_type); |
741 | def_data_type(kernel_ctx, conf.src_data_type, "SRC" ); |
742 | def_data_type(kernel_ctx, conf.dst_data_type, "DST" ); |
743 | kernel_ctx.add_option("-Dcl_intel_subgroups_char" ); |
744 | |
745 | return status::success; |
746 | } |
747 | |
748 | status_t xe_lp_x8s8x_convolution_bwd_data_t::execute_backward_data( |
749 | const exec_ctx_t &ctx) const { |
750 | |
751 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
752 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
753 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
754 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
755 | |
756 | const auto &conf = pd()->conf; |
757 | |
758 | compute::kernel_arg_list_t arg_list; |
759 | arg_list.set(0, diff_src); |
760 | arg_list.set(1, weights); |
761 | arg_list.set(2, bias); |
762 | arg_list.set(3, diff_dst); |
763 | |
764 | auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d); |
765 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
766 | |
767 | return status; |
768 | } |
769 | |
770 | } // namespace ocl |
771 | } // namespace gpu |
772 | } // namespace impl |
773 | } // namespace dnnl |
774 | |