1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gen9_convolution.hpp"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_traits.hpp"
21#include "common/math_utils.hpp"
22#include "common/reorder.hpp"
23#include "common/type_helpers.hpp"
24#include "gpu/ocl/ocl_memory_storage.hpp"
25#include "gpu/primitive_conf.hpp"
26
27using namespace dnnl::impl::memory_tracking::names;
28
29namespace dnnl {
30namespace impl {
31namespace gpu {
32namespace ocl {
33
34using namespace dnnl::impl::data_type;
35using namespace dnnl::impl::format_tag;
36
37static void fwd_compute_block_sizes(
38 conv_conf_t &conf, const convolution_pd_t *pd) {
39
40 int max_ow_block = (conf.src_data_type == data_type::f16 ? 20 : 16);
41 if (conf.ver == ver_16mb16c || conf.ver == ver_32mb16c) {
42 max_ow_block = 1;
43 } else if (conf.is_depthwise || conf.ver == ver_1stconv) {
44 max_ow_block = 8;
45 }
46 max_ow_block = nstl::min(conf.ow, max_ow_block);
47
48 if (conf.ver == ver_16mb16c) {
49 conf.mb_block
50 = (conf.src_data_type == data_type::f16 && !conf.is_depthwise)
51 ? (conf.mb % 32 == 0 ? 32 : 16)
52 : 16;
53 } else if (conf.ver == ver_32mb16c) {
54 conf.mb_block = 32;
55 } else {
56 conf.mb_block = 1;
57 }
58
59 conf.ow_block = utils::max_div(conf.ow, max_ow_block);
60
61 if (conf.ow_block < max_ow_block / 2) {
62 float min_tail_ratio = 1;
63 int best_ow_block = -1;
64 for (int ow_block = 8; ow_block <= max_ow_block; ow_block++) {
65 float tail_ratio
66 = (ow_block - (conf.ow % ow_block)) / (float)conf.ow;
67 if (tail_ratio <= min_tail_ratio) {
68 min_tail_ratio = tail_ratio;
69 best_ow_block = ow_block;
70 }
71 }
72 assert(best_ow_block > 0);
73 conf.ow_block = best_ow_block;
74 }
75
76 if (conf.is_depthwise) {
77 conf.oc_block = 16;
78 conf.ic_block = 16;
79 conf.omb = conf.mb_block;
80 return;
81 }
82
83 if (conf.ver == ver_1stconv && conf.mb_block == 1 && conf.oc % 32 == 0) {
84 conf.oc_block = 32;
85 } else {
86 conf.oc_block = 16;
87 }
88 conf.ic_block = nstl::min(conf.ic, 16);
89
90 conf.omb = (conf.mb_block == 1 && conf.mb % 16 == 0) ? 16 : conf.mb_block;
91 conf.ocb = utils::max_div(conf.oc / 16, 8) * 16;
92}
93
94status_t gen9_convolution_fwd_t::pd_t::init_conf(engine_t *engine) {
95
96 const convolution_desc_t &cd = *desc();
97 const memory_desc_wrapper src_mdw(src_md());
98 const memory_desc_wrapper weights_mdw(weights_md());
99 const memory_desc_wrapper dst_mdw(dst_md());
100 const memory_desc_wrapper bias_mdw(weights_md(1));
101
102 set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(),
103 *weights_md(1), *attr());
104
105 const bool int8_dst = conf.dst_data_type == data_type::s8;
106 const bool is_src_nhwc
107 = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
108 const bool is_dst_nhwc
109 = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef;
110 const bool is_nhwc = is_src_nhwc || is_dst_nhwc;
111
112 const bool is_1stconv = conf.ic_without_padding == 3;
113 const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1)
114 && (conf.oc_without_padding == 1);
115
116 conf.is_nhwc = is_1stconv ? is_dst_nhwc : is_nhwc;
117 conf.is_depthwise = is_depthwise;
118
119 const int out_block = int8_dst && !is_1stconv ? 32 : 16;
120 if (is_1stconv || (conf.with_groups && conf.ngroups > 1)) {
121 conf.ic = conf.ic_without_padding;
122 conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, out_block)
123 : conf.oc_without_padding;
124 } else {
125 conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
126 conf.oc = utils::rnd_up(conf.oc_without_padding, out_block);
127 }
128
129 conf.ngroups_without_padding = conf.ngroups;
130 if (is_depthwise)
131 conf.ngroups = utils::rnd_up(conf.ngroups, int8_dst ? 32 : 16);
132
133 const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0);
134 const bool is_16oc = conf.oc % out_block == 0;
135 const bool is_16ic = conf.ic % 16 == 0;
136
137 conf.mb_block = 1;
138 conf.oc_block = 1;
139 conf.ic_block = 1;
140 conf.od_block = 1;
141 conf.oh_block = 1;
142 conf.ow_block = 1;
143 conf.omb = 1;
144 conf.ocb = 1;
145 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
146 const bool is_xe_hp_plus
147 = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg();
148 const bool has_non_uniform_wg
149 = compute_engine->mayiuse_non_uniform_work_groups();
150
151 if (conf.is_nhwc) {
152 if (!utils::one_of(src_mdw.data_type(), f32, f16))
153 return status::unimplemented;
154 if (conf.is_depthwise && conf.ngroups_without_padding % 16)
155 return status::unimplemented;
156 // TODO: Add group convolution support in NHWC kernel.
157 if (!conf.is_depthwise && conf.ngroups > 1 && !(is_16oc && is_16ic)) {
158 return status::unimplemented;
159 }
160 if (int8_dst) { return status::unimplemented; }
161 conf.ver = ver_nhwc;
162 } else if (is_1stconv) {
163 if (!is_16oc) return status::unimplemented;
164 conf.ver = ver_1stconv;
165 } else if ((is_16oc && is_16ic) || is_dw_16g) {
166 if (conf.mb % 32 == 0 && conf.is_depthwise
167 && utils::one_of(src_mdw.data_type(), bf16, f16)
168 && is_xe_hp_plus) {
169 conf.ver = ver_32mb16c;
170 } else {
171 conf.ver = (conf.mb % out_block == 0) ? ver_16mb16c : ver_8ow16c;
172 }
173 } else {
174 return status::unimplemented;
175 }
176
177 const bool is_fp16 = src_mdw.data_type() == data_type::f16;
178
179 switch (conf.ver) {
180 case ver_nhwc: {
181 conf.mb_block = 1;
182 conf.oc_block = 16;
183 conf.ic_block = is_1stconv ? 1 : 16;
184
185 int max_ow_block = (conf.kw > 1) ? 8 : 16;
186 if (conf.oc <= 64 && conf.ic <= 64) max_ow_block = 8;
187
188 conf.ow_block = utils::max_div(conf.ow, max_ow_block);
189
190 if (conf.ow_block <= 8) {
191 int max_tail = 0;
192 for (int j = 8; j < max_ow_block; j++) {
193 if (conf.ow % j > max_tail) {
194 max_tail = conf.ow % j;
195 conf.ow_block = j;
196 }
197 }
198 }
199 if (conf.ow_block <= 8) conf.ow_block = 8;
200 if (conf.ow <= 8 || conf.oc <= 32) conf.ow_block = 8;
201
202 conf.oh_block = 1;
203 conf.sub_group_size = 16;
204 conf.lws_d[0] = 16;
205 conf.lws_d[1] = 1;
206 conf.lws_d[2] = 1;
207
208 int max_oc_block = 8;
209 if (conf.is_depthwise) {
210 conf.ocb = conf.ngroups;
211 } else {
212 conf.ocb = conf.oc_block
213 * utils::max_div(utils::div_up(conf.oc, conf.oc_block),
214 max_oc_block);
215 }
216
217 conf.gws_d[0] = conf.ocb;
218 conf.gws_d[1] = utils::div_up(conf.oh, conf.oh_block)
219 * utils::div_up(conf.ow, conf.ow_block) * conf.od;
220 if (conf.is_depthwise) {
221 conf.gws_d[2] = conf.mb;
222 } else {
223 conf.gws_d[2] = conf.mb * utils::div_up(conf.oc, conf.ocb)
224 * conf.ngroups;
225 }
226 } break;
227 case ver_1stconv:
228 case ver_8ow16c:
229 case ver_16mb16c:
230 case ver_32mb16c: {
231 fwd_compute_block_sizes(conf, this);
232 conf.sub_group_size = 16;
233 conf.gws_d[0] = conf.ngroups * conf.ocb / (conf.oc_block / 16);
234 conf.gws_d[1]
235 = (conf.od * conf.oh * utils::div_up(conf.ow, conf.ow_block)
236 * (conf.omb / conf.mb_block));
237 conf.gws_d[2] = (conf.oc / conf.ocb) * (conf.mb / conf.omb);
238 conf.lws_d[0] = is_xe_hp_plus ? 32 : 16;
239 conf.lws_d[1] = 1;
240 conf.lws_d[2] = 1;
241 break;
242 }
243 default: return status::unimplemented;
244 }
245
246 maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf);
247
248 format_tag_t src_tag, dst_tag, wei_tag;
249
250 switch (conf.ver) {
251 case ver_nhwc:
252 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
253 dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
254 if (is_1stconv) {
255 wei_tag = conf.with_groups ? utils::pick(
256 conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
257 : utils::pick(conf.ndims - 3, Owi16o,
258 Ohwi16o, Odhwi16o);
259 } else if (conf.is_depthwise) {
260 wei_tag = utils::pick(
261 conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g);
262 } else {
263 wei_tag = conf.with_groups
264 ? utils::pick(conf.ndims - 3, gOIw16i16o, gOIhw16i16o,
265 gOIdhw16i16o)
266 : utils::pick(conf.ndims - 3, OIw16i16o, OIhw16i16o,
267 OIdhw16i16o);
268 }
269 break;
270 case ver_1stconv:
271 if (is_src_nhwc)
272 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
273 else
274 src_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw);
275
276 if (is_xe_hp_plus && is_fp16) {
277 dst_tag = conf.mb % 32 == 0 ? utils::pick(conf.ndims - 3,
278 NCw32n16c, NChw32n16c, NCdhw32n16c)
279 : utils::pick(conf.ndims - 3,
280 nCw16c, nChw16c, nCdhw16c);
281 } else {
282 dst_tag = conf.mb % 16 == 0 ? utils::pick(conf.ndims - 3,
283 NCw16n16c, NChw16n16c, NCdhw16n16c)
284 : utils::pick(conf.ndims - 3,
285 nCw16c, nChw16c, nCdhw16c);
286 }
287 wei_tag = conf.with_groups
288 ? utils::pick(conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
289 : utils::pick(conf.ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
290 break;
291 case ver_16mb16c:
292 src_tag = utils::pick(
293 conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
294 dst_tag = utils::pick(
295 conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
296 wei_tag = conf.is_depthwise
297 ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
298 : (conf.with_groups ? utils::pick(conf.ndims - 3,
299 gIOw16i16o, gIOhw16i16o, gIOdhw16i16o)
300 : utils::pick(conf.ndims - 3, IOw16i16o,
301 IOhw16i16o, IOdhw16i16o));
302 break;
303 case ver_32mb16c:
304 src_tag = utils::pick(
305 conf.ndims - 3, NCw32n16c, NChw32n16c, NCdhw32n16c);
306 dst_tag = utils::pick(
307 conf.ndims - 3, NCw32n16c, NChw32n16c, NCdhw32n16c);
308 wei_tag = conf.is_depthwise
309 ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
310 : (conf.with_groups ? utils::pick(conf.ndims - 3,
311 gIOw16i16o, gIOhw16i16o, gIOdhw16i16o)
312 : utils::pick(conf.ndims - 3, IOw16i16o,
313 IOhw16i16o, IOdhw16i16o));
314 break;
315 case ver_8ow16c:
316 src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
317 dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
318 wei_tag = conf.is_depthwise
319 ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
320 : (conf.with_groups ? utils::pick(conf.ndims - 3,
321 gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
322 : utils::pick(conf.ndims - 3, OIw16i16o,
323 OIhw16i16o, OIdhw16i16o));
324 break;
325 default: return status::unimplemented;
326 }
327 if (int8_dst) {
328 if (is_1stconv && conf.ic_without_padding < 4) {
329 dst_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw);
330 } else if (conf.ver == ver_16mb16c || conf.ver == ver_32mb16c) {
331 dst_tag = utils::pick(
332 conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c);
333 } else {
334 dst_tag = utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c);
335 }
336 }
337
338 if (src_mdw.format_kind() == format_kind::any) {
339 conf.src_tag = src_tag;
340 } else {
341 conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
342 }
343 if (conf.src_tag != src_tag) return status::unimplemented;
344
345 if (weights_mdw.format_kind() == format_kind::any) {
346 conf.wei_tag = wei_tag;
347 } else {
348 conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
349 }
350 if (conf.wei_tag != wei_tag) return status::unimplemented;
351
352 if (dst_mdw.format_kind() == format_kind::any) {
353 conf.dst_tag = dst_tag;
354 } else {
355 conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
356 }
357 if (conf.dst_tag != dst_tag) return status::unimplemented;
358
359 conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw);
360 conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc);
361
362 return status::success;
363}
364
365status_t gen9_convolution_fwd_t::pd_t::init_kernel_ctx(
366 compute::kernel_ctx_t &kernel_ctx) const {
367 kernel_ctx.define_int("IS_DW", conf.is_depthwise);
368 kernel_ctx.define_int("G", conf.ngroups);
369 kernel_ctx.define_int("MB", conf.mb);
370 kernel_ctx.define_int("IC", conf.ic);
371 kernel_ctx.define_int("ID", conf.id);
372 kernel_ctx.define_int("IH", conf.ih);
373 kernel_ctx.define_int("IW", conf.iw);
374 kernel_ctx.define_int("OC", conf.oc);
375 kernel_ctx.define_int("OD", conf.od);
376 kernel_ctx.define_int("OH", conf.oh);
377 kernel_ctx.define_int("OW", conf.ow);
378 kernel_ctx.define_int("KD", conf.kd);
379 kernel_ctx.define_int("KH", conf.kh);
380 kernel_ctx.define_int("KW", conf.kw);
381 kernel_ctx.define_int("SD", conf.stride_d);
382 kernel_ctx.define_int("SH", conf.stride_h);
383 kernel_ctx.define_int("SW", conf.stride_w);
384 kernel_ctx.define_int("PD", conf.f_pad);
385 kernel_ctx.define_int("PH", conf.t_pad);
386 kernel_ctx.define_int("PW", conf.l_pad);
387 kernel_ctx.define_int("PD_R", conf.back_pad);
388 kernel_ctx.define_int("PH_R", conf.b_pad);
389 kernel_ctx.define_int("PW_R", conf.r_pad);
390 kernel_ctx.define_int("DD", conf.dilate_d);
391 kernel_ctx.define_int("DH", conf.dilate_h);
392 kernel_ctx.define_int("DW", conf.dilate_w);
393 kernel_ctx.define_int("OW_PADDED", utils::rnd_up(conf.ow, 4));
394 kernel_ctx.define_int("OC_PADDED", conf.oc);
395 kernel_ctx.define_int("OMB", conf.omb);
396 kernel_ctx.define_int("OCB", conf.ocb);
397 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
398 kernel_ctx.define_int("OH_BLOCK", conf.oh_block);
399 kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
400 kernel_ctx.define_int("OW_LAST", utils::rnd_dn(conf.ow, conf.ow_block));
401 kernel_ctx.define_int("OWB", utils::div_up(conf.ow, conf.ow_block));
402 kernel_ctx.define_int("OHB", utils::div_up(conf.oh, conf.oh_block));
403 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
404 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
405 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
406 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
407 kernel_ctx.define_int("G_WO_PADDING", conf.ngroups_without_padding);
408 kernel_ctx.define_int("IC_WO_PADDING", conf.ic_without_padding);
409 kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
410 kernel_ctx.define_int("OC_GROUP", conf.lws_d[0] / 8);
411 kernel_ctx.define_int("MB_GROUP", 1);
412 kernel_ctx.define_int("SP_GROUP", conf.lws_d[1]);
413 if (conf.kw == 1)
414 kernel_ctx.define_int("SRC_SP_GROUP", conf.lws_d[1] + conf.kw - 1);
415 else
416 kernel_ctx.define_int(
417 "SRC_SP_GROUP", conf.stride_w * (conf.lws_d[1] - 1) + conf.kw);
418
419 kernel_ctx.set_data_type(conf.src_data_type);
420 def_data_type(kernel_ctx, conf.dst_data_type, "DST");
421
422 kernel_ctx.define_int("VER_1STCONV", conf.ver == ver_1stconv);
423 kernel_ctx.define_int("VER_8OW16C", conf.ver == ver_8ow16c);
424 kernel_ctx.define_int("VER_16MB16C", conf.ver == ver_16mb16c);
425 kernel_ctx.define_int("VER_32MB16C", conf.ver == ver_32mb16c);
426
427 kernel_ctx.define_int("SRC_NCHW", conf.is_src_nchw);
428 kernel_ctx.define_int("SRC_NHWC", conf.is_src_nhwc);
429 kernel_ctx.define_int("SRC_16N16C",
430 utils::one_of(conf.src_tag, NCw16n16c, NChw16n16c, NCdhw16n16c));
431 kernel_ctx.define_int(
432 "SRC_W16C", utils::one_of(conf.src_tag, nCw16c, nChw16c, nCdhw16c));
433
434 kernel_ctx.define_int("WEI_I16O",
435 utils::one_of(conf.wei_tag, gOwi16o, gOhwi16o, gOdhwi16o, Owi16o,
436 Ohwi16o, Odhwi16o));
437 kernel_ctx.define_int("WEI_16I16O",
438 utils::one_of(conf.wei_tag, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o,
439 OIw16i16o, OIhw16i16o, OIdhw16i16o));
440 kernel_ctx.define_int("WEI_16I16O_FLIPPED",
441 utils::one_of(conf.wei_tag, gIOw16i16o, gIOhw16i16o, gIOdhw16i16o,
442 IOw16i16o, IOhw16i16o, IOdhw16i16o));
443
444 kernel_ctx.define_int(
445 "DST_W16C", utils::one_of(conf.dst_tag, nCw16c, nChw16c, nCdhw16c));
446 kernel_ctx.define_int("DST_16N16C",
447 utils::one_of(conf.dst_tag, NCw16n16c, NChw16n16c, NCdhw16n16c));
448 kernel_ctx.define_int("DST_32N16C",
449 utils::one_of(conf.dst_tag, NCw32n16c, NChw32n16c, NCdhw32n16c));
450 kernel_ctx.define_int("DST_32N32C",
451 utils::one_of(conf.dst_tag, NCw32n32c, NChw32n32c, NCdhw32n32c));
452 kernel_ctx.define_int(
453 "DST_W32C", utils::one_of(conf.dst_tag, nCw32c, nChw32c, nCdhw32c));
454 kernel_ctx.define_int(
455 "DST_NCHW", utils::one_of(conf.dst_tag, ncw, nchw, ncdhw));
456
457 kernel_ctx.define_int("GWS_0", conf.gws_d[0]);
458 kernel_ctx.define_int("GWS_1", conf.gws_d[1]);
459 kernel_ctx.define_int("GWS_2", conf.gws_d[2]);
460
461 kernel_ctx.define_int("GWS_ORIG_0", conf.gws_orig_d[0]);
462 kernel_ctx.define_int("GWS_ORIG_1", conf.gws_orig_d[1]);
463 kernel_ctx.define_int("GWS_ORIG_2", conf.gws_orig_d[2]);
464
465 kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
466 kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
467 kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
468
469 dnnl_dims_t dst_dims;
470 dst_dims[0] = conf.mb;
471 dst_dims[1] = conf.ngroups_without_padding * conf.oc_without_padding;
472 dst_dims[2] = conf.ndims > 4 ? conf.od : conf.oh;
473 dst_dims[3] = conf.ndims > 4 ? conf.oh : conf.ow;
474 dst_dims[4] = conf.ow;
475 kernel_ctx.add_option("-cl-std=CL2.0");
476 def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_, &dst_dims);
477
478 kernel_ctx.print_options();
479 return status::success;
480}
481
482status_t gen9_convolution_bwd_data_t::pd_t::init_conf(engine_t *engine) {
483 using namespace dnnl::impl::format_tag;
484 using namespace data_type;
485
486 const convolution_desc_t &cd = *desc();
487 const memory_desc_wrapper src_mdw(diff_src_md());
488 const memory_desc_wrapper weights_mdw(weights_md());
489 const memory_desc_wrapper dst_mdw(diff_dst_md());
490 const memory_desc_wrapper bias_mdw(weights_md(1));
491
492 set_default_conf(conf, cd, *diff_src_md(), *weights_md(), *diff_dst_md(),
493 *weights_md(1), *attr());
494 const bool is_nhwc
495 = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef
496 || dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc)
497 != format_tag::undef;
498 const bool is_1stconv = conf.ic_without_padding == 3;
499 const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1)
500 && (conf.oc_without_padding == 1);
501 conf.is_nhwc = is_nhwc;
502 conf.is_depthwise = is_depthwise;
503
504 if (is_nhwc && (is_depthwise || is_1stconv)) return status::unimplemented;
505
506 if (is_1stconv || (conf.with_groups && conf.ngroups > 1) || is_depthwise) {
507 conf.ic = conf.ic_without_padding;
508 conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, 16)
509 : conf.oc_without_padding;
510 } else {
511 conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
512 conf.oc = utils::rnd_up(conf.oc_without_padding, 16);
513 }
514 conf.ngroups_without_padding = conf.ngroups;
515 if (is_depthwise) conf.ngroups = utils::rnd_up(conf.ngroups, 16);
516 const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0);
517
518 const bool is_16ic = conf.ic % 16 == 0;
519 const bool is_16oc = conf.oc % 16 == 0;
520 const bool use_16mb_unroll = !is_nhwc
521 && !(conf.mb == 1 || conf.mb % 16 != 0) && !is_1stconv
522 && ((is_16ic && is_16oc) || is_dw_16g);
523 conf.mb_block = 1;
524 conf.oc_block = 1;
525 conf.ic_block = 1;
526 conf.od_block = 1;
527 conf.oh_block = 1;
528 conf.ow_block = 1;
529 conf.icb = 1;
530 if (is_nhwc)
531 conf.ver = ver_nhwc;
532 else if (use_16mb_unroll)
533 conf.ver = ver_16mb16c;
534 else if (conf.mb % 16 != 0 && ((is_16oc && is_16ic) || is_dw_16g))
535 conf.ver = ver_8ow16c;
536 else
537 return status::unimplemented;
538
539 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
540 //TODO: Fix Gtests and reenable
541 const bool is_xe_hp_plus
542 = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg();
543 const bool has_non_uniform_wg
544 = compute_engine->mayiuse_non_uniform_work_groups();
545
546 status_t status = status::success;
547 switch (conf.ver) {
548 case ver_16mb16c:
549 conf.mb_block = 16;
550 conf.oc_block = 16;
551 conf.ic_block = 16;
552 conf.od_block = 1;
553 conf.ih_block = 1;
554 conf.iw_block = 1;
555 conf.sub_group_size = 16;
556 if (conf.is_depthwise) {
557 conf.icb = conf.ngroups;
558 conf.lws_d[0] = 1;
559 conf.lws_d[1] = is_xe_hp_plus ? 32 : 16;
560 conf.lws_d[2] = 1;
561 conf.gws_d[0] = conf.ih * conf.iw * conf.id;
562 conf.gws_d[1] = conf.ic * conf.ngroups;
563 conf.gws_d[2] = conf.mb / 16;
564 } else {
565 conf.icb = 64;
566 while (conf.icb > 16) {
567 if (conf.ic % conf.icb == 0) break;
568 conf.icb /= 2;
569 }
570 conf.lws_d[0] = is_xe_hp_plus ? 32 : 16;
571 conf.lws_d[1] = 1;
572 conf.lws_d[2] = 1;
573 conf.gws_d[0] = conf.icb;
574 conf.gws_d[1] = conf.ih * conf.iw * conf.id;
575 conf.gws_d[2]
576 = conf.mb / 16 * (conf.ic / conf.icb) * conf.ngroups;
577 }
578 break;
579 case ver_8ow16c:
580 case ver_nhwc: {
581 conf.mb_block = 1;
582 conf.oc_block = 16;
583 conf.ic_block = 16;
584 conf.od_block = 1;
585 conf.ih_block = 1;
586 int max_iw_block = 16;
587 if (conf.ver == ver_nhwc) { max_iw_block = (conf.kw > 1) ? 8 : 16; }
588 conf.iw_block = nstl::max(8, utils::max_div(conf.iw, max_iw_block));
589 conf.sub_group_size = 16;
590 if (conf.is_depthwise) {
591 conf.icb = conf.ngroups;
592 conf.lws_d[0] = 1;
593 conf.lws_d[1] = conf.ic_block;
594 conf.lws_d[2] = 1;
595 conf.gws_d[0] = conf.ih * utils::div_up(conf.iw, conf.iw_block)
596 * conf.id;
597 conf.gws_d[1] = conf.ic * conf.ngroups;
598 conf.gws_d[2] = conf.mb;
599 } else {
600 conf.icb = 64;
601 while (conf.icb > conf.ic_block) {
602 if (utils::rnd_up(conf.ic, conf.ic_block) % conf.icb == 0)
603 break;
604 conf.icb /= 2;
605 }
606 conf.lws_d[0] = conf.ic_block;
607 conf.lws_d[1] = 1;
608 conf.lws_d[2] = 1;
609 conf.gws_d[0] = conf.icb;
610 conf.gws_d[1] = conf.ih * utils::div_up(conf.iw, conf.iw_block)
611 * conf.id;
612 conf.gws_d[2] = conf.mb
613 * (utils::rnd_up(conf.ic, conf.ic_block) / conf.icb)
614 * conf.ngroups;
615 }
616 break;
617 }
618 default: status = status::unimplemented;
619 }
620
621 maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf);
622
623 format_tag_t src_tag, dst_tag, wei_tag;
624
625 switch (conf.ver) {
626 case ver_nhwc:
627 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
628 dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
629 wei_tag = conf.with_groups ? utils::pick(conf.ndims - 3, gOIw16o16i,
630 gOIhw16o16i, gOIdhw16o16i)
631 : utils::pick(conf.ndims - 3, OIw16o16i,
632 OIhw16o16i, OIdhw16o16i);
633 break;
634 case ver_16mb16c:
635 src_tag = utils::pick(
636 conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
637 dst_tag = utils::pick(
638 conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
639 wei_tag = conf.is_depthwise
640 ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
641 : (conf.with_groups ? utils::pick(conf.ndims - 3,
642 gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
643 : utils::pick(conf.ndims - 3, OIw16o16i,
644 OIhw16o16i, OIdhw16o16i));
645 break;
646 case ver_8ow16c:
647 src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
648 dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
649 wei_tag = conf.is_depthwise
650 ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
651 : (conf.with_groups ? utils::pick(conf.ndims - 3,
652 gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
653 : utils::pick(conf.ndims - 3, OIw16o16i,
654 OIhw16o16i, OIdhw16o16i));
655 break;
656 default: status = status::unimplemented;
657 }
658 if (status != status::success) return status;
659
660 if (src_mdw.format_kind() == format_kind::any) {
661 conf.src_tag = src_tag;
662 } else {
663 conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
664 }
665 if (conf.src_tag != src_tag) return status::unimplemented;
666
667 if (weights_mdw.format_kind() == format_kind::any) {
668 conf.wei_tag = wei_tag;
669 } else {
670 conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
671 }
672 if (conf.wei_tag != wei_tag) return status::unimplemented;
673
674 if (dst_mdw.format_kind() == format_kind::any) {
675 conf.dst_tag = dst_tag;
676 } else {
677 conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
678 }
679 if (conf.dst_tag != dst_tag) return status::unimplemented;
680
681 conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw);
682 conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc);
683
684 return status::success;
685}
686
687status_t gen9_convolution_bwd_data_t::pd_t::init_kernel_ctx(
688 compute::kernel_ctx_t &kernel_ctx) const {
689 kernel_ctx.define_int("IS_DW", conf.is_depthwise);
690 kernel_ctx.define_int("BWD_DATA", 1);
691 kernel_ctx.define_int("G", conf.ngroups);
692 kernel_ctx.define_int("MB", conf.mb);
693 kernel_ctx.define_int("IC", conf.ic);
694 kernel_ctx.define_int("ICB", conf.icb);
695 kernel_ctx.define_int("ID", conf.id);
696 kernel_ctx.define_int("IH", conf.ih);
697 kernel_ctx.define_int("IW", conf.iw);
698 kernel_ctx.define_int("OC", conf.oc);
699 kernel_ctx.define_int("OD", conf.od);
700 kernel_ctx.define_int("OH", conf.oh);
701 kernel_ctx.define_int("OW", conf.ow);
702 kernel_ctx.define_int("KD", conf.kd);
703 kernel_ctx.define_int("KH", conf.kh);
704 kernel_ctx.define_int("KW", conf.kw);
705 kernel_ctx.define_int("SD", conf.stride_d);
706 kernel_ctx.define_int("SH", conf.stride_h);
707 kernel_ctx.define_int("SW", conf.stride_w);
708 kernel_ctx.define_int("PD", conf.f_pad);
709 kernel_ctx.define_int("PH", conf.t_pad);
710 kernel_ctx.define_int("PW", conf.l_pad);
711 kernel_ctx.define_int("PD_R", conf.back_pad);
712 kernel_ctx.define_int("PH_R", conf.b_pad);
713 kernel_ctx.define_int("PW_R", conf.r_pad);
714 kernel_ctx.define_int("DD", conf.dilate_d);
715 kernel_ctx.define_int("DH", conf.dilate_h);
716 kernel_ctx.define_int("DW", conf.dilate_w);
717 kernel_ctx.define_int("OC_PADDED", utils::rnd_up(conf.oc, conf.oc_block));
718 kernel_ctx.define_int("IC_PADDED", utils::rnd_up(conf.ic, conf.ic_block));
719 kernel_ctx.define_int("G_WO_PADDING", conf.ngroups_without_padding);
720 kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
721 kernel_ctx.define_int("IC_WO_PADDING", conf.ic_without_padding);
722 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
723 kernel_ctx.define_int("IH_BLOCK", conf.ih_block);
724 kernel_ctx.define_int("IW_BLOCK", conf.iw_block);
725 kernel_ctx.define_int("IWB", utils::div_up(conf.iw, conf.iw_block));
726 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
727 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
728 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
729 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
730
731 kernel_ctx.define_int("GWS_0", conf.gws_d[0]);
732 kernel_ctx.define_int("GWS_1", conf.gws_d[1]);
733 kernel_ctx.define_int("GWS_2", conf.gws_d[2]);
734
735 kernel_ctx.define_int("GWS_ORIG_0", conf.gws_orig_d[0]);
736 kernel_ctx.define_int("GWS_ORIG_1", conf.gws_orig_d[1]);
737 kernel_ctx.define_int("GWS_ORIG_2", conf.gws_orig_d[2]);
738
739 kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
740 kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
741 kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
742
743 kernel_ctx.set_data_type(conf.src_data_type);
744
745 switch (conf.ver) {
746 case ver_16mb16c: kernel_ctx.define_int("VER_16MB16C", 1); break;
747 case ver_8ow16c: kernel_ctx.define_int("VER_8OW16C", 1); break;
748 default: break;
749 }
750
751 kernel_ctx.add_option("-cl-std=CL2.0");
752
753 return status::success;
754}
755
756status_t gen9_convolution_bwd_data_t::execute_backward_data(
757 const exec_ctx_t &ctx) const {
758
759 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
760 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
761 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
762 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
763
764 const auto &conf = pd()->conf;
765
766 compute::kernel_arg_list_t arg_list;
767 arg_list.set(0, diff_src);
768 arg_list.set(1, weights);
769 arg_list.set(2, diff_dst);
770 arg_list.set(3, bias);
771
772 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
773
774 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
775
776 return status;
777}
778
779status_t gen9_convolution_bwd_weights_t::pd_t::init_conf(engine_t *engine) {
780 using namespace dnnl::impl::format_tag;
781 using namespace data_type;
782
783 const convolution_desc_t &cd = *desc();
784 const memory_desc_wrapper src_mdw(src_md());
785 const memory_desc_wrapper weights_mdw(diff_weights_md());
786 const memory_desc_wrapper dst_mdw(diff_dst_md());
787 const memory_desc_wrapper bias_mdw(diff_weights_md(1));
788
789 set_default_conf(conf, cd, *src_md(), *diff_weights_md(), *diff_dst_md(),
790 *diff_weights_md(1), *attr());
791
792 const bool is_nhwc
793 = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef
794 || dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc)
795 != format_tag::undef;
796
797 const bool is_1stconv = conf.ic_without_padding == 3;
798 const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1)
799 && (conf.oc_without_padding == 1);
800
801 conf.is_nhwc = is_nhwc;
802 conf.is_depthwise = is_depthwise;
803
804 if (is_1stconv || (conf.with_groups && conf.ngroups > 1) || is_depthwise
805 || is_nhwc) {
806 conf.ic = conf.ic_without_padding;
807 conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, 16)
808 : conf.oc_without_padding;
809 } else {
810 conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
811 conf.oc = utils::rnd_up(conf.oc_without_padding, 16);
812 }
813
814 conf.ngroups_without_padding = conf.ngroups;
815 if (is_depthwise && !is_nhwc)
816 conf.ngroups = utils::rnd_up(conf.ngroups, 16);
817 const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0);
818
819 const bool is_16ic = conf.ic % 16 == 0;
820 const bool is_16oc = conf.oc % 16 == 0;
821 const bool use_16mb_unroll = !is_nhwc
822 && !(conf.mb == 1 || conf.mb % 16 != 0) && !is_1stconv
823 && ((is_16ic && is_16oc) || is_dw_16g);
824
825 conf.mb_block = 1;
826 conf.oc_block = 1;
827 conf.ic_block = 1;
828 conf.od_block = 1;
829 conf.oh_block = 1;
830 conf.ow_block = 1;
831 conf.osp_chunk = 1;
832 conf.mb_chunk = 1;
833 if (is_nhwc)
834 conf.ver = ver_nhwc;
835 else if (use_16mb_unroll)
836 conf.ver = ver_16mb16c;
837 else if (conf.mb % 16 != 0 && ((is_16oc && is_16ic) || is_dw_16g))
838 conf.ver = ver_8ow16c;
839 else if (is_1stconv && is_16oc)
840 conf.ver = ver_1stconv;
841 else
842 return status::unimplemented;
843
844 switch (conf.ver) {
845 case ver_1stconv:
846 case ver_8ow16c:
847 case ver_nhwc:
848 conf.mb_block = 1;
849 conf.oc_block = 16;
850 conf.ic_block = is_1stconv ? 1 : 16;
851 conf.ow_block = 8;
852 break;
853 case ver_16mb16c:
854 conf.mb_block = 16;
855 conf.oc_block = 16;
856 conf.ic_block = 16;
857 conf.ow_block = 1;
858 break;
859 }
860
861 bwd_w_compute_block_sizes(conf, engine);
862
863 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
864
865 //TODO: Fix Gtests and reenable
866 const bool is_xe_hp_plus
867 = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg();
868 const bool has_non_uniform_wg
869 = compute_engine->mayiuse_non_uniform_work_groups();
870
871 conf.sub_group_size = 16;
872 conf.lws_d[0] = is_xe_hp_plus ? 32 : 16;
873 conf.lws_d[1] = 1;
874 conf.lws_d[2] = 1;
875
876 if (conf.is_depthwise) {
877 conf.gws_d[0] = utils::rnd_up(conf.ngroups, 16);
878 } else {
879 conf.gws_d[0] = is_1stconv ? conf.ocb * conf.ngroups
880 : conf.ocb * (conf.icb / 16) * conf.ngroups;
881 }
882 conf.gws_d[1] = is_1stconv && !is_nhwc
883 ? utils::div_up(conf.kh * conf.kw * conf.kd * conf.ic, 16)
884 : conf.kh * conf.kw * conf.kd;
885 conf.gws_d[2] = conf.nchunk * utils::div_up(conf.ic, conf.icb)
886 * utils::div_up(conf.oc, conf.ocb);
887
888 maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf);
889
890 format_tag_t src_tag, dst_tag, wei_tag;
891
892 switch (conf.ver) {
893 case ver_nhwc:
894 src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
895 dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc);
896 if (is_1stconv) {
897 wei_tag = conf.with_groups ? utils::pick(
898 conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
899 : utils::pick(conf.ndims - 3, Owi16o,
900 Ohwi16o, Odhwi16o);
901 } else if (conf.is_depthwise) {
902 wei_tag = utils::pick(
903 conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g);
904 } else {
905 wei_tag = conf.with_groups
906 ? utils::pick(conf.ndims - 3, gIOw16i16o, gIOhw16i16o,
907 gIOdhw16i16o)
908 : utils::pick(conf.ndims - 3, IOw16i16o, IOhw16i16o,
909 IOdhw16i16o);
910 }
911 break;
912 case ver_1stconv:
913 assert(!conf.is_depthwise);
914 src_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw);
915 dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
916 wei_tag = conf.with_groups
917 ? utils::pick(conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
918 : utils::pick(conf.ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
919 break;
920 case ver_16mb16c:
921 src_tag = utils::pick(
922 conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
923 dst_tag = utils::pick(
924 conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c);
925 wei_tag = conf.is_depthwise
926 ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
927 : (conf.with_groups ? utils::pick(conf.ndims - 3,
928 gIOw16i16o, gIOhw16i16o, gIOdhw16i16o)
929 : utils::pick(conf.ndims - 3, IOw16i16o,
930 IOhw16i16o, IOdhw16i16o));
931 break;
932 case ver_8ow16c:
933 src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
934 dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c);
935 wei_tag = conf.is_depthwise
936 ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g)
937 : (conf.with_groups ? utils::pick(conf.ndims - 3,
938 gIOw16i16o, gIOhw16i16o, gIOdhw16i16o)
939 : utils::pick(conf.ndims - 3, IOw16i16o,
940 IOhw16i16o, IOdhw16i16o));
941 break;
942 default: return status::unimplemented;
943 }
944
945 if (src_mdw.format_kind() == format_kind::any) {
946 conf.src_tag = src_tag;
947 } else {
948 conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
949 }
950 if (conf.src_tag != src_tag) return status::unimplemented;
951
952 if (weights_mdw.format_kind() == format_kind::any) {
953 conf.wei_tag = wei_tag;
954 } else {
955 conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
956 }
957 if (conf.wei_tag != wei_tag) return status::unimplemented;
958
959 if (dst_mdw.format_kind() == format_kind::any) {
960 conf.dst_tag = dst_tag;
961 } else {
962 conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
963 }
964 if (conf.dst_tag != dst_tag) return status::unimplemented;
965
966 conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw);
967 conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc);
968
969 bool ok = set_default_formats_common(
970 conf.src_tag, conf.wei_tag, conf.dst_tag);
971 if (!ok) return status::unimplemented;
972 if (is_1stconv && !is_nhwc) {
973 if (data_type::bf16 == conf.weights_data_type) {
974 conf.reorder_wei = true;
975 auto temp_wei_md = *diff_weights_md();
976 temp_wei_md.data_type = data_type::f32;
977
978 primitive_attr_t r_attr(default_attr());
979 if (!r_attr.is_initialized()) return status::out_of_memory;
980
981 CHECK(reorder_primitive_desc_create(rpd_wei_, engine, &temp_wei_md,
982 diff_weights_md(), &r_attr));
983 }
984
985 if (conf.with_bias && data_type::bf16 == conf.bias_data_type) {
986 conf.reorder_bias = true;
987 auto temp_bias_md = *diff_weights_md(1);
988 temp_bias_md.data_type = data_type::f32;
989 primitive_attr_t r_attr(default_attr());
990 if (!r_attr.is_initialized()) return status::out_of_memory;
991
992 CHECK(reorder_primitive_desc_create(rpd_bia_, engine, &temp_bias_md,
993 diff_weights_md(1), &r_attr));
994 }
995 }
996
997 return status::success;
998}
999
1000status_t gen9_convolution_bwd_weights_t::pd_t::init_kernel_ctx(
1001 compute::kernel_ctx_t &kernel_ctx) const {
1002 kernel_ctx.define_int("IS_DW", conf.is_depthwise);
1003 kernel_ctx.define_int("BWD_WEIGHTS", 1);
1004 kernel_ctx.define_int("G", conf.ngroups);
1005 kernel_ctx.define_int("MB", conf.mb);
1006 kernel_ctx.define_int("IC", conf.ic);
1007 kernel_ctx.define_int("ICB", conf.icb);
1008 kernel_ctx.define_int("ID", conf.id);
1009 kernel_ctx.define_int("IH", conf.ih);
1010 kernel_ctx.define_int("IW", conf.iw);
1011 kernel_ctx.define_int("OC", conf.oc);
1012 kernel_ctx.define_int("OCB", conf.ocb);
1013 kernel_ctx.define_int("OD", conf.od);
1014 kernel_ctx.define_int("OH", conf.oh);
1015 kernel_ctx.define_int("OW", conf.ow);
1016 kernel_ctx.define_int("KD", conf.kd);
1017 kernel_ctx.define_int("KH", conf.kh);
1018 kernel_ctx.define_int("KW", conf.kw);
1019 kernel_ctx.define_int("SD", conf.stride_d);
1020 kernel_ctx.define_int("SH", conf.stride_h);
1021 kernel_ctx.define_int("SW", conf.stride_w);
1022 kernel_ctx.define_int("PD", conf.f_pad);
1023 kernel_ctx.define_int("PH", conf.t_pad);
1024 kernel_ctx.define_int("PW", conf.l_pad);
1025 kernel_ctx.define_int("PD_R", conf.back_pad);
1026 kernel_ctx.define_int("PH_R", conf.b_pad);
1027 kernel_ctx.define_int("PW_R", conf.r_pad);
1028 kernel_ctx.define_int("DD", conf.dilate_d);
1029 kernel_ctx.define_int("DH", conf.dilate_h);
1030 kernel_ctx.define_int("DW", conf.dilate_w);
1031 kernel_ctx.define_int("OC_PADDED", conf.oc);
1032 kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
1033 kernel_ctx.define_int("G_WO_PADDING", conf.ngroups_without_padding);
1034
1035 kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
1036 kernel_ctx.define_int("ODB", conf.odb);
1037 kernel_ctx.define_int("OHB", conf.ohb);
1038 kernel_ctx.define_int("OWB", conf.owb);
1039
1040 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
1041 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
1042 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
1043 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
1044 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
1045 kernel_ctx.define_int("NCHUNK", conf.nchunk);
1046 kernel_ctx.define_int("OSP_CHUNK", conf.osp_chunk);
1047 kernel_ctx.define_int("MB_CHUNK", conf.mb_chunk);
1048 kernel_ctx.define_int(
1049 "MB_CHUNK_SIZE", utils::div_up(conf.mb, conf.mb_chunk));
1050 kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
1051
1052 kernel_ctx.define_int("GWS_0", conf.gws_d[0]);
1053 kernel_ctx.define_int("GWS_1", conf.gws_d[1]);
1054 kernel_ctx.define_int("GWS_2", conf.gws_d[2]);
1055
1056 kernel_ctx.define_int("GWS_ORIG_0", conf.gws_orig_d[0]);
1057 kernel_ctx.define_int("GWS_ORIG_1", conf.gws_orig_d[1]);
1058 kernel_ctx.define_int("GWS_ORIG_2", conf.gws_orig_d[2]);
1059
1060 kernel_ctx.define_int("LWS_0", conf.lws_d[0]);
1061 kernel_ctx.define_int("LWS_1", conf.lws_d[1]);
1062 kernel_ctx.define_int("LWS_2", conf.lws_d[2]);
1063
1064 kernel_ctx.add_option("-cl-std=CL2.0");
1065
1066 kernel_ctx.set_data_type(data_type::f32);
1067 def_data_type(kernel_ctx, src_md()->data_type, "SRC");
1068
1069 def_data_type(kernel_ctx, diff_dst_md()->data_type, "DST");
1070
1071 def_data_type(kernel_ctx,
1072 diff_weights_md(conf.with_bias ? 1 : 0)->data_type, "BIA");
1073
1074 def_data_type(kernel_ctx, data_type::f32, "WEI");
1075
1076 switch (conf.ver) {
1077 case ver_16mb16c: kernel_ctx.define_int("VER_16MB16C", 1); break;
1078 case ver_1stconv:
1079 case ver_8ow16c: kernel_ctx.define_int("VER_8OW16C", 1); break;
1080 default: break;
1081 }
1082
1083 return status::success;
1084}
1085
1086status_t gen9_convolution_bwd_weights_t::pd_t::init_scratchpad() {
1087 auto scratchpad = scratchpad_registry().registrar();
1088 if (!conf.reorder_wei && !conf.reorder_bias) return status::success;
1089 if (conf.reorder_wei) {
1090 auto temp_wei_md = *diff_weights_md();
1091 temp_wei_md.data_type = data_type::f32;
1092 memory_desc_wrapper wei_md_d(temp_wei_md);
1093 scratchpad.book(memory_tracking::names::key_conv_bwd_w_1st_wei_reorder,
1094 wei_md_d.size(), 1, OCL_BUFFER_ALIGNMENT);
1095 scratchpad.book(memory_tracking::names::key_nested_multiple,
1096 rpd_wei_->scratchpad_registry());
1097 }
1098 if (!conf.reorder_bias) return status::success;
1099 auto temp_bias_md = *diff_weights_md(1);
1100 temp_bias_md.data_type = data_type::f32;
1101 memory_desc_wrapper bia_md_d(temp_bias_md);
1102 scratchpad.book(memory_tracking::names::key_conv_bwd_w_1st_bia_reorder,
1103 bia_md_d.size(), 1, OCL_BUFFER_ALIGNMENT);
1104 scratchpad.book(memory_tracking::names::key_nested_multiple + 1,
1105 rpd_bia_->scratchpad_registry());
1106
1107 return status::success;
1108}
1109
1110status_t gen9_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
1111
1112 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
1113 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
1114 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
1115 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
1116
1117 const auto &conf = pd()->conf;
1118
1119 compute::kernel_arg_list_t arg_list;
1120 arg_list.set(0, src);
1121 arg_list.set(1, weights);
1122 arg_list.set(2, bias);
1123 arg_list.set(3, dst);
1124 append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_);
1125
1126 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
1127
1128 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
1129
1130 if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) {
1131 ctx.zero_pad_output(DNNL_ARG_DST);
1132 }
1133 return status;
1134}
1135
1136status_t gen9_convolution_bwd_weights_t::execute_backward_weights(
1137 const exec_ctx_t &ctx) const {
1138 auto *compute_stream
1139 = utils::downcast<compute::compute_stream_t *>(ctx.stream());
1140
1141 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
1142 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
1143 auto &diff_weights = CTX_OUT_STORAGE(DNNL_ARG_DIFF_WEIGHTS);
1144 auto &diff_bias = CTX_OUT_STORAGE(DNNL_ARG_DIFF_BIAS);
1145
1146 const auto &conf = pd()->conf;
1147
1148 const uint8_t zero = 0;
1149 std::unique_ptr<memory_t> wspace_wei;
1150 std::unique_ptr<memory_t> wspace_bia;
1151 auto temp_wei_md = *pd()->diff_weights_md();
1152 auto temp_bia_md = *pd()->diff_weights_md(1);
1153 std::unique_ptr<memory_storage_t> wspace_ptr_wei;
1154 std::unique_ptr<memory_storage_t> wspace_ptr_bia;
1155 if (conf.reorder_wei) {
1156 wspace_ptr_wei = ctx.get_scratchpad_grantor().get_memory_storage(
1157 memory_tracking::names::key_conv_bwd_w_1st_wei_reorder);
1158
1159 temp_wei_md.data_type = data_type::f32;
1160 }
1161 if (conf.reorder_bias) {
1162 wspace_ptr_bia = ctx.get_scratchpad_grantor().get_memory_storage(
1163 memory_tracking::names::key_conv_bwd_w_1st_bia_reorder);
1164
1165 temp_bia_md.data_type = data_type::f32;
1166 }
1167
1168 memory_desc_wrapper wei_mdw(temp_wei_md);
1169 CHECK(compute_stream->fill(
1170 conf.reorder_wei ? *wspace_ptr_wei : diff_weights, zero,
1171 wei_mdw.size()));
1172 if (conf.with_bias) {
1173 memory_desc_wrapper bia_mdw(temp_bia_md);
1174 CHECK(compute_stream->fill(
1175 conf.reorder_bias ? *wspace_ptr_bia : diff_bias, zero,
1176 bia_mdw.size()));
1177 }
1178
1179 compute::kernel_arg_list_t arg_list;
1180 arg_list.set(0, src);
1181 arg_list.set(1, conf.reorder_wei ? *wspace_ptr_wei : diff_weights);
1182 arg_list.set(2, conf.reorder_bias ? *wspace_ptr_bia : diff_bias);
1183 arg_list.set(3, diff_dst);
1184
1185 status_t status = parallel_for(ctx,
1186 compute::nd_range_t(conf.gws_d, conf.lws_d), kernel_, arg_list);
1187 if (status != status::success) return status;
1188 auto exec_reorder = [&](memory_t *in, memory_t *out,
1189 const std::shared_ptr<primitive_t> &prim,
1190 int r_num) -> status_t {
1191 exec_args_t r_args;
1192 r_args[DNNL_ARG_FROM] = memory_arg_t {in, true};
1193 r_args[DNNL_ARG_TO] = memory_arg_t {out, false};
1194 exec_ctx_t r_ctx(ctx, std::move(r_args));
1195 nested_scratchpad_t ns(
1196 ctx, memory_tracking::names::key_nested_multiple + r_num, prim);
1197 r_ctx.set_scratchpad_grantor(ns.grantor());
1198 return prim->execute(r_ctx);
1199 };
1200
1201 if (conf.reorder_wei) {
1202 CHECK(safe_ptr_assign(wspace_wei,
1203 new memory_t(ctx.stream()->engine(), &temp_wei_md,
1204 std::move(wspace_ptr_wei))));
1205 CHECK(exec_reorder(wspace_wei.get(), ctx.output(DNNL_ARG_DIFF_WEIGHTS),
1206 wei_reorder_, 0));
1207 }
1208 if (conf.reorder_bias) {
1209 CHECK(safe_ptr_assign(wspace_bia,
1210 new memory_t(ctx.stream()->engine(), &temp_bia_md,
1211 std::move(wspace_ptr_bia))));
1212 CHECK(exec_reorder(wspace_bia.get(), ctx.output(DNNL_ARG_DIFF_BIAS),
1213 bia_reorder_, 1));
1214 }
1215
1216 return status::success;
1217}
1218
1219} // namespace ocl
1220} // namespace gpu
1221} // namespace impl
1222} // namespace dnnl
1223
1224// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1225