1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/xe_lp_x8s8x_convolution.hpp"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_traits.hpp"
21#include "common/type_helpers.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace ocl {
27
28bool is_nhwc(const memory_desc_wrapper &src_mdw,
29 const memory_desc_wrapper &dst_mdw) {
30 using namespace format_tag;
31 const bool is_src_nhwc
32 = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
33 const bool is_dst_nhwc
34 = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
35 const bool is_nhwc = is_src_nhwc || is_dst_nhwc;
36 return is_nhwc;
37}
38
39status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_conf() {
40 using namespace format_tag;
41
42 const memory_desc_t *src = src_md();
43 const memory_desc_t *dst = dst_md();
44 const memory_desc_t *wei = weights_md();
45 const memory_desc_t *bia = weights_md(1);
46
47 memory_desc_t r_src, r_wei, r_dst;
48
49 int ndims = src_md()->ndims;
50
51 // XXX: try reduce number of spatial dims when iw/ow/kw=1,
52 // memory tags will be selected based on the number of input dimensions
53 bool use_reshaped_mem = ndims > 3;
54 if (memory_desc_reshape(r_src, *src, src->ndims - 1, src->dims)
55 != status::success)
56 use_reshaped_mem = false;
57 if (memory_desc_reshape(r_dst, *dst, dst->ndims - 1, dst->dims)
58 != status::success)
59 use_reshaped_mem = false;
60 if (memory_desc_reshape(r_wei, *wei, wei->ndims - 1, wei->dims)
61 != status::success)
62 use_reshaped_mem = false;
63
64 if (use_reshaped_mem) {
65 src = &r_src;
66 dst = &r_dst;
67 wei = &r_wei;
68 }
69
70 const convolution_desc_t &cd = *desc();
71 const memory_desc_wrapper src_mdw(src);
72 const memory_desc_wrapper weights_mdw(wei);
73 const memory_desc_wrapper dst_mdw(dst);
74
75 set_default_conf(conf, cd, *src, *wei, *dst, *bia, *attr());
76
77 const bool is_1stconv = conf.ic_without_padding <= 4 && !conf.is_depthwise;
78
79 conf.is_nhwc = is_nhwc(src_mdw, dst_mdw);
80 conf.is_dst_nhwc
81 = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
82 // TODO: Add group convolution support in NHWC kernel.
83 if (!conf.is_depthwise && conf.with_groups && conf.ngroups > 1
84 && (conf.oc % 32 != 0 || conf.ic % 32 != 0))
85 return status::unimplemented;
86
87 conf.dst_data_type = dst_mdw.data_type();
88 conf.src_data_type = src_mdw.data_type();
89
90 conf.oc_block = 32;
91 conf.ic_block = 32;
92 conf.mb_block = 1;
93 conf.ow_block = 1;
94
95 if (conf.is_nhwc) {
96 conf.ver = ver_nhwc;
97 if (conf.is_depthwise) {
98 if (!(conf.kw <= 4 && conf.stride_w <= 2 && conf.dilate_w == 0
99 && conf.l_pad < 4)) {
100 conf.mb_block = 32;
101 } else {
102 int off = conf.kw == 4 ? 1 : 0;
103 if (conf.ow < 15 - off) {
104 conf.ow_block = conf.ow;
105 } else {
106 for (int i = 0; i < 7; ++i) {
107 conf.ow_block = utils::max_div(conf.ow + i, 14 - off);
108 if (conf.ow_block > 4) break;
109 }
110 }
111 }
112
113 int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
114
115 conf.sub_group_size = 16;
116
117 conf.lws_d[0] = 16;
118 conf.lws_d[1] = 1;
119 conf.lws_d[2] = 1;
120
121 conf.gws_d[0] = utils::div_up(conf.ngroups, 32) * conf.lws_d[0];
122 conf.gws_d[1] = conf.od * conf.oh * ow_nchunk;
123 conf.gws_d[2] = utils::div_up(conf.mb,
124 utils::div_up(conf.mb_block, conf.mb_block == 32 ? 4 : 1));
125 } else {
126 if (!is_1stconv) {
127 conf.ow_block
128 = (conf.mb * conf.oc * conf.oh * conf.ow < 49 * 1024)
129 ? 4
130 : 8;
131 } else { // 1st conv
132 conf.ic_block = 4;
133 conf.ow_block = (conf.kw * conf.kh <= 49 && conf.ow % 16 < 8)
134 ? 16
135 : 12;
136 if (conf.mb == 8 || conf.mb % 16 == 0) { conf.mb_block = 32; }
137 }
138
139 int max_oc = 4;
140 int oc_group = utils::max_div(
141 utils::div_up(conf.oc, conf.oc_block), max_oc);
142 int max_subgroups = 32;
143 int max_ow_group = max_subgroups / oc_group;
144 int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
145 int ow_group = utils::max_div(ow_nchunk, max_ow_group);
146
147 conf.sub_group_size = 8;
148 conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block);
149 conf.src_slm_size = conf.ic_block / 4
150 * (ow_group * conf.stride_w * conf.ow_block
151 + (conf.kw - 1) * (1 + conf.dilate_w));
152
153 conf.lws_d[0] = 8 * oc_group;
154 conf.lws_d[1] = ow_group;
155 conf.lws_d[2] = 1;
156
157 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
158 conf.gws_d[1]
159 = conf.od * conf.oh * utils::rnd_up(ow_nchunk, ow_group);
160 conf.gws_d[2] = is_1stconv
161 ? (conf.mb <= 16 && dst_mdw.is_plain())
162 ? conf.mb
163 : utils::rnd_up(conf.mb, conf.mb_block)
164 : utils::div_up(conf.mb,
165 utils::div_up(conf.mb_block,
166 conf.mb_block == 32 ? 2 : 1));
167 }
168
169 } else if (conf.is_depthwise) {
170 if (conf.mb == 8 || conf.mb % 16 == 0
171 || !(conf.kw <= 4 && conf.stride_w <= 2 && conf.dilate_w == 0
172 && conf.l_pad < 4)) {
173 conf.ver = ver_mb_block;
174 conf.mb_block = 32;
175 } else {
176 conf.ver = ver_ow_block;
177 int off = conf.kw == 4 ? 1 : 0;
178 // Try to do not use ow blocks of size > 10 as there is
179 // a lot of GRF memory used what leads to spills
180 if (conf.ow < 10 - off) {
181 conf.ow_block = conf.ow;
182 } else {
183 for (int i = 0; i < 7; ++i) {
184 conf.ow_block = utils::max_div(conf.ow + i, 10 - off);
185 if (conf.ow_block > 4) break;
186 }
187 }
188 }
189
190 conf.sub_group_size = 16;
191 const int spatial_global_size
192 = conf.od * conf.oh * utils::div_up(conf.ow, conf.ow_block);
193
194 conf.lws_d[0] = 16;
195 conf.lws_d[1] = 1;
196 if (conf.ver == ver_mb_block) {
197 // Try to increase WG size in order to improve caching
198 for (const int pixels_per_wg : {2, 3, 5}) {
199 if (spatial_global_size % pixels_per_wg == 0) {
200 conf.lws_d[1] = pixels_per_wg;
201 break;
202 }
203 }
204 }
205 conf.lws_d[2] = 1;
206
207 conf.gws_d[0] = utils::div_up(conf.ngroups, 32) * conf.lws_d[0];
208 conf.gws_d[1] = spatial_global_size;
209 conf.gws_d[2] = (conf.mb_block == 32 ? 4 : 1)
210 * utils::div_up(conf.mb, conf.mb_block);
211
212 } else {
213 if (conf.mb % 16 == 0) {
214 conf.ver = ver_mb_block;
215 conf.mb_block = 32;
216 } else {
217 conf.ver = ver_ow_block;
218 }
219 if (conf.ic <= 4) conf.ver = ver_1stconv;
220
221 int max_oc = 4;
222 int oc_group
223 = utils::max_div(utils::div_up(conf.oc, conf.oc_block), max_oc);
224 int max_subgroups = 32;
225 int max_ow_group = max_subgroups / oc_group;
226 int ow_group = 1;
227 int ow_nchunk = 1;
228
229 conf.sub_group_size = 8;
230 conf.nchunk = utils::div_up(conf.oc * conf.ngroups, conf.oc_block);
231
232 switch (conf.ver) {
233 case ver_mb_block:
234 oc_group = 1;
235 conf.ow_block = 1;
236 ow_group = 1;
237 break;
238 case ver_ow_block:
239 conf.ow_block
240 = (conf.mb * conf.oc * conf.oh * conf.ow < 49 * 1024)
241 ? 4
242 : 8;
243 ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
244 ow_group = utils::max_div(ow_nchunk, max_ow_group);
245 break;
246 case ver_1stconv:
247 conf.ic_block = 4;
248 conf.ow_block = (conf.kw * conf.kh <= 49 && conf.ow % 16 < 8)
249 ? 16
250 : 12;
251 ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
252 ow_group = utils::max_div(ow_nchunk, max_ow_group);
253 if (ow_group == 1)
254 ow_group = utils::max_div(ow_nchunk + 1, max_ow_group);
255 break;
256 }
257
258 conf.src_slm_size = conf.ic_block / 4
259 * (ow_group * conf.stride_w * conf.ow_block
260 + (conf.kw - 1) * (1 + conf.dilate_w));
261
262 conf.lws_d[0] = 8 * oc_group;
263 conf.lws_d[1] = ow_group;
264 conf.lws_d[2] = 1;
265
266 conf.src_slm_size = conf.ic_block / 4
267 * (conf.lws_d[1] * conf.stride_w * conf.ow_block
268 + (conf.kw - 1) * (1 + conf.dilate_w) + conf.l_pad);
269
270 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
271 conf.gws_d[1] = conf.od * conf.oh
272 * utils::rnd_up(
273 utils::div_up(conf.ow, conf.ow_block), ow_group);
274 conf.gws_d[2] = (conf.mb_block == 32 ? 2 : 1)
275 * utils::div_up(conf.mb, conf.mb_block);
276
277 if (conf.ver == ver_1stconv) {
278 conf.gws_d[2] = utils::rnd_up(conf.mb, conf.mb_block);
279 // Save opportunity to use this implementation with nchw formats,
280 // which will result in worse performance, but prevent us using reorder.
281 // That can be efficient in some cases.
282 conf.is_nchw = src_mdw.matches_one_of_tag(ncw, nchw, ncdhw)
283 || src_mdw.format_kind() == format_kind::any;
284 // decrease src ic_block in case of input nchw
285 if (conf.is_nchw) conf.ic_block = 1;
286 }
287 }
288
289 // TODO: add support for nhwc and dw ow_block
290 const bool has_compensation = conf.attr_info.with_src_zpoints
291 || conf.attr_info.with_dst_zpoints;
292 if (has_compensation)
293 if (conf.is_nhwc || (conf.is_depthwise && conf.mb_block != 32))
294 return status::unimplemented;
295
296 conf.with_bias = cd.bias_desc.format_kind != format_kind::undef;
297
298 format_tag_t src_tag, dst_tag, wei_tag;
299
300 if (conf.is_nhwc) {
301 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
302 dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
303
304 if (is_1stconv) {
305 wei_tag = conf.with_groups
306 ? utils::pick(ndims - 3, gOIw8o4i, gOIhw8o4i, gOIdhw8o4i)
307 : utils::pick(ndims - 3, OIw8o4i, OIhw8o4i, OIdhw8o4i);
308 if (!conf.is_dst_nhwc) {
309 if (conf.mb_block == 32) {
310 dst_tag = utils::pick(
311 ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
312 } else {
313 dst_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
314 }
315 }
316 } else if (conf.is_depthwise) {
317 wei_tag = utils::pick(ndims - 3, Goiw32g, Goihw32g, Goidhw32g);
318 } else {
319 wei_tag = conf.with_groups ? utils::pick(ndims - 3, gOIw4o8i8o4i,
320 gOIhw4o8i8o4i, gOIdhw4o8i8o4i)
321 : utils::pick(ndims - 3, OIw4o8i8o4i,
322 OIhw4o8i8o4i, OIdhw4o8i8o4i);
323 }
324
325 } else {
326 if (conf.mb_block == 32) {
327 src_tag = utils::pick(
328 ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
329 dst_tag = utils::pick(
330 ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
331 } else {
332 src_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
333 dst_tag = utils::pick(ndims - 3, nCw32c, nChw32c, nCdhw32c);
334 }
335
336 if (!conf.is_depthwise && conf.ver == ver_1stconv) {
337 src_tag = (conf.is_nchw)
338 ? utils::pick(ndims - 3, ncw, nchw, ncdhw)
339 : utils::pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
340 }
341
342 if (conf.is_depthwise) {
343 wei_tag = utils::pick(ndims - 3, Goiw32g, Goihw32g, Goidhw32g);
344 } else {
345 if (conf.ver == ver_1stconv) {
346 wei_tag = conf.with_groups
347 ? utils::pick(
348 ndims - 3, gOIw8o4i, gOIhw8o4i, gOIdhw8o4i)
349 : utils::pick(ndims - 3, OIw8o4i, OIhw8o4i, OIdhw8o4i);
350 } else {
351 wei_tag = conf.with_groups ? utils::pick(ndims - 3,
352 gOIw4o8i8o4i, gOIhw4o8i8o4i, gOIdhw4o8i8o4i)
353 : utils::pick(ndims - 3, OIw4o8i8o4i,
354 OIhw4o8i8o4i, OIdhw4o8i8o4i);
355 }
356 }
357 }
358
359 conf.src_tag = src_mdw.format_kind() == format_kind::any
360 ? src_tag
361 : src_mdw.matches_one_of_tag(src_tag);
362 conf.wei_tag = weights_mdw.format_kind() == format_kind::any
363 ? wei_tag
364 : weights_mdw.matches_one_of_tag(wei_tag);
365 conf.dst_tag = dst_mdw.format_kind() == format_kind::any
366 ? dst_tag
367 : dst_mdw.matches_one_of_tag(dst_tag);
368
369 if (conf.src_tag != src_tag || conf.wei_tag != wei_tag
370 || conf.dst_tag != dst_tag)
371 return status::unimplemented;
372
373 return status::success;
374}
375
376status_t xe_lp_x8s8x_convolution_fwd_t::pd_t::init_kernel_ctx(
377 compute::kernel_ctx_t &kernel_ctx) const {
378 int owx = nstl::max(
379 1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w));
380 int ow_block_with_stride = conf.stride_w * conf.ow_block;
381 int iw_with_l_pad = conf.iw + conf.l_pad;
382 int iw_len = iw_with_l_pad < ow_block_with_stride + conf.kw - 1
383 ? iw_with_l_pad - ow_block_with_stride
384 : iw_with_l_pad % ow_block_with_stride;
385 int iw_tail
386 = iw_len < (conf.kw - 1) ? ow_block_with_stride + iw_len : iw_len;
387 int ow_tail = conf.ow % conf.ow_block;
388 int iw_nchunk = utils::div_up(conf.iw, ow_block_with_stride);
389 int ow_nchunk = utils::div_up(conf.ow, conf.ow_block);
390 int min_w_nchunk = nstl::min(ow_nchunk, iw_nchunk);
391 int slm_tail
392 = conf.iw - (conf.stride_w * conf.ow_block * (min_w_nchunk - 1));
393 int zero_tail = utils::rnd_up(conf.ow, conf.ow_block) * conf.stride_w
394 - conf.iw + (conf.kw - 1) * (1 + conf.dilate_w) - conf.l_pad;
395
396 kernel_ctx.define_int("NCHW", conf.is_nchw);
397 kernel_ctx.define_int("DST_NHWC", conf.is_dst_nhwc);
398 kernel_ctx.define_int("G", conf.ngroups);
399 kernel_ctx.define_int("MB", conf.mb);
400 kernel_ctx.define_int("IC", conf.ic);
401 kernel_ctx.define_int("ID", conf.id);
402 kernel_ctx.define_int("IH", conf.ih);
403 kernel_ctx.define_int("IW", conf.iw);
404 kernel_ctx.define_int("OC", conf.oc);
405 kernel_ctx.define_int("OD", conf.od);
406 kernel_ctx.define_int("OH", conf.oh);
407 kernel_ctx.define_int("OW", conf.ow);
408 kernel_ctx.define_int("KD", conf.kd);
409 kernel_ctx.define_int("KH", conf.kh);
410 kernel_ctx.define_int("KW", conf.kw);
411 kernel_ctx.define_int("SD", conf.stride_d);
412 kernel_ctx.define_int("SH", conf.stride_h);
413 kernel_ctx.define_int("SW", conf.stride_w);
414 kernel_ctx.define_int("PD", conf.f_pad);
415 kernel_ctx.define_int("PH", conf.t_pad);
416 kernel_ctx.define_int("PW", conf.l_pad);
417 kernel_ctx.define_int("DD", conf.dilate_d);
418 kernel_ctx.define_int("DH", conf.dilate_h);
419 kernel_ctx.define_int("DW", conf.dilate_w);
420
421 kernel_ctx.define_int("OW_PADDED",
422 utils::rnd_up(
423 utils::div_up(conf.ow, conf.ow_block), conf.lws_d[1]));
424 int ow = nstl::max(
425 1, utils::div_up(conf.iw + 2 * conf.l_pad, conf.stride_w));
426 kernel_ctx.define_int("OWX", ow);
427 kernel_ctx.define_int("OWB", utils::div_up(conf.ow, conf.ow_block));
428
429 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
430 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
431 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
432 kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
433 kernel_ctx.define_int("SRC_SLM_SIZE", conf.src_slm_size);
434 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
435 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
436 kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
437 kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
438 kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
439
440 kernel_ctx.define_int("OW_TAIL", ow_tail);
441 kernel_ctx.define_int("IW_TAIL", iw_tail);
442 kernel_ctx.define_int("SLM_TAIL", slm_tail);
443 kernel_ctx.define_int("ZERO_TAIL", zero_tail);
444
445 kernel_ctx.define_int("OW_PADDED", utils::rnd_up(ow_nchunk, conf.lws_d[1]));
446 kernel_ctx.define_int("G_PADDED",
447 utils::div_up(conf.ngroups, conf.oc_block) * conf.oc_block);
448
449 kernel_ctx.define_int("MB_GROUP", 1);
450 kernel_ctx.define_int("SP_GROUP", conf.lws_d[1]);
451 kernel_ctx.define_int("OC_GROUP", utils::div_up(conf.lws_d[0], 8));
452
453 kernel_ctx.define_int("OC_NCHUNK", utils::div_up(conf.oc, conf.oc_block));
454 kernel_ctx.define_int("IC_NCHUNK", utils::div_up(conf.ic, conf.ic_block));
455 kernel_ctx.define_int("OW_NCHUNK", ow_nchunk);
456 kernel_ctx.define_int("SLM_NCHUNK", min_w_nchunk);
457 kernel_ctx.define_int("OWB", ow_nchunk);
458 kernel_ctx.define_int("OWX", owx);
459
460 if (conf.is_depthwise)
461 kernel_ctx.define_int("WEI_32G", 1);
462 else
463 kernel_ctx.define_int("WEI_4O8I8O4I", 1);
464
465 kernel_ctx.set_data_type(conf.dst_data_type);
466
467 def_data_type(kernel_ctx, conf.src_data_type, "SRC");
468 def_data_type(kernel_ctx, conf.dst_data_type, "DST");
469 def_data_type(kernel_ctx,
470 conf.attr_info.sum_data_type == dnnl_data_type_undef
471 ? conf.dst_data_type
472 : conf.attr_info.sum_data_type,
473 "SUM");
474
475 def_attr_info(
476 kernel_ctx, conf.attr_info, attr()->post_ops_, &(dst_md()->dims));
477
478 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
479 kernel_ctx.add_option("-Dcl_intel_subgroups_long");
480
481 return status::success;
482}
483
484void xe_lp_x8s8x_convolution_fwd_t::pd_t::init_scratchpad() {
485 if (conf.attr_info.with_src_zpoints) {
486 size_t size = conf.is_depthwise
487 ? utils::rnd_up(conf.ngroups, 32)
488 : conf.ngroups * utils::rnd_up(conf.oc, 32);
489
490 auto scratchpad = scratchpad_registry().registrar();
491 scratchpad.book(memory_tracking::names::key_conv_wei_reduction, size,
492 types::data_type_size(data_type::s32), OCL_BUFFER_ALIGNMENT);
493 }
494}
495
496status_t xe_lp_x8s8x_convolution_fwd_t::execute_forward(
497 const exec_ctx_t &ctx) const {
498 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
499 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
500 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
501 auto &oscales = CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES);
502 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
503 auto &src_zpoints
504 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
505 auto &dst_zpoints
506 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
507
508 const auto &conf = pd()->conf;
509
510 // XXX: first convolution calculates compensation in-place
511 const bool precompute_compensation = conf.is_depthwise || conf.ic > 4;
512
513 std::unique_ptr<memory_storage_t> temp_src_compensation;
514 if (conf.attr_info.with_src_zpoints && precompute_compensation) {
515 temp_src_compensation = ctx.get_scratchpad_grantor().get_memory_storage(
516 memory_tracking::names::key_conv_wei_reduction);
517
518 compute::kernel_arg_list_t arg_list;
519 arg_list.set(0, src_zpoints);
520 arg_list.set(1, weights);
521 arg_list.set(2, *temp_src_compensation);
522
523 auto nd_range = conf.is_depthwise
524 ? compute::nd_range_t(
525 {16, utils::div_up(conf.ngroups, 32), 1}, {16, 1, 1})
526 : compute::nd_range_t(
527 {8, utils::div_up(conf.oc, 32), conf.ngroups},
528 {8, 1, 1});
529 status_t status = parallel_for(
530 ctx, nd_range, src_compensation_kernel_, arg_list);
531 if (status != status::success) return status::runtime_error;
532 }
533
534 compute::kernel_arg_list_t arg_list;
535 arg_list.set(0, src);
536 arg_list.set(1, weights);
537 arg_list.set(2, bias);
538 arg_list.set(3, dst);
539
540 unsigned arg_idx = append_post_ops_to_arg_list(
541 ctx, arg_list, 4, pd()->attr()->post_ops_);
542
543 if (conf.attr_info.with_common_oscales
544 || conf.attr_info.with_per_oc_oscales) {
545 arg_list.set(arg_idx++, oscales);
546 } else {
547 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
548 }
549
550 if (conf.attr_info.with_src_zpoints) {
551 if (precompute_compensation)
552 arg_list.set(arg_idx++, *temp_src_compensation);
553 else
554 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
555 arg_list.set(arg_idx++, src_zpoints);
556 } else {
557 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
558 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
559 }
560
561 if (conf.attr_info.with_dst_zpoints)
562 arg_list.set(arg_idx++, dst_zpoints);
563 else
564 arg_list.set(arg_idx++, memory_storage_t::empty_storage());
565
566 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
567 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
568
569 if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) {
570 ctx.zero_pad_output(DNNL_ARG_DST);
571 }
572 return status;
573}
574
575status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_conf() {
576 using namespace format_tag;
577
578 const convolution_desc_t &cd = *desc();
579 const memory_desc_wrapper src_mdw(diff_src_md());
580 const memory_desc_wrapper weights_mdw(weights_md());
581 const memory_desc_wrapper dst_mdw(diff_dst_md());
582 const memory_desc_wrapper bias_mdw(weights_md(1));
583
584 set_default_conf(conf, cd, *diff_src_md(), *weights_md(), *diff_dst_md(),
585 *weights_md(1), *attr());
586
587 conf.is_nhwc = is_nhwc(src_mdw, dst_mdw);
588
589 if (conf.with_groups && conf.ngroups > 1
590 && (conf.oc % 32 != 0 || conf.ic % 32 != 0))
591 return status::unimplemented;
592
593 if (!conf.is_nhwc) {
594 if (conf.mb % 16 == 0) {
595 conf.ver = ver_mb_block;
596 } else {
597 conf.ver = ver_ow_block;
598 }
599 }
600
601 conf.oc_block = 32;
602 conf.ic_block = 32;
603 conf.iw_block = 1;
604
605 conf.sub_group_size = 8;
606 conf.nchunk = utils::div_up(conf.ic * conf.ngroups, conf.ic_block);
607 int ic_group = nstl::min(conf.nchunk, 2);
608
609 if (conf.ver == ver_ow_block || conf.is_nhwc) {
610 conf.mb_block = 1;
611 int max_ic = 4;
612 ic_group
613 = utils::max_div(utils::div_up(conf.ic, conf.ic_block), max_ic);
614 int max_subgroups = 32;
615 int max_iw_group = max_subgroups / ic_group;
616 conf.iw_block
617 = (conf.mb * conf.ic * conf.ih * conf.iw < 49 * 1024) ? 4 : 8;
618 int iw_nchunk = utils::div_up(conf.iw, conf.iw_block);
619 int iw_group = utils::max_div(iw_nchunk, max_iw_group);
620
621 //an upper bound on the number of elems per subgroup
622 conf.dst_slm_size = (conf.oc_block / 4)
623 * ((iw_group * conf.iw_block)
624 + (conf.kw - 1) * (1 + conf.dilate_w));
625 conf.iw_tail = conf.iw % conf.iw_block;
626
627 conf.lws_d[0] = 8 * ic_group;
628 conf.lws_d[1] = iw_group;
629 conf.lws_d[2] = 1;
630
631 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
632 conf.gws_d[1] = conf.id * conf.ih * iw_nchunk;
633 conf.gws_d[2] = utils::div_up(conf.mb, utils::div_up(conf.mb_block, 2));
634 } else { //ver_mb_block
635 conf.mb_block = 32;
636 conf.lws_d[0] = 8 * ic_group;
637 conf.lws_d[1] = 8;
638 conf.lws_d[2] = 1;
639
640 conf.gws_d[0] = utils::rnd_up(conf.nchunk * 8, conf.lws_d[0]);
641 conf.gws_d[1]
642 = conf.id * conf.ih * utils::rnd_up(conf.iw, conf.lws_d[1]);
643 conf.gws_d[2] = 2 * utils::div_up(conf.mb, conf.mb_block);
644 }
645 conf.with_bias = cd.bias_desc.format_kind != format_kind::undef;
646
647 format_tag_t src_tag, dst_tag, wei_tag;
648
649 if (conf.is_nhwc) {
650 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
651 dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
652 } else {
653 src_tag = (conf.ver == ver_ow_block)
654 ? utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c)
655 : utils::pick(
656 conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
657 dst_tag = (conf.ver == ver_ow_block)
658 ? utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c)
659 : utils::pick(
660 conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
661 }
662
663 wei_tag = conf.with_groups ? utils::pick(conf.ndims - 3, gIOw4i8o8i4o,
664 gIOhw4i8o8i4o, gIOdhw4i8o8i4o)
665 : utils::pick(conf.ndims - 3, IOw4i8o8i4o,
666 IOhw4i8o8i4o, IOdhw4i8o8i4o);
667
668 conf.dst_data_type = dst_mdw.data_type();
669 conf.src_data_type = src_mdw.data_type();
670
671 conf.src_tag = src_mdw.format_kind() == format_kind::any
672 ? src_tag
673 : src_mdw.matches_one_of_tag(src_tag);
674 conf.wei_tag = weights_mdw.format_kind() == format_kind::any
675 ? wei_tag
676 : weights_mdw.matches_one_of_tag(wei_tag);
677 conf.dst_tag = dst_mdw.format_kind() == format_kind::any
678 ? dst_tag
679 : dst_mdw.matches_one_of_tag(dst_tag);
680
681 if (conf.src_tag != src_tag || conf.wei_tag != wei_tag
682 || conf.dst_tag != dst_tag)
683 return status::unimplemented;
684
685 return status::success;
686}
687
688status_t xe_lp_x8s8x_convolution_bwd_data_t::pd_t::init_kernel_ctx(
689 compute::kernel_ctx_t &kernel_ctx) const {
690 kernel_ctx.define_int("G", conf.ngroups);
691 kernel_ctx.define_int("MB", conf.mb);
692 kernel_ctx.define_int("IC", conf.ic);
693 kernel_ctx.define_int("ID", conf.id);
694 kernel_ctx.define_int("IH", conf.ih);
695 kernel_ctx.define_int("IW", conf.iw);
696 kernel_ctx.define_int("OC", conf.oc);
697 kernel_ctx.define_int("OD", conf.od);
698 kernel_ctx.define_int("OH", conf.oh);
699 kernel_ctx.define_int("OW", conf.ow);
700 kernel_ctx.define_int("KD", conf.kd);
701 kernel_ctx.define_int("KH", conf.kh);
702 kernel_ctx.define_int("KW", conf.kw);
703 kernel_ctx.define_int("SD", conf.stride_d);
704 kernel_ctx.define_int("SH", conf.stride_h);
705 kernel_ctx.define_int("SW", conf.stride_w);
706 kernel_ctx.define_int("PD", conf.f_pad);
707 kernel_ctx.define_int("PH", conf.t_pad);
708 kernel_ctx.define_int("PW", conf.l_pad);
709 kernel_ctx.define_int("DD", conf.dilate_d);
710 kernel_ctx.define_int("DH", conf.dilate_h);
711 kernel_ctx.define_int("DW", conf.dilate_w);
712
713 kernel_ctx.define_int("IW_PADDED", utils::rnd_up(conf.iw, conf.lws_d[1]));
714 kernel_ctx.define_int("IW_TAIL", conf.iw_tail);
715
716 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
717 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
718 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
719 kernel_ctx.define_int("IW_BLOCK", conf.iw_block);
720
721 kernel_ctx.define_int("MB_GROUP", 1);
722 kernel_ctx.define_int("IC_GROUP", utils::div_up(conf.lws_d[0], 8));
723 kernel_ctx.define_int("SP_GROUP", conf.lws_d[1]);
724
725 kernel_ctx.define_int("IW_NCHUNK", utils::div_up(conf.iw, conf.iw_block));
726 kernel_ctx.define_int("OC_NCHUNK", utils::div_up(conf.oc, conf.oc_block));
727 kernel_ctx.define_int("IC_NCHUNK", utils::div_up(conf.ic, conf.ic_block));
728
729 kernel_ctx.define_int("DST_SLM_SIZE", conf.dst_slm_size);
730 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
731
732 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
733
734 kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
735 kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
736 kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
737
738 kernel_ctx.define_int("IS_NHWC", conf.is_nhwc);
739
740 kernel_ctx.set_data_type(conf.dst_data_type);
741 def_data_type(kernel_ctx, conf.src_data_type, "SRC");
742 def_data_type(kernel_ctx, conf.dst_data_type, "DST");
743 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
744
745 return status::success;
746}
747
748status_t xe_lp_x8s8x_convolution_bwd_data_t::execute_backward_data(
749 const exec_ctx_t &ctx) const {
750
751 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
752 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
753 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
754 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
755
756 const auto &conf = pd()->conf;
757
758 compute::kernel_arg_list_t arg_list;
759 arg_list.set(0, diff_src);
760 arg_list.set(1, weights);
761 arg_list.set(2, bias);
762 arg_list.set(3, diff_dst);
763
764 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
765 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
766
767 return status;
768}
769
770} // namespace ocl
771} // namespace gpu
772} // namespace impl
773} // namespace dnnl
774