1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gen9_wino_convolution.hpp"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_traits.hpp"
21#include "common/math_utils.hpp"
22#include "common/type_helpers.hpp"
23#include "gpu/compute/device_info.hpp"
24#include "gpu/ocl/ocl_memory_storage.hpp"
25
26using namespace dnnl::impl::memory_tracking::names;
27
28namespace dnnl {
29namespace impl {
30namespace gpu {
31namespace ocl {
32
33using namespace dnnl::impl::data_type;
34using namespace dnnl::impl::format_tag;
35
36static bool is_impl_optimal(conv_conf_t &conf, const convolution_desc_t &cd,
37 const compute::gpu_arch_t arch) {
38 if (cd.alg_kind == alg_kind::convolution_winograd) return true;
39
40 int ow_blocks = conf.wino_ow / conf.ow_block;
41 float ow_util = (float)conf.ow / conf.wino_ow;
42 int oh_blocks = conf.wino_oh / conf.oh_block;
43 float oh_util = (float)conf.oh / conf.wino_oh;
44 int oc_blocks = conf.ocb;
45 float oc_util = (float)conf.oc_without_padding / conf.wino_oc;
46 float ic_util = (float)conf.ic_without_padding / conf.wino_ic;
47
48 int blocks = ow_blocks * oh_blocks * oc_blocks;
49 float utilization = ow_util * oh_util * oc_util * ic_util;
50 float score;
51
52 switch (arch) {
53 case compute::gpu_arch_t::gen9:
54 case compute::gpu_arch_t::gen11:
55 score = blocks * utilization;
56 if (score >= 128 && utilization >= 0.50) return true;
57 return false;
58 case compute::gpu_arch_t::xe_lp:
59 // Performance is poor with large oc*ic and small spatial, this is
60 // likely due to overflowing cache and no blocking on ic.
61 score = (float)conf.oc * conf.ic / (oh_blocks * ow_blocks);
62 if (score < 32 * 1024 && utilization >= 0.50) return true;
63 return false;
64 default: return false;
65 }
66}
67
68static void fwd_compute_block_sizes(
69 conv_conf_t &conf, const compute::gpu_arch_t arch) {
70
71 if (conf.ver == ver_16mb16c) {
72 conf.mb_block = (conf.src_data_type == data_type::f16)
73 ? (conf.mb % 32 == 0 ? 32 : 16)
74 : 16;
75 } else {
76 conf.mb_block = 1;
77 }
78
79 //Using F(m, r) for r = 3 and tile_size = m + r - 1
80 const int m = utils::div_up(conf.oh, 6) < utils::div_up(conf.oh, 4)
81 ? 6
82 : conf.oh > 2 ? 4 : 2;
83 const int r = 3;
84 conf.is_fused = true;
85
86 conf.wino_m = m;
87 conf.wino_r = r;
88 conf.tile_size = m + r - 1;
89
90 const bool is_pre_gen12 = utils::one_of(
91 arch, compute::gpu_arch_t::gen9, compute::gpu_arch_t::gen11);
92
93 conf.vect_size = is_pre_gen12
94 ? static_cast<int>(16 / types::data_type_size(conf.src_data_type))
95 : 8;
96 conf.oc_block = 16;
97 conf.ic_block = nstl::min(conf.ic, 16);
98 if (conf.src_data_type == data_type::f16)
99 conf.wino_ic_block = 32;
100 else if (is_pre_gen12 && conf.ow * conf.oh <= 256)
101 conf.wino_ic_block = 32;
102 else if (arch >= compute::gpu_arch_t::xe_hpc)
103 // XeHPC does not support subgroup size 8;
104 conf.wino_ic_block = 32;
105 else
106 conf.wino_ic_block = 16;
107
108 conf.ocb = utils::div_up(conf.oc, conf.oc_block);
109
110 if (conf.is_fused) {
111 conf.wino_oc_block = 16;
112 conf.oh_block = conf.wino_m;
113 conf.ow_block = conf.ow > 14 ? 14 : utils::rnd_up(conf.ow, 2);
114 } else {
115 conf.wino_oc_block = 32;
116 conf.oh_block = 8;
117 conf.ow_block = conf.wino_m;
118 }
119
120 // Used for the internal data transform
121 conf.wino_ow = utils::rnd_up(conf.ow, conf.ow_block);
122 conf.wino_iw = conf.wino_ow;
123 conf.wino_oh = utils::rnd_up(conf.oh, conf.oh_block);
124 conf.wino_ih = conf.wino_oh + conf.t_pad + conf.b_pad;
125 conf.wino_ic = utils::rnd_up(conf.ic, conf.wino_ic_block);
126 conf.wino_oc = utils::rnd_up(conf.oc, conf.wino_oc_block);
127}
128
129status_t gen9_wino_convolution_fwd_t::pd_t::init_conf(
130 compute::compute_engine_t *engine) {
131
132 const convolution_desc_t &cd = *desc();
133 const memory_desc_wrapper src_mdw(src_md());
134 const memory_desc_wrapper weights_mdw(weights_md());
135 const memory_desc_wrapper dst_mdw(dst_md());
136 const memory_desc_wrapper bias_mdw(weights_md(1));
137
138 set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(),
139 *weights_md(1), *attr());
140
141 conf.ic = utils::rnd_up(conf.ic_without_padding, 16);
142 conf.oc = utils::rnd_up(conf.oc_without_padding, 16);
143
144 const bool is_wino_shape = conf.ndims == 4 && conf.kh == 3 && conf.kw == 3
145 && conf.ngroups == 1 && conf.stride_h == 1 && conf.stride_w == 1
146 && conf.dilate_h == 0 && conf.dilate_w == 0 && conf.l_pad < conf.kw
147 && conf.r_pad < conf.kw && conf.t_pad < conf.kh
148 && conf.b_pad < conf.kh;
149 if (!is_wino_shape) return status::unimplemented;
150
151 const bool is_16oc = conf.oc % 16 == 0;
152 const bool is_16ic = conf.ic % 16 == 0;
153
154 if (src_mdw.matches_one_of_tag(nhwc)
155 && (dst_mdw.matches_one_of_tag(nhwc)
156 || dst_mdw.format_kind() == format_kind::any)) {
157 // Technically this implementation currently requires ic is a multiple
158 // of VTRANS_BLOCK = 4. This condition was not implemented yet due to no
159 // known use case, and small IC is expected to have poor performance
160 // because of extra work created by the current blocking.
161 if (conf.ic_without_padding % 16 != 0
162 || conf.oc_without_padding % 16 != 0)
163 return status::unimplemented;
164 conf.ver = ver_nhwc;
165 } else if ((is_16oc && is_16ic)) {
166 conf.ver = (conf.mb % 16 == 0) ? ver_16mb16c : ver_8ow16c;
167 } else {
168 return status::unimplemented;
169 }
170
171 const compute::gpu_arch_t arch = engine->device_info()->gpu_arch();
172 fwd_compute_block_sizes(conf, arch);
173 if (!is_impl_optimal(conf, cd, arch)) return status::unimplemented;
174
175 size_t U_sz = conf.tile_size * conf.kh * conf.wino_ic * conf.wino_oc;
176 size_t M_sz = 0, V_sz = 0;
177 if (!conf.is_fused) {
178 M_sz = conf.tile_size * conf.mb * conf.wino_oc * conf.wino_oh
179 * conf.wino_ow;
180 V_sz = conf.tile_size * conf.mb * conf.wino_ic * conf.wino_ih
181 * conf.wino_iw;
182 }
183
184 // Limit max problem size since this method uses more memory
185 if (U_sz + M_sz + V_sz > 300000000) return status::unimplemented;
186
187 //Using F(m, r) for r = 3 and tile_size = m + r - 1
188 if (!conf.is_fused) {
189 conf.mb_block = 1;
190 conf.lws_d[0] = 8;
191 conf.lws_d[1] = 1;
192 conf.lws_d[2] = 1;
193 conf.gws_d[0] = (conf.wino_oc / conf.wino_oc_block) * conf.lws_d[0];
194 conf.gws_d[1] = conf.wino_ow * (conf.wino_oh / conf.oh_block);
195 conf.gws_d[2] = (conf.mb / conf.mb_block) * conf.tile_size;
196
197 conf.U_lws_d[0] = 1;
198 conf.U_lws_d[1] = 1;
199 conf.U_lws_d[2] = 1;
200 conf.U_gws_d[0] = 1;
201 conf.U_gws_d[1] = 3; // kh or kw depending
202 conf.U_gws_d[2] = conf.wino_ic * conf.wino_oc;
203
204 conf.V_lws_d[0] = 1;
205 conf.V_lws_d[1] = 1;
206 conf.V_lws_d[2] = 1;
207 conf.V_gws_d[0] = conf.wino_ow;
208 conf.V_gws_d[1] = conf.wino_ih;
209 conf.V_gws_d[2] = conf.wino_ic / conf.ic_block * conf.mb;
210
211 conf.M_lws_d[0] = 1;
212 conf.M_lws_d[1] = 1;
213 conf.M_lws_d[2] = 1;
214 conf.M_gws_d[0] = utils::div_up(conf.ow, conf.ow_block);
215 conf.M_gws_d[1] = conf.oh;
216 conf.M_gws_d[2] = conf.oc / conf.oc_block * conf.mb;
217 } else {
218 conf.mb_block = 1;
219 conf.lws_d[0] = conf.wino_ic_block / 2;
220 conf.lws_d[1] = 8;
221 conf.lws_d[2] = 1;
222 conf.gws_d[0]
223 = utils::div_up(conf.wino_ow, conf.ow_block) * conf.lws_d[0];
224 conf.gws_d[1]
225 = utils::div_up(conf.wino_oh, conf.oh_block) * conf.lws_d[1];
226 conf.gws_d[2] = (conf.mb / conf.mb_block)
227 * (conf.wino_oc / conf.wino_oc_block);
228
229 conf.U_lws_d[0] = conf.wino_ic_block / 2;
230 conf.U_lws_d[1] = 1;
231 conf.U_lws_d[2] = 1;
232 conf.U_gws_d[0] = conf.wino_ic * conf.wino_oc / conf.vect_size;
233 conf.U_gws_d[1] = 3;
234 conf.U_gws_d[2] = 1; // kh or kw depending
235 }
236
237 format_tag_t src_tag, dst_tag, wei_tag;
238
239 switch (conf.ver) {
240 case ver_16mb16c:
241 src_tag = NChw16n16c;
242 dst_tag = NChw16n16c;
243 wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
244 break;
245 case ver_8ow16c:
246 src_tag = nChw16c;
247 dst_tag = nChw16c;
248 wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
249 break;
250 case ver_nhwc:
251 src_tag = nhwc;
252 dst_tag = nhwc;
253 wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o;
254 break;
255 default: return status::unimplemented;
256 }
257
258 if (src_mdw.format_kind() == format_kind::any) {
259 conf.src_tag = src_tag;
260 } else {
261 conf.src_tag = src_mdw.matches_one_of_tag(src_tag);
262 }
263 if (conf.src_tag != src_tag) return status::unimplemented;
264
265 if (weights_mdw.format_kind() == format_kind::any) {
266 conf.wei_tag = wei_tag;
267 } else {
268 conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag);
269 }
270 if (conf.wei_tag != wei_tag) return status::unimplemented;
271
272 if (dst_mdw.format_kind() == format_kind::any) {
273 conf.dst_tag = dst_tag;
274 } else {
275 conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag);
276 }
277 if (conf.dst_tag != dst_tag) return status::unimplemented;
278
279 return status::success;
280}
281
282void gen9_wino_convolution_fwd_t::pd_t::init_scratchpad() {
283 auto scratchpad = scratchpad_registry().registrar();
284
285 auto wei_data_t = this->desc()->weights_desc.data_type;
286 size_t U_sz = conf.tile_size * conf.kh * conf.wino_ic * conf.wino_oc;
287 scratchpad.book(key_wino_U, U_sz, types::data_type_size(wei_data_t),
288 OCL_BUFFER_ALIGNMENT);
289
290 if (!conf.is_fused) {
291 auto dst_data_t = this->desc()->dst_desc.data_type;
292 size_t M_sz = conf.tile_size * conf.mb * conf.wino_oc * conf.wino_oh
293 * conf.wino_ow;
294 scratchpad.book(key_wino_M, M_sz, types::data_type_size(dst_data_t),
295 OCL_BUFFER_ALIGNMENT);
296
297 auto src_data_t = this->desc()->src_desc.data_type;
298 size_t V_sz = conf.tile_size * conf.mb * conf.wino_ic * conf.wino_ih
299 * conf.wino_iw;
300 scratchpad.book(key_wino_V, V_sz, types::data_type_size(src_data_t),
301 OCL_BUFFER_ALIGNMENT);
302 }
303}
304
305status_t gen9_wino_convolution_fwd_t::pd_t::init_kernel_ctx(
306 compute::kernel_ctx_t &kernel_ctx) const {
307 kernel_ctx.define_int("G", conf.ngroups);
308 kernel_ctx.define_int("MB", conf.mb);
309 kernel_ctx.define_int("IC", conf.ic);
310 kernel_ctx.define_int("ID", conf.id);
311 kernel_ctx.define_int("IH", conf.ih);
312 kernel_ctx.define_int("IW", conf.iw);
313 kernel_ctx.define_int("OC", conf.oc);
314 kernel_ctx.define_int("OD", conf.od);
315 kernel_ctx.define_int("OH", conf.oh);
316 kernel_ctx.define_int("OW", conf.ow);
317 kernel_ctx.define_int("KD", conf.kd);
318 kernel_ctx.define_int("KH", conf.kh);
319 kernel_ctx.define_int("KW", conf.kw);
320 kernel_ctx.define_int("PH", conf.t_pad);
321 kernel_ctx.define_int("PW", conf.l_pad);
322 kernel_ctx.define_int("OCB", conf.ocb);
323 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
324 kernel_ctx.define_int("OH_BLOCK", conf.oh_block);
325 kernel_ctx.define_int("OW_BLOCK", conf.ow_block);
326 kernel_ctx.define_int("OW_LAST", utils::rnd_dn(conf.ow, conf.ow_block));
327 kernel_ctx.define_int("OWB", utils::div_up(conf.ow, conf.ow_block));
328 kernel_ctx.define_int("OHB", utils::div_up(conf.oh, conf.oh_block));
329 kernel_ctx.define_int("OC_WO_PADDING", conf.oc_without_padding);
330 kernel_ctx.define_int("WINO_M", conf.wino_m);
331 kernel_ctx.define_int("WINO_R", conf.wino_r);
332 kernel_ctx.define_int("WINO_IC_BLOCK", conf.wino_ic_block);
333 kernel_ctx.define_int("WINO_OC_BLOCK", conf.wino_oc_block);
334 kernel_ctx.define_int("WINO_IC", conf.wino_ic);
335 kernel_ctx.define_int("WINO_OC", conf.wino_oc);
336 kernel_ctx.define_int("WINO_IH", conf.wino_ih);
337 kernel_ctx.define_int("WINO_IW", conf.wino_iw);
338 kernel_ctx.define_int("WINO_OH", conf.wino_oh);
339 kernel_ctx.define_int("WINO_OW", conf.wino_ow);
340 kernel_ctx.define_int("OC_BLOCK", conf.oc_block);
341 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
342 kernel_ctx.define_int("VECT_DT_N", conf.vect_size);
343
344 kernel_ctx.set_data_type(conf.src_data_type);
345
346 kernel_ctx.define_int("VER_8OW16C", conf.ver == ver_8ow16c);
347 kernel_ctx.define_int("VER_16MB16C", conf.ver == ver_16mb16c);
348
349 kernel_ctx.define_int("SRC_NHWC", utils::one_of(conf.src_tag, nhwc));
350 kernel_ctx.define_int(
351 "SRC_16N16C", utils::one_of(conf.src_tag, NChw16n16c));
352 kernel_ctx.define_int("SRC_W16C", utils::one_of(conf.src_tag, nChw16c));
353
354 kernel_ctx.define_int(
355 "WEI_16I16O", utils::one_of(conf.wei_tag, gOIhw16i16o, OIhw16i16o));
356 kernel_ctx.define_int("WEI_16I16O_FLIPPED",
357 utils::one_of(conf.wei_tag, gIOhw16i16o, IOhw16i16o));
358
359 kernel_ctx.define_int("DST_NHWC", utils::one_of(conf.src_tag, nhwc));
360 kernel_ctx.define_int(
361 "DST_16N16C", utils::one_of(conf.dst_tag, NChw16n16c));
362 kernel_ctx.define_int("DST_W16C", utils::one_of(conf.dst_tag, nChw16c));
363
364 kernel_ctx.define_int("WITH_BIAS", conf.with_bias);
365
366 dnnl_dims_t dst_dims;
367 dst_dims[0] = conf.mb;
368 dst_dims[1] = conf.oc_without_padding;
369 dst_dims[2] = conf.ndims > 4 ? conf.od : conf.oh;
370 dst_dims[3] = conf.ndims > 4 ? conf.oh : conf.ow;
371 dst_dims[4] = conf.ow;
372 def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_, &dst_dims);
373
374 kernel_ctx.print_options();
375 return status::success;
376}
377
378status_t gen9_wino_convolution_fwd_t::execute_forward(
379 const exec_ctx_t &ctx) const {
380
381 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
382 auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
383 auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
384 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
385
386 const auto &conf = pd()->conf;
387 const auto &attr_info = conf.attr_info;
388
389 std::unique_ptr<memory_storage_t> wei_trans
390 = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_U);
391 compute::kernel_arg_list_t wei_transform_args;
392 wei_transform_args.set(0, *wei_trans);
393 wei_transform_args.set(1, weights);
394 auto wei_trans_nd_range = compute::nd_range_t(conf.U_gws_d, conf.U_lws_d);
395 status_t status = parallel_for(
396 ctx, wei_trans_nd_range, wei_trans_kernel_, wei_transform_args);
397
398 if (conf.is_fused) {
399 compute::kernel_arg_list_t arg_list;
400 arg_list.set(0, dst);
401 arg_list.set(1, src);
402 arg_list.set(2, *wei_trans);
403 arg_list.set(3, bias);
404 append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_);
405 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
406 status = parallel_for(ctx, nd_range, kernel_, arg_list);
407 } else {
408 std::unique_ptr<memory_storage_t> src_trans
409 = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_V);
410 compute::kernel_arg_list_t src_transform_args;
411 src_transform_args.set(0, *src_trans);
412 src_transform_args.set(1, src);
413 auto src_trans_nd_range
414 = compute::nd_range_t(conf.V_gws_d, conf.V_lws_d);
415 status = parallel_for(
416 ctx, src_trans_nd_range, src_trans_kernel_, src_transform_args);
417
418 std::unique_ptr<memory_storage_t> M_buf
419 = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_M);
420 compute::kernel_arg_list_t arg_list;
421 arg_list.set(0, *M_buf);
422 arg_list.set(1, *src_trans);
423 arg_list.set(2, *wei_trans);
424 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
425 status = parallel_for(ctx, nd_range, kernel_, arg_list);
426
427 compute::kernel_arg_list_t dst_transform_args;
428 dst_transform_args.set(0, dst);
429 dst_transform_args.set(1, *M_buf);
430 dst_transform_args.set(2, bias);
431 append_post_ops_to_arg_list(
432 ctx, dst_transform_args, 3, pd()->attr()->post_ops_);
433 auto dst_trans_nd_range
434 = compute::nd_range_t(conf.M_gws_d, conf.M_lws_d);
435 status = parallel_for(
436 ctx, dst_trans_nd_range, dst_trans_kernel_, dst_transform_args);
437 }
438
439 if (attr_info.with_eltwise
440 && !gpu_eltwise_fwd_pd_t::eltwise_preserves_zero(
441 attr_info.eltwise_alg, attr_info.eltwise_alpha,
442 attr_info.eltwise_beta)) {
443 ctx.zero_pad_output(DNNL_ARG_DST);
444 }
445 return status;
446}
447} // namespace ocl
448} // namespace gpu
449} // namespace impl
450} // namespace dnnl
451
452// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
453