1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gen9_convolution.hpp" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_traits.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/reorder.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "gpu/ocl/ocl_memory_storage.hpp" |
25 | #include "gpu/primitive_conf.hpp" |
26 | |
27 | using namespace dnnl::impl::memory_tracking::names; |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace gpu { |
32 | namespace ocl { |
33 | |
34 | using namespace dnnl::impl::data_type; |
35 | using namespace dnnl::impl::format_tag; |
36 | |
37 | static void fwd_compute_block_sizes( |
38 | conv_conf_t &conf, const convolution_pd_t *pd) { |
39 | |
40 | int max_ow_block = (conf.src_data_type == data_type::f16 ? 20 : 16); |
41 | if (conf.ver == ver_16mb16c || conf.ver == ver_32mb16c) { |
42 | max_ow_block = 1; |
43 | } else if (conf.is_depthwise || conf.ver == ver_1stconv) { |
44 | max_ow_block = 8; |
45 | } |
46 | max_ow_block = nstl::min(conf.ow, max_ow_block); |
47 | |
48 | if (conf.ver == ver_16mb16c) { |
49 | conf.mb_block |
50 | = (conf.src_data_type == data_type::f16 && !conf.is_depthwise) |
51 | ? (conf.mb % 32 == 0 ? 32 : 16) |
52 | : 16; |
53 | } else if (conf.ver == ver_32mb16c) { |
54 | conf.mb_block = 32; |
55 | } else { |
56 | conf.mb_block = 1; |
57 | } |
58 | |
59 | conf.ow_block = utils::max_div(conf.ow, max_ow_block); |
60 | |
61 | if (conf.ow_block < max_ow_block / 2) { |
62 | float min_tail_ratio = 1; |
63 | int best_ow_block = -1; |
64 | for (int ow_block = 8; ow_block <= max_ow_block; ow_block++) { |
65 | float tail_ratio |
66 | = (ow_block - (conf.ow % ow_block)) / (float)conf.ow; |
67 | if (tail_ratio <= min_tail_ratio) { |
68 | min_tail_ratio = tail_ratio; |
69 | best_ow_block = ow_block; |
70 | } |
71 | } |
72 | assert(best_ow_block > 0); |
73 | conf.ow_block = best_ow_block; |
74 | } |
75 | |
76 | if (conf.is_depthwise) { |
77 | conf.oc_block = 16; |
78 | conf.ic_block = 16; |
79 | conf.omb = conf.mb_block; |
80 | return; |
81 | } |
82 | |
83 | if (conf.ver == ver_1stconv && conf.mb_block == 1 && conf.oc % 32 == 0) { |
84 | conf.oc_block = 32; |
85 | } else { |
86 | conf.oc_block = 16; |
87 | } |
88 | conf.ic_block = nstl::min(conf.ic, 16); |
89 | |
90 | conf.omb = (conf.mb_block == 1 && conf.mb % 16 == 0) ? 16 : conf.mb_block; |
91 | conf.ocb = utils::max_div(conf.oc / 16, 8) * 16; |
92 | } |
93 | |
94 | status_t gen9_convolution_fwd_t::pd_t::init_conf(engine_t *engine) { |
95 | |
96 | const convolution_desc_t &cd = *desc(); |
97 | const memory_desc_wrapper src_mdw(src_md()); |
98 | const memory_desc_wrapper weights_mdw(weights_md()); |
99 | const memory_desc_wrapper dst_mdw(dst_md()); |
100 | const memory_desc_wrapper bias_mdw(weights_md(1)); |
101 | |
102 | set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(), |
103 | *weights_md(1), *attr()); |
104 | |
105 | const bool int8_dst = conf.dst_data_type == data_type::s8; |
106 | const bool is_src_nhwc |
107 | = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef; |
108 | const bool is_dst_nhwc |
109 | = dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef; |
110 | const bool is_nhwc = is_src_nhwc || is_dst_nhwc; |
111 | |
112 | const bool is_1stconv = conf.ic_without_padding == 3; |
113 | const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1) |
114 | && (conf.oc_without_padding == 1); |
115 | |
116 | conf.is_nhwc = is_1stconv ? is_dst_nhwc : is_nhwc; |
117 | conf.is_depthwise = is_depthwise; |
118 | |
119 | const int out_block = int8_dst && !is_1stconv ? 32 : 16; |
120 | if (is_1stconv || (conf.with_groups && conf.ngroups > 1)) { |
121 | conf.ic = conf.ic_without_padding; |
122 | conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, out_block) |
123 | : conf.oc_without_padding; |
124 | } else { |
125 | conf.ic = utils::rnd_up(conf.ic_without_padding, 16); |
126 | conf.oc = utils::rnd_up(conf.oc_without_padding, out_block); |
127 | } |
128 | |
129 | conf.ngroups_without_padding = conf.ngroups; |
130 | if (is_depthwise) |
131 | conf.ngroups = utils::rnd_up(conf.ngroups, int8_dst ? 32 : 16); |
132 | |
133 | const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0); |
134 | const bool is_16oc = conf.oc % out_block == 0; |
135 | const bool is_16ic = conf.ic % 16 == 0; |
136 | |
137 | conf.mb_block = 1; |
138 | conf.oc_block = 1; |
139 | conf.ic_block = 1; |
140 | conf.od_block = 1; |
141 | conf.oh_block = 1; |
142 | conf.ow_block = 1; |
143 | conf.omb = 1; |
144 | conf.ocb = 1; |
145 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
146 | const bool is_xe_hp_plus |
147 | = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg(); |
148 | const bool has_non_uniform_wg |
149 | = compute_engine->mayiuse_non_uniform_work_groups(); |
150 | |
151 | if (conf.is_nhwc) { |
152 | if (!utils::one_of(src_mdw.data_type(), f32, f16)) |
153 | return status::unimplemented; |
154 | if (conf.is_depthwise && conf.ngroups_without_padding % 16) |
155 | return status::unimplemented; |
156 | // TODO: Add group convolution support in NHWC kernel. |
157 | if (!conf.is_depthwise && conf.ngroups > 1 && !(is_16oc && is_16ic)) { |
158 | return status::unimplemented; |
159 | } |
160 | if (int8_dst) { return status::unimplemented; } |
161 | conf.ver = ver_nhwc; |
162 | } else if (is_1stconv) { |
163 | if (!is_16oc) return status::unimplemented; |
164 | conf.ver = ver_1stconv; |
165 | } else if ((is_16oc && is_16ic) || is_dw_16g) { |
166 | if (conf.mb % 32 == 0 && conf.is_depthwise |
167 | && utils::one_of(src_mdw.data_type(), bf16, f16) |
168 | && is_xe_hp_plus) { |
169 | conf.ver = ver_32mb16c; |
170 | } else { |
171 | conf.ver = (conf.mb % out_block == 0) ? ver_16mb16c : ver_8ow16c; |
172 | } |
173 | } else { |
174 | return status::unimplemented; |
175 | } |
176 | |
177 | const bool is_fp16 = src_mdw.data_type() == data_type::f16; |
178 | |
179 | switch (conf.ver) { |
180 | case ver_nhwc: { |
181 | conf.mb_block = 1; |
182 | conf.oc_block = 16; |
183 | conf.ic_block = is_1stconv ? 1 : 16; |
184 | |
185 | int max_ow_block = (conf.kw > 1) ? 8 : 16; |
186 | if (conf.oc <= 64 && conf.ic <= 64) max_ow_block = 8; |
187 | |
188 | conf.ow_block = utils::max_div(conf.ow, max_ow_block); |
189 | |
190 | if (conf.ow_block <= 8) { |
191 | int max_tail = 0; |
192 | for (int j = 8; j < max_ow_block; j++) { |
193 | if (conf.ow % j > max_tail) { |
194 | max_tail = conf.ow % j; |
195 | conf.ow_block = j; |
196 | } |
197 | } |
198 | } |
199 | if (conf.ow_block <= 8) conf.ow_block = 8; |
200 | if (conf.ow <= 8 || conf.oc <= 32) conf.ow_block = 8; |
201 | |
202 | conf.oh_block = 1; |
203 | conf.sub_group_size = 16; |
204 | conf.lws_d[0] = 16; |
205 | conf.lws_d[1] = 1; |
206 | conf.lws_d[2] = 1; |
207 | |
208 | int max_oc_block = 8; |
209 | if (conf.is_depthwise) { |
210 | conf.ocb = conf.ngroups; |
211 | } else { |
212 | conf.ocb = conf.oc_block |
213 | * utils::max_div(utils::div_up(conf.oc, conf.oc_block), |
214 | max_oc_block); |
215 | } |
216 | |
217 | conf.gws_d[0] = conf.ocb; |
218 | conf.gws_d[1] = utils::div_up(conf.oh, conf.oh_block) |
219 | * utils::div_up(conf.ow, conf.ow_block) * conf.od; |
220 | if (conf.is_depthwise) { |
221 | conf.gws_d[2] = conf.mb; |
222 | } else { |
223 | conf.gws_d[2] = conf.mb * utils::div_up(conf.oc, conf.ocb) |
224 | * conf.ngroups; |
225 | } |
226 | } break; |
227 | case ver_1stconv: |
228 | case ver_8ow16c: |
229 | case ver_16mb16c: |
230 | case ver_32mb16c: { |
231 | fwd_compute_block_sizes(conf, this); |
232 | conf.sub_group_size = 16; |
233 | conf.gws_d[0] = conf.ngroups * conf.ocb / (conf.oc_block / 16); |
234 | conf.gws_d[1] |
235 | = (conf.od * conf.oh * utils::div_up(conf.ow, conf.ow_block) |
236 | * (conf.omb / conf.mb_block)); |
237 | conf.gws_d[2] = (conf.oc / conf.ocb) * (conf.mb / conf.omb); |
238 | conf.lws_d[0] = is_xe_hp_plus ? 32 : 16; |
239 | conf.lws_d[1] = 1; |
240 | conf.lws_d[2] = 1; |
241 | break; |
242 | } |
243 | default: return status::unimplemented; |
244 | } |
245 | |
246 | maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf); |
247 | |
248 | format_tag_t src_tag, dst_tag, wei_tag; |
249 | |
250 | switch (conf.ver) { |
251 | case ver_nhwc: |
252 | src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
253 | dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
254 | if (is_1stconv) { |
255 | wei_tag = conf.with_groups ? utils::pick( |
256 | conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) |
257 | : utils::pick(conf.ndims - 3, Owi16o, |
258 | Ohwi16o, Odhwi16o); |
259 | } else if (conf.is_depthwise) { |
260 | wei_tag = utils::pick( |
261 | conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g); |
262 | } else { |
263 | wei_tag = conf.with_groups |
264 | ? utils::pick(conf.ndims - 3, gOIw16i16o, gOIhw16i16o, |
265 | gOIdhw16i16o) |
266 | : utils::pick(conf.ndims - 3, OIw16i16o, OIhw16i16o, |
267 | OIdhw16i16o); |
268 | } |
269 | break; |
270 | case ver_1stconv: |
271 | if (is_src_nhwc) |
272 | src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
273 | else |
274 | src_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw); |
275 | |
276 | if (is_xe_hp_plus && is_fp16) { |
277 | dst_tag = conf.mb % 32 == 0 ? utils::pick(conf.ndims - 3, |
278 | NCw32n16c, NChw32n16c, NCdhw32n16c) |
279 | : utils::pick(conf.ndims - 3, |
280 | nCw16c, nChw16c, nCdhw16c); |
281 | } else { |
282 | dst_tag = conf.mb % 16 == 0 ? utils::pick(conf.ndims - 3, |
283 | NCw16n16c, NChw16n16c, NCdhw16n16c) |
284 | : utils::pick(conf.ndims - 3, |
285 | nCw16c, nChw16c, nCdhw16c); |
286 | } |
287 | wei_tag = conf.with_groups |
288 | ? utils::pick(conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) |
289 | : utils::pick(conf.ndims - 3, Owi16o, Ohwi16o, Odhwi16o); |
290 | break; |
291 | case ver_16mb16c: |
292 | src_tag = utils::pick( |
293 | conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c); |
294 | dst_tag = utils::pick( |
295 | conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c); |
296 | wei_tag = conf.is_depthwise |
297 | ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g) |
298 | : (conf.with_groups ? utils::pick(conf.ndims - 3, |
299 | gIOw16i16o, gIOhw16i16o, gIOdhw16i16o) |
300 | : utils::pick(conf.ndims - 3, IOw16i16o, |
301 | IOhw16i16o, IOdhw16i16o)); |
302 | break; |
303 | case ver_32mb16c: |
304 | src_tag = utils::pick( |
305 | conf.ndims - 3, NCw32n16c, NChw32n16c, NCdhw32n16c); |
306 | dst_tag = utils::pick( |
307 | conf.ndims - 3, NCw32n16c, NChw32n16c, NCdhw32n16c); |
308 | wei_tag = conf.is_depthwise |
309 | ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g) |
310 | : (conf.with_groups ? utils::pick(conf.ndims - 3, |
311 | gIOw16i16o, gIOhw16i16o, gIOdhw16i16o) |
312 | : utils::pick(conf.ndims - 3, IOw16i16o, |
313 | IOhw16i16o, IOdhw16i16o)); |
314 | break; |
315 | case ver_8ow16c: |
316 | src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c); |
317 | dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c); |
318 | wei_tag = conf.is_depthwise |
319 | ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g) |
320 | : (conf.with_groups ? utils::pick(conf.ndims - 3, |
321 | gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) |
322 | : utils::pick(conf.ndims - 3, OIw16i16o, |
323 | OIhw16i16o, OIdhw16i16o)); |
324 | break; |
325 | default: return status::unimplemented; |
326 | } |
327 | if (int8_dst) { |
328 | if (is_1stconv && conf.ic_without_padding < 4) { |
329 | dst_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw); |
330 | } else if (conf.ver == ver_16mb16c || conf.ver == ver_32mb16c) { |
331 | dst_tag = utils::pick( |
332 | conf.ndims - 3, NCw32n32c, NChw32n32c, NCdhw32n32c); |
333 | } else { |
334 | dst_tag = utils::pick(conf.ndims - 3, nCw32c, nChw32c, nCdhw32c); |
335 | } |
336 | } |
337 | |
338 | if (src_mdw.format_kind() == format_kind::any) { |
339 | conf.src_tag = src_tag; |
340 | } else { |
341 | conf.src_tag = src_mdw.matches_one_of_tag(src_tag); |
342 | } |
343 | if (conf.src_tag != src_tag) return status::unimplemented; |
344 | |
345 | if (weights_mdw.format_kind() == format_kind::any) { |
346 | conf.wei_tag = wei_tag; |
347 | } else { |
348 | conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag); |
349 | } |
350 | if (conf.wei_tag != wei_tag) return status::unimplemented; |
351 | |
352 | if (dst_mdw.format_kind() == format_kind::any) { |
353 | conf.dst_tag = dst_tag; |
354 | } else { |
355 | conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag); |
356 | } |
357 | if (conf.dst_tag != dst_tag) return status::unimplemented; |
358 | |
359 | conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw); |
360 | conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc); |
361 | |
362 | return status::success; |
363 | } |
364 | |
365 | status_t gen9_convolution_fwd_t::pd_t::init_kernel_ctx( |
366 | compute::kernel_ctx_t &kernel_ctx) const { |
367 | kernel_ctx.define_int("IS_DW" , conf.is_depthwise); |
368 | kernel_ctx.define_int("G" , conf.ngroups); |
369 | kernel_ctx.define_int("MB" , conf.mb); |
370 | kernel_ctx.define_int("IC" , conf.ic); |
371 | kernel_ctx.define_int("ID" , conf.id); |
372 | kernel_ctx.define_int("IH" , conf.ih); |
373 | kernel_ctx.define_int("IW" , conf.iw); |
374 | kernel_ctx.define_int("OC" , conf.oc); |
375 | kernel_ctx.define_int("OD" , conf.od); |
376 | kernel_ctx.define_int("OH" , conf.oh); |
377 | kernel_ctx.define_int("OW" , conf.ow); |
378 | kernel_ctx.define_int("KD" , conf.kd); |
379 | kernel_ctx.define_int("KH" , conf.kh); |
380 | kernel_ctx.define_int("KW" , conf.kw); |
381 | kernel_ctx.define_int("SD" , conf.stride_d); |
382 | kernel_ctx.define_int("SH" , conf.stride_h); |
383 | kernel_ctx.define_int("SW" , conf.stride_w); |
384 | kernel_ctx.define_int("PD" , conf.f_pad); |
385 | kernel_ctx.define_int("PH" , conf.t_pad); |
386 | kernel_ctx.define_int("PW" , conf.l_pad); |
387 | kernel_ctx.define_int("PD_R" , conf.back_pad); |
388 | kernel_ctx.define_int("PH_R" , conf.b_pad); |
389 | kernel_ctx.define_int("PW_R" , conf.r_pad); |
390 | kernel_ctx.define_int("DD" , conf.dilate_d); |
391 | kernel_ctx.define_int("DH" , conf.dilate_h); |
392 | kernel_ctx.define_int("DW" , conf.dilate_w); |
393 | kernel_ctx.define_int("OW_PADDED" , utils::rnd_up(conf.ow, 4)); |
394 | kernel_ctx.define_int("OC_PADDED" , conf.oc); |
395 | kernel_ctx.define_int("OMB" , conf.omb); |
396 | kernel_ctx.define_int("OCB" , conf.ocb); |
397 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
398 | kernel_ctx.define_int("OH_BLOCK" , conf.oh_block); |
399 | kernel_ctx.define_int("OW_BLOCK" , conf.ow_block); |
400 | kernel_ctx.define_int("OW_LAST" , utils::rnd_dn(conf.ow, conf.ow_block)); |
401 | kernel_ctx.define_int("OWB" , utils::div_up(conf.ow, conf.ow_block)); |
402 | kernel_ctx.define_int("OHB" , utils::div_up(conf.oh, conf.oh_block)); |
403 | kernel_ctx.define_int("WITH_BIAS" , conf.with_bias); |
404 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
405 | kernel_ctx.define_int("OC_BLOCK" , conf.oc_block); |
406 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
407 | kernel_ctx.define_int("G_WO_PADDING" , conf.ngroups_without_padding); |
408 | kernel_ctx.define_int("IC_WO_PADDING" , conf.ic_without_padding); |
409 | kernel_ctx.define_int("OC_WO_PADDING" , conf.oc_without_padding); |
410 | kernel_ctx.define_int("OC_GROUP" , conf.lws_d[0] / 8); |
411 | kernel_ctx.define_int("MB_GROUP" , 1); |
412 | kernel_ctx.define_int("SP_GROUP" , conf.lws_d[1]); |
413 | if (conf.kw == 1) |
414 | kernel_ctx.define_int("SRC_SP_GROUP" , conf.lws_d[1] + conf.kw - 1); |
415 | else |
416 | kernel_ctx.define_int( |
417 | "SRC_SP_GROUP" , conf.stride_w * (conf.lws_d[1] - 1) + conf.kw); |
418 | |
419 | kernel_ctx.set_data_type(conf.src_data_type); |
420 | def_data_type(kernel_ctx, conf.dst_data_type, "DST" ); |
421 | |
422 | kernel_ctx.define_int("VER_1STCONV" , conf.ver == ver_1stconv); |
423 | kernel_ctx.define_int("VER_8OW16C" , conf.ver == ver_8ow16c); |
424 | kernel_ctx.define_int("VER_16MB16C" , conf.ver == ver_16mb16c); |
425 | kernel_ctx.define_int("VER_32MB16C" , conf.ver == ver_32mb16c); |
426 | |
427 | kernel_ctx.define_int("SRC_NCHW" , conf.is_src_nchw); |
428 | kernel_ctx.define_int("SRC_NHWC" , conf.is_src_nhwc); |
429 | kernel_ctx.define_int("SRC_16N16C" , |
430 | utils::one_of(conf.src_tag, NCw16n16c, NChw16n16c, NCdhw16n16c)); |
431 | kernel_ctx.define_int( |
432 | "SRC_W16C" , utils::one_of(conf.src_tag, nCw16c, nChw16c, nCdhw16c)); |
433 | |
434 | kernel_ctx.define_int("WEI_I16O" , |
435 | utils::one_of(conf.wei_tag, gOwi16o, gOhwi16o, gOdhwi16o, Owi16o, |
436 | Ohwi16o, Odhwi16o)); |
437 | kernel_ctx.define_int("WEI_16I16O" , |
438 | utils::one_of(conf.wei_tag, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o, |
439 | OIw16i16o, OIhw16i16o, OIdhw16i16o)); |
440 | kernel_ctx.define_int("WEI_16I16O_FLIPPED" , |
441 | utils::one_of(conf.wei_tag, gIOw16i16o, gIOhw16i16o, gIOdhw16i16o, |
442 | IOw16i16o, IOhw16i16o, IOdhw16i16o)); |
443 | |
444 | kernel_ctx.define_int( |
445 | "DST_W16C" , utils::one_of(conf.dst_tag, nCw16c, nChw16c, nCdhw16c)); |
446 | kernel_ctx.define_int("DST_16N16C" , |
447 | utils::one_of(conf.dst_tag, NCw16n16c, NChw16n16c, NCdhw16n16c)); |
448 | kernel_ctx.define_int("DST_32N16C" , |
449 | utils::one_of(conf.dst_tag, NCw32n16c, NChw32n16c, NCdhw32n16c)); |
450 | kernel_ctx.define_int("DST_32N32C" , |
451 | utils::one_of(conf.dst_tag, NCw32n32c, NChw32n32c, NCdhw32n32c)); |
452 | kernel_ctx.define_int( |
453 | "DST_W32C" , utils::one_of(conf.dst_tag, nCw32c, nChw32c, nCdhw32c)); |
454 | kernel_ctx.define_int( |
455 | "DST_NCHW" , utils::one_of(conf.dst_tag, ncw, nchw, ncdhw)); |
456 | |
457 | kernel_ctx.define_int("GWS_0" , conf.gws_d[0]); |
458 | kernel_ctx.define_int("GWS_1" , conf.gws_d[1]); |
459 | kernel_ctx.define_int("GWS_2" , conf.gws_d[2]); |
460 | |
461 | kernel_ctx.define_int("GWS_ORIG_0" , conf.gws_orig_d[0]); |
462 | kernel_ctx.define_int("GWS_ORIG_1" , conf.gws_orig_d[1]); |
463 | kernel_ctx.define_int("GWS_ORIG_2" , conf.gws_orig_d[2]); |
464 | |
465 | kernel_ctx.define_int("LWS_0" , conf.lws_d[0]); |
466 | kernel_ctx.define_int("LWS_1" , conf.lws_d[1]); |
467 | kernel_ctx.define_int("LWS_2" , conf.lws_d[2]); |
468 | |
469 | dnnl_dims_t dst_dims; |
470 | dst_dims[0] = conf.mb; |
471 | dst_dims[1] = conf.ngroups_without_padding * conf.oc_without_padding; |
472 | dst_dims[2] = conf.ndims > 4 ? conf.od : conf.oh; |
473 | dst_dims[3] = conf.ndims > 4 ? conf.oh : conf.ow; |
474 | dst_dims[4] = conf.ow; |
475 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
476 | def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_, &dst_dims); |
477 | |
478 | kernel_ctx.print_options(); |
479 | return status::success; |
480 | } |
481 | |
482 | status_t gen9_convolution_bwd_data_t::pd_t::init_conf(engine_t *engine) { |
483 | using namespace dnnl::impl::format_tag; |
484 | using namespace data_type; |
485 | |
486 | const convolution_desc_t &cd = *desc(); |
487 | const memory_desc_wrapper src_mdw(diff_src_md()); |
488 | const memory_desc_wrapper weights_mdw(weights_md()); |
489 | const memory_desc_wrapper dst_mdw(diff_dst_md()); |
490 | const memory_desc_wrapper bias_mdw(weights_md(1)); |
491 | |
492 | set_default_conf(conf, cd, *diff_src_md(), *weights_md(), *diff_dst_md(), |
493 | *weights_md(1), *attr()); |
494 | const bool is_nhwc |
495 | = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef |
496 | || dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) |
497 | != format_tag::undef; |
498 | const bool is_1stconv = conf.ic_without_padding == 3; |
499 | const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1) |
500 | && (conf.oc_without_padding == 1); |
501 | conf.is_nhwc = is_nhwc; |
502 | conf.is_depthwise = is_depthwise; |
503 | |
504 | if (is_nhwc && (is_depthwise || is_1stconv)) return status::unimplemented; |
505 | |
506 | if (is_1stconv || (conf.with_groups && conf.ngroups > 1) || is_depthwise) { |
507 | conf.ic = conf.ic_without_padding; |
508 | conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, 16) |
509 | : conf.oc_without_padding; |
510 | } else { |
511 | conf.ic = utils::rnd_up(conf.ic_without_padding, 16); |
512 | conf.oc = utils::rnd_up(conf.oc_without_padding, 16); |
513 | } |
514 | conf.ngroups_without_padding = conf.ngroups; |
515 | if (is_depthwise) conf.ngroups = utils::rnd_up(conf.ngroups, 16); |
516 | const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0); |
517 | |
518 | const bool is_16ic = conf.ic % 16 == 0; |
519 | const bool is_16oc = conf.oc % 16 == 0; |
520 | const bool use_16mb_unroll = !is_nhwc |
521 | && !(conf.mb == 1 || conf.mb % 16 != 0) && !is_1stconv |
522 | && ((is_16ic && is_16oc) || is_dw_16g); |
523 | conf.mb_block = 1; |
524 | conf.oc_block = 1; |
525 | conf.ic_block = 1; |
526 | conf.od_block = 1; |
527 | conf.oh_block = 1; |
528 | conf.ow_block = 1; |
529 | conf.icb = 1; |
530 | if (is_nhwc) |
531 | conf.ver = ver_nhwc; |
532 | else if (use_16mb_unroll) |
533 | conf.ver = ver_16mb16c; |
534 | else if (conf.mb % 16 != 0 && ((is_16oc && is_16ic) || is_dw_16g)) |
535 | conf.ver = ver_8ow16c; |
536 | else |
537 | return status::unimplemented; |
538 | |
539 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
540 | //TODO: Fix Gtests and reenable |
541 | const bool is_xe_hp_plus |
542 | = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg(); |
543 | const bool has_non_uniform_wg |
544 | = compute_engine->mayiuse_non_uniform_work_groups(); |
545 | |
546 | status_t status = status::success; |
547 | switch (conf.ver) { |
548 | case ver_16mb16c: |
549 | conf.mb_block = 16; |
550 | conf.oc_block = 16; |
551 | conf.ic_block = 16; |
552 | conf.od_block = 1; |
553 | conf.ih_block = 1; |
554 | conf.iw_block = 1; |
555 | conf.sub_group_size = 16; |
556 | if (conf.is_depthwise) { |
557 | conf.icb = conf.ngroups; |
558 | conf.lws_d[0] = 1; |
559 | conf.lws_d[1] = is_xe_hp_plus ? 32 : 16; |
560 | conf.lws_d[2] = 1; |
561 | conf.gws_d[0] = conf.ih * conf.iw * conf.id; |
562 | conf.gws_d[1] = conf.ic * conf.ngroups; |
563 | conf.gws_d[2] = conf.mb / 16; |
564 | } else { |
565 | conf.icb = 64; |
566 | while (conf.icb > 16) { |
567 | if (conf.ic % conf.icb == 0) break; |
568 | conf.icb /= 2; |
569 | } |
570 | conf.lws_d[0] = is_xe_hp_plus ? 32 : 16; |
571 | conf.lws_d[1] = 1; |
572 | conf.lws_d[2] = 1; |
573 | conf.gws_d[0] = conf.icb; |
574 | conf.gws_d[1] = conf.ih * conf.iw * conf.id; |
575 | conf.gws_d[2] |
576 | = conf.mb / 16 * (conf.ic / conf.icb) * conf.ngroups; |
577 | } |
578 | break; |
579 | case ver_8ow16c: |
580 | case ver_nhwc: { |
581 | conf.mb_block = 1; |
582 | conf.oc_block = 16; |
583 | conf.ic_block = 16; |
584 | conf.od_block = 1; |
585 | conf.ih_block = 1; |
586 | int max_iw_block = 16; |
587 | if (conf.ver == ver_nhwc) { max_iw_block = (conf.kw > 1) ? 8 : 16; } |
588 | conf.iw_block = nstl::max(8, utils::max_div(conf.iw, max_iw_block)); |
589 | conf.sub_group_size = 16; |
590 | if (conf.is_depthwise) { |
591 | conf.icb = conf.ngroups; |
592 | conf.lws_d[0] = 1; |
593 | conf.lws_d[1] = conf.ic_block; |
594 | conf.lws_d[2] = 1; |
595 | conf.gws_d[0] = conf.ih * utils::div_up(conf.iw, conf.iw_block) |
596 | * conf.id; |
597 | conf.gws_d[1] = conf.ic * conf.ngroups; |
598 | conf.gws_d[2] = conf.mb; |
599 | } else { |
600 | conf.icb = 64; |
601 | while (conf.icb > conf.ic_block) { |
602 | if (utils::rnd_up(conf.ic, conf.ic_block) % conf.icb == 0) |
603 | break; |
604 | conf.icb /= 2; |
605 | } |
606 | conf.lws_d[0] = conf.ic_block; |
607 | conf.lws_d[1] = 1; |
608 | conf.lws_d[2] = 1; |
609 | conf.gws_d[0] = conf.icb; |
610 | conf.gws_d[1] = conf.ih * utils::div_up(conf.iw, conf.iw_block) |
611 | * conf.id; |
612 | conf.gws_d[2] = conf.mb |
613 | * (utils::rnd_up(conf.ic, conf.ic_block) / conf.icb) |
614 | * conf.ngroups; |
615 | } |
616 | break; |
617 | } |
618 | default: status = status::unimplemented; |
619 | } |
620 | |
621 | maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf); |
622 | |
623 | format_tag_t src_tag, dst_tag, wei_tag; |
624 | |
625 | switch (conf.ver) { |
626 | case ver_nhwc: |
627 | src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
628 | dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
629 | wei_tag = conf.with_groups ? utils::pick(conf.ndims - 3, gOIw16o16i, |
630 | gOIhw16o16i, gOIdhw16o16i) |
631 | : utils::pick(conf.ndims - 3, OIw16o16i, |
632 | OIhw16o16i, OIdhw16o16i); |
633 | break; |
634 | case ver_16mb16c: |
635 | src_tag = utils::pick( |
636 | conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c); |
637 | dst_tag = utils::pick( |
638 | conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c); |
639 | wei_tag = conf.is_depthwise |
640 | ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g) |
641 | : (conf.with_groups ? utils::pick(conf.ndims - 3, |
642 | gOIw16o16i, gOIhw16o16i, gOIdhw16o16i) |
643 | : utils::pick(conf.ndims - 3, OIw16o16i, |
644 | OIhw16o16i, OIdhw16o16i)); |
645 | break; |
646 | case ver_8ow16c: |
647 | src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c); |
648 | dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c); |
649 | wei_tag = conf.is_depthwise |
650 | ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g) |
651 | : (conf.with_groups ? utils::pick(conf.ndims - 3, |
652 | gOIw16o16i, gOIhw16o16i, gOIdhw16o16i) |
653 | : utils::pick(conf.ndims - 3, OIw16o16i, |
654 | OIhw16o16i, OIdhw16o16i)); |
655 | break; |
656 | default: status = status::unimplemented; |
657 | } |
658 | if (status != status::success) return status; |
659 | |
660 | if (src_mdw.format_kind() == format_kind::any) { |
661 | conf.src_tag = src_tag; |
662 | } else { |
663 | conf.src_tag = src_mdw.matches_one_of_tag(src_tag); |
664 | } |
665 | if (conf.src_tag != src_tag) return status::unimplemented; |
666 | |
667 | if (weights_mdw.format_kind() == format_kind::any) { |
668 | conf.wei_tag = wei_tag; |
669 | } else { |
670 | conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag); |
671 | } |
672 | if (conf.wei_tag != wei_tag) return status::unimplemented; |
673 | |
674 | if (dst_mdw.format_kind() == format_kind::any) { |
675 | conf.dst_tag = dst_tag; |
676 | } else { |
677 | conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag); |
678 | } |
679 | if (conf.dst_tag != dst_tag) return status::unimplemented; |
680 | |
681 | conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw); |
682 | conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc); |
683 | |
684 | return status::success; |
685 | } |
686 | |
687 | status_t gen9_convolution_bwd_data_t::pd_t::init_kernel_ctx( |
688 | compute::kernel_ctx_t &kernel_ctx) const { |
689 | kernel_ctx.define_int("IS_DW" , conf.is_depthwise); |
690 | kernel_ctx.define_int("BWD_DATA" , 1); |
691 | kernel_ctx.define_int("G" , conf.ngroups); |
692 | kernel_ctx.define_int("MB" , conf.mb); |
693 | kernel_ctx.define_int("IC" , conf.ic); |
694 | kernel_ctx.define_int("ICB" , conf.icb); |
695 | kernel_ctx.define_int("ID" , conf.id); |
696 | kernel_ctx.define_int("IH" , conf.ih); |
697 | kernel_ctx.define_int("IW" , conf.iw); |
698 | kernel_ctx.define_int("OC" , conf.oc); |
699 | kernel_ctx.define_int("OD" , conf.od); |
700 | kernel_ctx.define_int("OH" , conf.oh); |
701 | kernel_ctx.define_int("OW" , conf.ow); |
702 | kernel_ctx.define_int("KD" , conf.kd); |
703 | kernel_ctx.define_int("KH" , conf.kh); |
704 | kernel_ctx.define_int("KW" , conf.kw); |
705 | kernel_ctx.define_int("SD" , conf.stride_d); |
706 | kernel_ctx.define_int("SH" , conf.stride_h); |
707 | kernel_ctx.define_int("SW" , conf.stride_w); |
708 | kernel_ctx.define_int("PD" , conf.f_pad); |
709 | kernel_ctx.define_int("PH" , conf.t_pad); |
710 | kernel_ctx.define_int("PW" , conf.l_pad); |
711 | kernel_ctx.define_int("PD_R" , conf.back_pad); |
712 | kernel_ctx.define_int("PH_R" , conf.b_pad); |
713 | kernel_ctx.define_int("PW_R" , conf.r_pad); |
714 | kernel_ctx.define_int("DD" , conf.dilate_d); |
715 | kernel_ctx.define_int("DH" , conf.dilate_h); |
716 | kernel_ctx.define_int("DW" , conf.dilate_w); |
717 | kernel_ctx.define_int("OC_PADDED" , utils::rnd_up(conf.oc, conf.oc_block)); |
718 | kernel_ctx.define_int("IC_PADDED" , utils::rnd_up(conf.ic, conf.ic_block)); |
719 | kernel_ctx.define_int("G_WO_PADDING" , conf.ngroups_without_padding); |
720 | kernel_ctx.define_int("OC_WO_PADDING" , conf.oc_without_padding); |
721 | kernel_ctx.define_int("IC_WO_PADDING" , conf.ic_without_padding); |
722 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
723 | kernel_ctx.define_int("IH_BLOCK" , conf.ih_block); |
724 | kernel_ctx.define_int("IW_BLOCK" , conf.iw_block); |
725 | kernel_ctx.define_int("IWB" , utils::div_up(conf.iw, conf.iw_block)); |
726 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
727 | kernel_ctx.define_int("OC_BLOCK" , conf.oc_block); |
728 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
729 | kernel_ctx.define_int("WITH_BIAS" , conf.with_bias); |
730 | |
731 | kernel_ctx.define_int("GWS_0" , conf.gws_d[0]); |
732 | kernel_ctx.define_int("GWS_1" , conf.gws_d[1]); |
733 | kernel_ctx.define_int("GWS_2" , conf.gws_d[2]); |
734 | |
735 | kernel_ctx.define_int("GWS_ORIG_0" , conf.gws_orig_d[0]); |
736 | kernel_ctx.define_int("GWS_ORIG_1" , conf.gws_orig_d[1]); |
737 | kernel_ctx.define_int("GWS_ORIG_2" , conf.gws_orig_d[2]); |
738 | |
739 | kernel_ctx.define_int("LWS_0" , conf.lws_d[0]); |
740 | kernel_ctx.define_int("LWS_1" , conf.lws_d[1]); |
741 | kernel_ctx.define_int("LWS_2" , conf.lws_d[2]); |
742 | |
743 | kernel_ctx.set_data_type(conf.src_data_type); |
744 | |
745 | switch (conf.ver) { |
746 | case ver_16mb16c: kernel_ctx.define_int("VER_16MB16C" , 1); break; |
747 | case ver_8ow16c: kernel_ctx.define_int("VER_8OW16C" , 1); break; |
748 | default: break; |
749 | } |
750 | |
751 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
752 | |
753 | return status::success; |
754 | } |
755 | |
756 | status_t gen9_convolution_bwd_data_t::execute_backward_data( |
757 | const exec_ctx_t &ctx) const { |
758 | |
759 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
760 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
761 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
762 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
763 | |
764 | const auto &conf = pd()->conf; |
765 | |
766 | compute::kernel_arg_list_t arg_list; |
767 | arg_list.set(0, diff_src); |
768 | arg_list.set(1, weights); |
769 | arg_list.set(2, diff_dst); |
770 | arg_list.set(3, bias); |
771 | |
772 | auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d); |
773 | |
774 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
775 | |
776 | return status; |
777 | } |
778 | |
779 | status_t gen9_convolution_bwd_weights_t::pd_t::init_conf(engine_t *engine) { |
780 | using namespace dnnl::impl::format_tag; |
781 | using namespace data_type; |
782 | |
783 | const convolution_desc_t &cd = *desc(); |
784 | const memory_desc_wrapper src_mdw(src_md()); |
785 | const memory_desc_wrapper weights_mdw(diff_weights_md()); |
786 | const memory_desc_wrapper dst_mdw(diff_dst_md()); |
787 | const memory_desc_wrapper bias_mdw(diff_weights_md(1)); |
788 | |
789 | set_default_conf(conf, cd, *src_md(), *diff_weights_md(), *diff_dst_md(), |
790 | *diff_weights_md(1), *attr()); |
791 | |
792 | const bool is_nhwc |
793 | = src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) != format_tag::undef |
794 | || dst_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) |
795 | != format_tag::undef; |
796 | |
797 | const bool is_1stconv = conf.ic_without_padding == 3; |
798 | const bool is_depthwise = conf.with_groups && (conf.ic_without_padding == 1) |
799 | && (conf.oc_without_padding == 1); |
800 | |
801 | conf.is_nhwc = is_nhwc; |
802 | conf.is_depthwise = is_depthwise; |
803 | |
804 | if (is_1stconv || (conf.with_groups && conf.ngroups > 1) || is_depthwise |
805 | || is_nhwc) { |
806 | conf.ic = conf.ic_without_padding; |
807 | conf.oc = is_1stconv ? utils::rnd_up(conf.oc_without_padding, 16) |
808 | : conf.oc_without_padding; |
809 | } else { |
810 | conf.ic = utils::rnd_up(conf.ic_without_padding, 16); |
811 | conf.oc = utils::rnd_up(conf.oc_without_padding, 16); |
812 | } |
813 | |
814 | conf.ngroups_without_padding = conf.ngroups; |
815 | if (is_depthwise && !is_nhwc) |
816 | conf.ngroups = utils::rnd_up(conf.ngroups, 16); |
817 | const bool is_dw_16g = (conf.is_depthwise && conf.ngroups % 16 == 0); |
818 | |
819 | const bool is_16ic = conf.ic % 16 == 0; |
820 | const bool is_16oc = conf.oc % 16 == 0; |
821 | const bool use_16mb_unroll = !is_nhwc |
822 | && !(conf.mb == 1 || conf.mb % 16 != 0) && !is_1stconv |
823 | && ((is_16ic && is_16oc) || is_dw_16g); |
824 | |
825 | conf.mb_block = 1; |
826 | conf.oc_block = 1; |
827 | conf.ic_block = 1; |
828 | conf.od_block = 1; |
829 | conf.oh_block = 1; |
830 | conf.ow_block = 1; |
831 | conf.osp_chunk = 1; |
832 | conf.mb_chunk = 1; |
833 | if (is_nhwc) |
834 | conf.ver = ver_nhwc; |
835 | else if (use_16mb_unroll) |
836 | conf.ver = ver_16mb16c; |
837 | else if (conf.mb % 16 != 0 && ((is_16oc && is_16ic) || is_dw_16g)) |
838 | conf.ver = ver_8ow16c; |
839 | else if (is_1stconv && is_16oc) |
840 | conf.ver = ver_1stconv; |
841 | else |
842 | return status::unimplemented; |
843 | |
844 | switch (conf.ver) { |
845 | case ver_1stconv: |
846 | case ver_8ow16c: |
847 | case ver_nhwc: |
848 | conf.mb_block = 1; |
849 | conf.oc_block = 16; |
850 | conf.ic_block = is_1stconv ? 1 : 16; |
851 | conf.ow_block = 8; |
852 | break; |
853 | case ver_16mb16c: |
854 | conf.mb_block = 16; |
855 | conf.oc_block = 16; |
856 | conf.ic_block = 16; |
857 | conf.ow_block = 1; |
858 | break; |
859 | } |
860 | |
861 | bwd_w_compute_block_sizes(conf, engine); |
862 | |
863 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
864 | |
865 | //TODO: Fix Gtests and reenable |
866 | const bool is_xe_hp_plus |
867 | = compute_engine->is_xe_hp() || compute_engine->is_xe_hpg(); |
868 | const bool has_non_uniform_wg |
869 | = compute_engine->mayiuse_non_uniform_work_groups(); |
870 | |
871 | conf.sub_group_size = 16; |
872 | conf.lws_d[0] = is_xe_hp_plus ? 32 : 16; |
873 | conf.lws_d[1] = 1; |
874 | conf.lws_d[2] = 1; |
875 | |
876 | if (conf.is_depthwise) { |
877 | conf.gws_d[0] = utils::rnd_up(conf.ngroups, 16); |
878 | } else { |
879 | conf.gws_d[0] = is_1stconv ? conf.ocb * conf.ngroups |
880 | : conf.ocb * (conf.icb / 16) * conf.ngroups; |
881 | } |
882 | conf.gws_d[1] = is_1stconv && !is_nhwc |
883 | ? utils::div_up(conf.kh * conf.kw * conf.kd * conf.ic, 16) |
884 | : conf.kh * conf.kw * conf.kd; |
885 | conf.gws_d[2] = conf.nchunk * utils::div_up(conf.ic, conf.icb) |
886 | * utils::div_up(conf.oc, conf.ocb); |
887 | |
888 | maybe_fix_non_uniform_work_sizes(has_non_uniform_wg, conf); |
889 | |
890 | format_tag_t src_tag, dst_tag, wei_tag; |
891 | |
892 | switch (conf.ver) { |
893 | case ver_nhwc: |
894 | src_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
895 | dst_tag = utils::pick(conf.ndims - 3, nwc, nhwc, ndhwc); |
896 | if (is_1stconv) { |
897 | wei_tag = conf.with_groups ? utils::pick( |
898 | conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) |
899 | : utils::pick(conf.ndims - 3, Owi16o, |
900 | Ohwi16o, Odhwi16o); |
901 | } else if (conf.is_depthwise) { |
902 | wei_tag = utils::pick( |
903 | conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g); |
904 | } else { |
905 | wei_tag = conf.with_groups |
906 | ? utils::pick(conf.ndims - 3, gIOw16i16o, gIOhw16i16o, |
907 | gIOdhw16i16o) |
908 | : utils::pick(conf.ndims - 3, IOw16i16o, IOhw16i16o, |
909 | IOdhw16i16o); |
910 | } |
911 | break; |
912 | case ver_1stconv: |
913 | assert(!conf.is_depthwise); |
914 | src_tag = utils::pick(conf.ndims - 3, ncw, nchw, ncdhw); |
915 | dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c); |
916 | wei_tag = conf.with_groups |
917 | ? utils::pick(conf.ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) |
918 | : utils::pick(conf.ndims - 3, Owi16o, Ohwi16o, Odhwi16o); |
919 | break; |
920 | case ver_16mb16c: |
921 | src_tag = utils::pick( |
922 | conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c); |
923 | dst_tag = utils::pick( |
924 | conf.ndims - 3, NCw16n16c, NChw16n16c, NCdhw16n16c); |
925 | wei_tag = conf.is_depthwise |
926 | ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g) |
927 | : (conf.with_groups ? utils::pick(conf.ndims - 3, |
928 | gIOw16i16o, gIOhw16i16o, gIOdhw16i16o) |
929 | : utils::pick(conf.ndims - 3, IOw16i16o, |
930 | IOhw16i16o, IOdhw16i16o)); |
931 | break; |
932 | case ver_8ow16c: |
933 | src_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c); |
934 | dst_tag = utils::pick(conf.ndims - 3, nCw16c, nChw16c, nCdhw16c); |
935 | wei_tag = conf.is_depthwise |
936 | ? utils::pick(conf.ndims - 3, Goiw16g, Goihw16g, Goidhw16g) |
937 | : (conf.with_groups ? utils::pick(conf.ndims - 3, |
938 | gIOw16i16o, gIOhw16i16o, gIOdhw16i16o) |
939 | : utils::pick(conf.ndims - 3, IOw16i16o, |
940 | IOhw16i16o, IOdhw16i16o)); |
941 | break; |
942 | default: return status::unimplemented; |
943 | } |
944 | |
945 | if (src_mdw.format_kind() == format_kind::any) { |
946 | conf.src_tag = src_tag; |
947 | } else { |
948 | conf.src_tag = src_mdw.matches_one_of_tag(src_tag); |
949 | } |
950 | if (conf.src_tag != src_tag) return status::unimplemented; |
951 | |
952 | if (weights_mdw.format_kind() == format_kind::any) { |
953 | conf.wei_tag = wei_tag; |
954 | } else { |
955 | conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag); |
956 | } |
957 | if (conf.wei_tag != wei_tag) return status::unimplemented; |
958 | |
959 | if (dst_mdw.format_kind() == format_kind::any) { |
960 | conf.dst_tag = dst_tag; |
961 | } else { |
962 | conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag); |
963 | } |
964 | if (conf.dst_tag != dst_tag) return status::unimplemented; |
965 | |
966 | conf.is_src_nchw = utils::one_of(src_tag, ncw, nchw, ncdhw); |
967 | conf.is_src_nhwc = utils::one_of(src_tag, nwc, nhwc, ndhwc); |
968 | |
969 | bool ok = set_default_formats_common( |
970 | conf.src_tag, conf.wei_tag, conf.dst_tag); |
971 | if (!ok) return status::unimplemented; |
972 | if (is_1stconv && !is_nhwc) { |
973 | if (data_type::bf16 == conf.weights_data_type) { |
974 | conf.reorder_wei = true; |
975 | auto temp_wei_md = *diff_weights_md(); |
976 | temp_wei_md.data_type = data_type::f32; |
977 | |
978 | primitive_attr_t r_attr(default_attr()); |
979 | if (!r_attr.is_initialized()) return status::out_of_memory; |
980 | |
981 | CHECK(reorder_primitive_desc_create(rpd_wei_, engine, &temp_wei_md, |
982 | diff_weights_md(), &r_attr)); |
983 | } |
984 | |
985 | if (conf.with_bias && data_type::bf16 == conf.bias_data_type) { |
986 | conf.reorder_bias = true; |
987 | auto temp_bias_md = *diff_weights_md(1); |
988 | temp_bias_md.data_type = data_type::f32; |
989 | primitive_attr_t r_attr(default_attr()); |
990 | if (!r_attr.is_initialized()) return status::out_of_memory; |
991 | |
992 | CHECK(reorder_primitive_desc_create(rpd_bia_, engine, &temp_bias_md, |
993 | diff_weights_md(1), &r_attr)); |
994 | } |
995 | } |
996 | |
997 | return status::success; |
998 | } |
999 | |
1000 | status_t gen9_convolution_bwd_weights_t::pd_t::init_kernel_ctx( |
1001 | compute::kernel_ctx_t &kernel_ctx) const { |
1002 | kernel_ctx.define_int("IS_DW" , conf.is_depthwise); |
1003 | kernel_ctx.define_int("BWD_WEIGHTS" , 1); |
1004 | kernel_ctx.define_int("G" , conf.ngroups); |
1005 | kernel_ctx.define_int("MB" , conf.mb); |
1006 | kernel_ctx.define_int("IC" , conf.ic); |
1007 | kernel_ctx.define_int("ICB" , conf.icb); |
1008 | kernel_ctx.define_int("ID" , conf.id); |
1009 | kernel_ctx.define_int("IH" , conf.ih); |
1010 | kernel_ctx.define_int("IW" , conf.iw); |
1011 | kernel_ctx.define_int("OC" , conf.oc); |
1012 | kernel_ctx.define_int("OCB" , conf.ocb); |
1013 | kernel_ctx.define_int("OD" , conf.od); |
1014 | kernel_ctx.define_int("OH" , conf.oh); |
1015 | kernel_ctx.define_int("OW" , conf.ow); |
1016 | kernel_ctx.define_int("KD" , conf.kd); |
1017 | kernel_ctx.define_int("KH" , conf.kh); |
1018 | kernel_ctx.define_int("KW" , conf.kw); |
1019 | kernel_ctx.define_int("SD" , conf.stride_d); |
1020 | kernel_ctx.define_int("SH" , conf.stride_h); |
1021 | kernel_ctx.define_int("SW" , conf.stride_w); |
1022 | kernel_ctx.define_int("PD" , conf.f_pad); |
1023 | kernel_ctx.define_int("PH" , conf.t_pad); |
1024 | kernel_ctx.define_int("PW" , conf.l_pad); |
1025 | kernel_ctx.define_int("PD_R" , conf.back_pad); |
1026 | kernel_ctx.define_int("PH_R" , conf.b_pad); |
1027 | kernel_ctx.define_int("PW_R" , conf.r_pad); |
1028 | kernel_ctx.define_int("DD" , conf.dilate_d); |
1029 | kernel_ctx.define_int("DH" , conf.dilate_h); |
1030 | kernel_ctx.define_int("DW" , conf.dilate_w); |
1031 | kernel_ctx.define_int("OC_PADDED" , conf.oc); |
1032 | kernel_ctx.define_int("OC_WO_PADDING" , conf.oc_without_padding); |
1033 | kernel_ctx.define_int("G_WO_PADDING" , conf.ngroups_without_padding); |
1034 | |
1035 | kernel_ctx.define_int("OW_BLOCK" , conf.ow_block); |
1036 | kernel_ctx.define_int("ODB" , conf.odb); |
1037 | kernel_ctx.define_int("OHB" , conf.ohb); |
1038 | kernel_ctx.define_int("OWB" , conf.owb); |
1039 | |
1040 | kernel_ctx.define_int("WITH_BIAS" , conf.with_bias); |
1041 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
1042 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
1043 | kernel_ctx.define_int("OC_BLOCK" , conf.oc_block); |
1044 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
1045 | kernel_ctx.define_int("NCHUNK" , conf.nchunk); |
1046 | kernel_ctx.define_int("OSP_CHUNK" , conf.osp_chunk); |
1047 | kernel_ctx.define_int("MB_CHUNK" , conf.mb_chunk); |
1048 | kernel_ctx.define_int( |
1049 | "MB_CHUNK_SIZE" , utils::div_up(conf.mb, conf.mb_chunk)); |
1050 | kernel_ctx.define_int("OW_BLOCK" , conf.ow_block); |
1051 | |
1052 | kernel_ctx.define_int("GWS_0" , conf.gws_d[0]); |
1053 | kernel_ctx.define_int("GWS_1" , conf.gws_d[1]); |
1054 | kernel_ctx.define_int("GWS_2" , conf.gws_d[2]); |
1055 | |
1056 | kernel_ctx.define_int("GWS_ORIG_0" , conf.gws_orig_d[0]); |
1057 | kernel_ctx.define_int("GWS_ORIG_1" , conf.gws_orig_d[1]); |
1058 | kernel_ctx.define_int("GWS_ORIG_2" , conf.gws_orig_d[2]); |
1059 | |
1060 | kernel_ctx.define_int("LWS_0" , conf.lws_d[0]); |
1061 | kernel_ctx.define_int("LWS_1" , conf.lws_d[1]); |
1062 | kernel_ctx.define_int("LWS_2" , conf.lws_d[2]); |
1063 | |
1064 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
1065 | |
1066 | kernel_ctx.set_data_type(data_type::f32); |
1067 | def_data_type(kernel_ctx, src_md()->data_type, "SRC" ); |
1068 | |
1069 | def_data_type(kernel_ctx, diff_dst_md()->data_type, "DST" ); |
1070 | |
1071 | def_data_type(kernel_ctx, |
1072 | diff_weights_md(conf.with_bias ? 1 : 0)->data_type, "BIA" ); |
1073 | |
1074 | def_data_type(kernel_ctx, data_type::f32, "WEI" ); |
1075 | |
1076 | switch (conf.ver) { |
1077 | case ver_16mb16c: kernel_ctx.define_int("VER_16MB16C" , 1); break; |
1078 | case ver_1stconv: |
1079 | case ver_8ow16c: kernel_ctx.define_int("VER_8OW16C" , 1); break; |
1080 | default: break; |
1081 | } |
1082 | |
1083 | return status::success; |
1084 | } |
1085 | |
1086 | status_t gen9_convolution_bwd_weights_t::pd_t::init_scratchpad() { |
1087 | auto scratchpad = scratchpad_registry().registrar(); |
1088 | if (!conf.reorder_wei && !conf.reorder_bias) return status::success; |
1089 | if (conf.reorder_wei) { |
1090 | auto temp_wei_md = *diff_weights_md(); |
1091 | temp_wei_md.data_type = data_type::f32; |
1092 | memory_desc_wrapper wei_md_d(temp_wei_md); |
1093 | scratchpad.book(memory_tracking::names::key_conv_bwd_w_1st_wei_reorder, |
1094 | wei_md_d.size(), 1, OCL_BUFFER_ALIGNMENT); |
1095 | scratchpad.book(memory_tracking::names::key_nested_multiple, |
1096 | rpd_wei_->scratchpad_registry()); |
1097 | } |
1098 | if (!conf.reorder_bias) return status::success; |
1099 | auto temp_bias_md = *diff_weights_md(1); |
1100 | temp_bias_md.data_type = data_type::f32; |
1101 | memory_desc_wrapper bia_md_d(temp_bias_md); |
1102 | scratchpad.book(memory_tracking::names::key_conv_bwd_w_1st_bia_reorder, |
1103 | bia_md_d.size(), 1, OCL_BUFFER_ALIGNMENT); |
1104 | scratchpad.book(memory_tracking::names::key_nested_multiple + 1, |
1105 | rpd_bia_->scratchpad_registry()); |
1106 | |
1107 | return status::success; |
1108 | } |
1109 | |
1110 | status_t gen9_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
1111 | |
1112 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
1113 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
1114 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
1115 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
1116 | |
1117 | const auto &conf = pd()->conf; |
1118 | |
1119 | compute::kernel_arg_list_t arg_list; |
1120 | arg_list.set(0, src); |
1121 | arg_list.set(1, weights); |
1122 | arg_list.set(2, bias); |
1123 | arg_list.set(3, dst); |
1124 | append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_); |
1125 | |
1126 | auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d); |
1127 | |
1128 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
1129 | |
1130 | if (!post_ops_preserves_zeroes(ctx, pd()->attr()->post_ops_)) { |
1131 | ctx.zero_pad_output(DNNL_ARG_DST); |
1132 | } |
1133 | return status; |
1134 | } |
1135 | |
1136 | status_t gen9_convolution_bwd_weights_t::execute_backward_weights( |
1137 | const exec_ctx_t &ctx) const { |
1138 | auto *compute_stream |
1139 | = utils::downcast<compute::compute_stream_t *>(ctx.stream()); |
1140 | |
1141 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
1142 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
1143 | auto &diff_weights = CTX_OUT_STORAGE(DNNL_ARG_DIFF_WEIGHTS); |
1144 | auto &diff_bias = CTX_OUT_STORAGE(DNNL_ARG_DIFF_BIAS); |
1145 | |
1146 | const auto &conf = pd()->conf; |
1147 | |
1148 | const uint8_t zero = 0; |
1149 | std::unique_ptr<memory_t> wspace_wei; |
1150 | std::unique_ptr<memory_t> wspace_bia; |
1151 | auto temp_wei_md = *pd()->diff_weights_md(); |
1152 | auto temp_bia_md = *pd()->diff_weights_md(1); |
1153 | std::unique_ptr<memory_storage_t> wspace_ptr_wei; |
1154 | std::unique_ptr<memory_storage_t> wspace_ptr_bia; |
1155 | if (conf.reorder_wei) { |
1156 | wspace_ptr_wei = ctx.get_scratchpad_grantor().get_memory_storage( |
1157 | memory_tracking::names::key_conv_bwd_w_1st_wei_reorder); |
1158 | |
1159 | temp_wei_md.data_type = data_type::f32; |
1160 | } |
1161 | if (conf.reorder_bias) { |
1162 | wspace_ptr_bia = ctx.get_scratchpad_grantor().get_memory_storage( |
1163 | memory_tracking::names::key_conv_bwd_w_1st_bia_reorder); |
1164 | |
1165 | temp_bia_md.data_type = data_type::f32; |
1166 | } |
1167 | |
1168 | memory_desc_wrapper wei_mdw(temp_wei_md); |
1169 | CHECK(compute_stream->fill( |
1170 | conf.reorder_wei ? *wspace_ptr_wei : diff_weights, zero, |
1171 | wei_mdw.size())); |
1172 | if (conf.with_bias) { |
1173 | memory_desc_wrapper bia_mdw(temp_bia_md); |
1174 | CHECK(compute_stream->fill( |
1175 | conf.reorder_bias ? *wspace_ptr_bia : diff_bias, zero, |
1176 | bia_mdw.size())); |
1177 | } |
1178 | |
1179 | compute::kernel_arg_list_t arg_list; |
1180 | arg_list.set(0, src); |
1181 | arg_list.set(1, conf.reorder_wei ? *wspace_ptr_wei : diff_weights); |
1182 | arg_list.set(2, conf.reorder_bias ? *wspace_ptr_bia : diff_bias); |
1183 | arg_list.set(3, diff_dst); |
1184 | |
1185 | status_t status = parallel_for(ctx, |
1186 | compute::nd_range_t(conf.gws_d, conf.lws_d), kernel_, arg_list); |
1187 | if (status != status::success) return status; |
1188 | auto exec_reorder = [&](memory_t *in, memory_t *out, |
1189 | const std::shared_ptr<primitive_t> &prim, |
1190 | int r_num) -> status_t { |
1191 | exec_args_t r_args; |
1192 | r_args[DNNL_ARG_FROM] = memory_arg_t {in, true}; |
1193 | r_args[DNNL_ARG_TO] = memory_arg_t {out, false}; |
1194 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
1195 | nested_scratchpad_t ns( |
1196 | ctx, memory_tracking::names::key_nested_multiple + r_num, prim); |
1197 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
1198 | return prim->execute(r_ctx); |
1199 | }; |
1200 | |
1201 | if (conf.reorder_wei) { |
1202 | CHECK(safe_ptr_assign(wspace_wei, |
1203 | new memory_t(ctx.stream()->engine(), &temp_wei_md, |
1204 | std::move(wspace_ptr_wei)))); |
1205 | CHECK(exec_reorder(wspace_wei.get(), ctx.output(DNNL_ARG_DIFF_WEIGHTS), |
1206 | wei_reorder_, 0)); |
1207 | } |
1208 | if (conf.reorder_bias) { |
1209 | CHECK(safe_ptr_assign(wspace_bia, |
1210 | new memory_t(ctx.stream()->engine(), &temp_bia_md, |
1211 | std::move(wspace_ptr_bia)))); |
1212 | CHECK(exec_reorder(wspace_bia.get(), ctx.output(DNNL_ARG_DIFF_BIAS), |
1213 | bia_reorder_, 1)); |
1214 | } |
1215 | |
1216 | return status::success; |
1217 | } |
1218 | |
1219 | } // namespace ocl |
1220 | } // namespace gpu |
1221 | } // namespace impl |
1222 | } // namespace dnnl |
1223 | |
1224 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1225 | |