1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gen9_wino_convolution.hpp" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_traits.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "gpu/compute/device_info.hpp" |
24 | #include "gpu/ocl/ocl_memory_storage.hpp" |
25 | |
26 | using namespace dnnl::impl::memory_tracking::names; |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace gpu { |
31 | namespace ocl { |
32 | |
33 | using namespace dnnl::impl::data_type; |
34 | using namespace dnnl::impl::format_tag; |
35 | |
36 | static bool is_impl_optimal(conv_conf_t &conf, const convolution_desc_t &cd, |
37 | const compute::gpu_arch_t arch) { |
38 | if (cd.alg_kind == alg_kind::convolution_winograd) return true; |
39 | |
40 | int ow_blocks = conf.wino_ow / conf.ow_block; |
41 | float ow_util = (float)conf.ow / conf.wino_ow; |
42 | int oh_blocks = conf.wino_oh / conf.oh_block; |
43 | float oh_util = (float)conf.oh / conf.wino_oh; |
44 | int oc_blocks = conf.ocb; |
45 | float oc_util = (float)conf.oc_without_padding / conf.wino_oc; |
46 | float ic_util = (float)conf.ic_without_padding / conf.wino_ic; |
47 | |
48 | int blocks = ow_blocks * oh_blocks * oc_blocks; |
49 | float utilization = ow_util * oh_util * oc_util * ic_util; |
50 | float score; |
51 | |
52 | switch (arch) { |
53 | case compute::gpu_arch_t::gen9: |
54 | case compute::gpu_arch_t::gen11: |
55 | score = blocks * utilization; |
56 | if (score >= 128 && utilization >= 0.50) return true; |
57 | return false; |
58 | case compute::gpu_arch_t::xe_lp: |
59 | // Performance is poor with large oc*ic and small spatial, this is |
60 | // likely due to overflowing cache and no blocking on ic. |
61 | score = (float)conf.oc * conf.ic / (oh_blocks * ow_blocks); |
62 | if (score < 32 * 1024 && utilization >= 0.50) return true; |
63 | return false; |
64 | default: return false; |
65 | } |
66 | } |
67 | |
68 | static void fwd_compute_block_sizes( |
69 | conv_conf_t &conf, const compute::gpu_arch_t arch) { |
70 | |
71 | if (conf.ver == ver_16mb16c) { |
72 | conf.mb_block = (conf.src_data_type == data_type::f16) |
73 | ? (conf.mb % 32 == 0 ? 32 : 16) |
74 | : 16; |
75 | } else { |
76 | conf.mb_block = 1; |
77 | } |
78 | |
79 | //Using F(m, r) for r = 3 and tile_size = m + r - 1 |
80 | const int m = utils::div_up(conf.oh, 6) < utils::div_up(conf.oh, 4) |
81 | ? 6 |
82 | : conf.oh > 2 ? 4 : 2; |
83 | const int r = 3; |
84 | conf.is_fused = true; |
85 | |
86 | conf.wino_m = m; |
87 | conf.wino_r = r; |
88 | conf.tile_size = m + r - 1; |
89 | |
90 | const bool is_pre_gen12 = utils::one_of( |
91 | arch, compute::gpu_arch_t::gen9, compute::gpu_arch_t::gen11); |
92 | |
93 | conf.vect_size = is_pre_gen12 |
94 | ? static_cast<int>(16 / types::data_type_size(conf.src_data_type)) |
95 | : 8; |
96 | conf.oc_block = 16; |
97 | conf.ic_block = nstl::min(conf.ic, 16); |
98 | if (conf.src_data_type == data_type::f16) |
99 | conf.wino_ic_block = 32; |
100 | else if (is_pre_gen12 && conf.ow * conf.oh <= 256) |
101 | conf.wino_ic_block = 32; |
102 | else if (arch >= compute::gpu_arch_t::xe_hpc) |
103 | // XeHPC does not support subgroup size 8; |
104 | conf.wino_ic_block = 32; |
105 | else |
106 | conf.wino_ic_block = 16; |
107 | |
108 | conf.ocb = utils::div_up(conf.oc, conf.oc_block); |
109 | |
110 | if (conf.is_fused) { |
111 | conf.wino_oc_block = 16; |
112 | conf.oh_block = conf.wino_m; |
113 | conf.ow_block = conf.ow > 14 ? 14 : utils::rnd_up(conf.ow, 2); |
114 | } else { |
115 | conf.wino_oc_block = 32; |
116 | conf.oh_block = 8; |
117 | conf.ow_block = conf.wino_m; |
118 | } |
119 | |
120 | // Used for the internal data transform |
121 | conf.wino_ow = utils::rnd_up(conf.ow, conf.ow_block); |
122 | conf.wino_iw = conf.wino_ow; |
123 | conf.wino_oh = utils::rnd_up(conf.oh, conf.oh_block); |
124 | conf.wino_ih = conf.wino_oh + conf.t_pad + conf.b_pad; |
125 | conf.wino_ic = utils::rnd_up(conf.ic, conf.wino_ic_block); |
126 | conf.wino_oc = utils::rnd_up(conf.oc, conf.wino_oc_block); |
127 | } |
128 | |
129 | status_t gen9_wino_convolution_fwd_t::pd_t::init_conf( |
130 | compute::compute_engine_t *engine) { |
131 | |
132 | const convolution_desc_t &cd = *desc(); |
133 | const memory_desc_wrapper src_mdw(src_md()); |
134 | const memory_desc_wrapper weights_mdw(weights_md()); |
135 | const memory_desc_wrapper dst_mdw(dst_md()); |
136 | const memory_desc_wrapper bias_mdw(weights_md(1)); |
137 | |
138 | set_default_conf(conf, cd, *src_md(), *weights_md(), *dst_md(), |
139 | *weights_md(1), *attr()); |
140 | |
141 | conf.ic = utils::rnd_up(conf.ic_without_padding, 16); |
142 | conf.oc = utils::rnd_up(conf.oc_without_padding, 16); |
143 | |
144 | const bool is_wino_shape = conf.ndims == 4 && conf.kh == 3 && conf.kw == 3 |
145 | && conf.ngroups == 1 && conf.stride_h == 1 && conf.stride_w == 1 |
146 | && conf.dilate_h == 0 && conf.dilate_w == 0 && conf.l_pad < conf.kw |
147 | && conf.r_pad < conf.kw && conf.t_pad < conf.kh |
148 | && conf.b_pad < conf.kh; |
149 | if (!is_wino_shape) return status::unimplemented; |
150 | |
151 | const bool is_16oc = conf.oc % 16 == 0; |
152 | const bool is_16ic = conf.ic % 16 == 0; |
153 | |
154 | if (src_mdw.matches_one_of_tag(nhwc) |
155 | && (dst_mdw.matches_one_of_tag(nhwc) |
156 | || dst_mdw.format_kind() == format_kind::any)) { |
157 | // Technically this implementation currently requires ic is a multiple |
158 | // of VTRANS_BLOCK = 4. This condition was not implemented yet due to no |
159 | // known use case, and small IC is expected to have poor performance |
160 | // because of extra work created by the current blocking. |
161 | if (conf.ic_without_padding % 16 != 0 |
162 | || conf.oc_without_padding % 16 != 0) |
163 | return status::unimplemented; |
164 | conf.ver = ver_nhwc; |
165 | } else if ((is_16oc && is_16ic)) { |
166 | conf.ver = (conf.mb % 16 == 0) ? ver_16mb16c : ver_8ow16c; |
167 | } else { |
168 | return status::unimplemented; |
169 | } |
170 | |
171 | const compute::gpu_arch_t arch = engine->device_info()->gpu_arch(); |
172 | fwd_compute_block_sizes(conf, arch); |
173 | if (!is_impl_optimal(conf, cd, arch)) return status::unimplemented; |
174 | |
175 | size_t U_sz = conf.tile_size * conf.kh * conf.wino_ic * conf.wino_oc; |
176 | size_t M_sz = 0, V_sz = 0; |
177 | if (!conf.is_fused) { |
178 | M_sz = conf.tile_size * conf.mb * conf.wino_oc * conf.wino_oh |
179 | * conf.wino_ow; |
180 | V_sz = conf.tile_size * conf.mb * conf.wino_ic * conf.wino_ih |
181 | * conf.wino_iw; |
182 | } |
183 | |
184 | // Limit max problem size since this method uses more memory |
185 | if (U_sz + M_sz + V_sz > 300000000) return status::unimplemented; |
186 | |
187 | //Using F(m, r) for r = 3 and tile_size = m + r - 1 |
188 | if (!conf.is_fused) { |
189 | conf.mb_block = 1; |
190 | conf.lws_d[0] = 8; |
191 | conf.lws_d[1] = 1; |
192 | conf.lws_d[2] = 1; |
193 | conf.gws_d[0] = (conf.wino_oc / conf.wino_oc_block) * conf.lws_d[0]; |
194 | conf.gws_d[1] = conf.wino_ow * (conf.wino_oh / conf.oh_block); |
195 | conf.gws_d[2] = (conf.mb / conf.mb_block) * conf.tile_size; |
196 | |
197 | conf.U_lws_d[0] = 1; |
198 | conf.U_lws_d[1] = 1; |
199 | conf.U_lws_d[2] = 1; |
200 | conf.U_gws_d[0] = 1; |
201 | conf.U_gws_d[1] = 3; // kh or kw depending |
202 | conf.U_gws_d[2] = conf.wino_ic * conf.wino_oc; |
203 | |
204 | conf.V_lws_d[0] = 1; |
205 | conf.V_lws_d[1] = 1; |
206 | conf.V_lws_d[2] = 1; |
207 | conf.V_gws_d[0] = conf.wino_ow; |
208 | conf.V_gws_d[1] = conf.wino_ih; |
209 | conf.V_gws_d[2] = conf.wino_ic / conf.ic_block * conf.mb; |
210 | |
211 | conf.M_lws_d[0] = 1; |
212 | conf.M_lws_d[1] = 1; |
213 | conf.M_lws_d[2] = 1; |
214 | conf.M_gws_d[0] = utils::div_up(conf.ow, conf.ow_block); |
215 | conf.M_gws_d[1] = conf.oh; |
216 | conf.M_gws_d[2] = conf.oc / conf.oc_block * conf.mb; |
217 | } else { |
218 | conf.mb_block = 1; |
219 | conf.lws_d[0] = conf.wino_ic_block / 2; |
220 | conf.lws_d[1] = 8; |
221 | conf.lws_d[2] = 1; |
222 | conf.gws_d[0] |
223 | = utils::div_up(conf.wino_ow, conf.ow_block) * conf.lws_d[0]; |
224 | conf.gws_d[1] |
225 | = utils::div_up(conf.wino_oh, conf.oh_block) * conf.lws_d[1]; |
226 | conf.gws_d[2] = (conf.mb / conf.mb_block) |
227 | * (conf.wino_oc / conf.wino_oc_block); |
228 | |
229 | conf.U_lws_d[0] = conf.wino_ic_block / 2; |
230 | conf.U_lws_d[1] = 1; |
231 | conf.U_lws_d[2] = 1; |
232 | conf.U_gws_d[0] = conf.wino_ic * conf.wino_oc / conf.vect_size; |
233 | conf.U_gws_d[1] = 3; |
234 | conf.U_gws_d[2] = 1; // kh or kw depending |
235 | } |
236 | |
237 | format_tag_t src_tag, dst_tag, wei_tag; |
238 | |
239 | switch (conf.ver) { |
240 | case ver_16mb16c: |
241 | src_tag = NChw16n16c; |
242 | dst_tag = NChw16n16c; |
243 | wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o; |
244 | break; |
245 | case ver_8ow16c: |
246 | src_tag = nChw16c; |
247 | dst_tag = nChw16c; |
248 | wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o; |
249 | break; |
250 | case ver_nhwc: |
251 | src_tag = nhwc; |
252 | dst_tag = nhwc; |
253 | wei_tag = conf.with_groups ? gOIhw16i16o : OIhw16i16o; |
254 | break; |
255 | default: return status::unimplemented; |
256 | } |
257 | |
258 | if (src_mdw.format_kind() == format_kind::any) { |
259 | conf.src_tag = src_tag; |
260 | } else { |
261 | conf.src_tag = src_mdw.matches_one_of_tag(src_tag); |
262 | } |
263 | if (conf.src_tag != src_tag) return status::unimplemented; |
264 | |
265 | if (weights_mdw.format_kind() == format_kind::any) { |
266 | conf.wei_tag = wei_tag; |
267 | } else { |
268 | conf.wei_tag = weights_mdw.matches_one_of_tag(wei_tag); |
269 | } |
270 | if (conf.wei_tag != wei_tag) return status::unimplemented; |
271 | |
272 | if (dst_mdw.format_kind() == format_kind::any) { |
273 | conf.dst_tag = dst_tag; |
274 | } else { |
275 | conf.dst_tag = dst_mdw.matches_one_of_tag(dst_tag); |
276 | } |
277 | if (conf.dst_tag != dst_tag) return status::unimplemented; |
278 | |
279 | return status::success; |
280 | } |
281 | |
282 | void gen9_wino_convolution_fwd_t::pd_t::init_scratchpad() { |
283 | auto scratchpad = scratchpad_registry().registrar(); |
284 | |
285 | auto wei_data_t = this->desc()->weights_desc.data_type; |
286 | size_t U_sz = conf.tile_size * conf.kh * conf.wino_ic * conf.wino_oc; |
287 | scratchpad.book(key_wino_U, U_sz, types::data_type_size(wei_data_t), |
288 | OCL_BUFFER_ALIGNMENT); |
289 | |
290 | if (!conf.is_fused) { |
291 | auto dst_data_t = this->desc()->dst_desc.data_type; |
292 | size_t M_sz = conf.tile_size * conf.mb * conf.wino_oc * conf.wino_oh |
293 | * conf.wino_ow; |
294 | scratchpad.book(key_wino_M, M_sz, types::data_type_size(dst_data_t), |
295 | OCL_BUFFER_ALIGNMENT); |
296 | |
297 | auto src_data_t = this->desc()->src_desc.data_type; |
298 | size_t V_sz = conf.tile_size * conf.mb * conf.wino_ic * conf.wino_ih |
299 | * conf.wino_iw; |
300 | scratchpad.book(key_wino_V, V_sz, types::data_type_size(src_data_t), |
301 | OCL_BUFFER_ALIGNMENT); |
302 | } |
303 | } |
304 | |
305 | status_t gen9_wino_convolution_fwd_t::pd_t::init_kernel_ctx( |
306 | compute::kernel_ctx_t &kernel_ctx) const { |
307 | kernel_ctx.define_int("G" , conf.ngroups); |
308 | kernel_ctx.define_int("MB" , conf.mb); |
309 | kernel_ctx.define_int("IC" , conf.ic); |
310 | kernel_ctx.define_int("ID" , conf.id); |
311 | kernel_ctx.define_int("IH" , conf.ih); |
312 | kernel_ctx.define_int("IW" , conf.iw); |
313 | kernel_ctx.define_int("OC" , conf.oc); |
314 | kernel_ctx.define_int("OD" , conf.od); |
315 | kernel_ctx.define_int("OH" , conf.oh); |
316 | kernel_ctx.define_int("OW" , conf.ow); |
317 | kernel_ctx.define_int("KD" , conf.kd); |
318 | kernel_ctx.define_int("KH" , conf.kh); |
319 | kernel_ctx.define_int("KW" , conf.kw); |
320 | kernel_ctx.define_int("PH" , conf.t_pad); |
321 | kernel_ctx.define_int("PW" , conf.l_pad); |
322 | kernel_ctx.define_int("OCB" , conf.ocb); |
323 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
324 | kernel_ctx.define_int("OH_BLOCK" , conf.oh_block); |
325 | kernel_ctx.define_int("OW_BLOCK" , conf.ow_block); |
326 | kernel_ctx.define_int("OW_LAST" , utils::rnd_dn(conf.ow, conf.ow_block)); |
327 | kernel_ctx.define_int("OWB" , utils::div_up(conf.ow, conf.ow_block)); |
328 | kernel_ctx.define_int("OHB" , utils::div_up(conf.oh, conf.oh_block)); |
329 | kernel_ctx.define_int("OC_WO_PADDING" , conf.oc_without_padding); |
330 | kernel_ctx.define_int("WINO_M" , conf.wino_m); |
331 | kernel_ctx.define_int("WINO_R" , conf.wino_r); |
332 | kernel_ctx.define_int("WINO_IC_BLOCK" , conf.wino_ic_block); |
333 | kernel_ctx.define_int("WINO_OC_BLOCK" , conf.wino_oc_block); |
334 | kernel_ctx.define_int("WINO_IC" , conf.wino_ic); |
335 | kernel_ctx.define_int("WINO_OC" , conf.wino_oc); |
336 | kernel_ctx.define_int("WINO_IH" , conf.wino_ih); |
337 | kernel_ctx.define_int("WINO_IW" , conf.wino_iw); |
338 | kernel_ctx.define_int("WINO_OH" , conf.wino_oh); |
339 | kernel_ctx.define_int("WINO_OW" , conf.wino_ow); |
340 | kernel_ctx.define_int("OC_BLOCK" , conf.oc_block); |
341 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
342 | kernel_ctx.define_int("VECT_DT_N" , conf.vect_size); |
343 | |
344 | kernel_ctx.set_data_type(conf.src_data_type); |
345 | |
346 | kernel_ctx.define_int("VER_8OW16C" , conf.ver == ver_8ow16c); |
347 | kernel_ctx.define_int("VER_16MB16C" , conf.ver == ver_16mb16c); |
348 | |
349 | kernel_ctx.define_int("SRC_NHWC" , utils::one_of(conf.src_tag, nhwc)); |
350 | kernel_ctx.define_int( |
351 | "SRC_16N16C" , utils::one_of(conf.src_tag, NChw16n16c)); |
352 | kernel_ctx.define_int("SRC_W16C" , utils::one_of(conf.src_tag, nChw16c)); |
353 | |
354 | kernel_ctx.define_int( |
355 | "WEI_16I16O" , utils::one_of(conf.wei_tag, gOIhw16i16o, OIhw16i16o)); |
356 | kernel_ctx.define_int("WEI_16I16O_FLIPPED" , |
357 | utils::one_of(conf.wei_tag, gIOhw16i16o, IOhw16i16o)); |
358 | |
359 | kernel_ctx.define_int("DST_NHWC" , utils::one_of(conf.src_tag, nhwc)); |
360 | kernel_ctx.define_int( |
361 | "DST_16N16C" , utils::one_of(conf.dst_tag, NChw16n16c)); |
362 | kernel_ctx.define_int("DST_W16C" , utils::one_of(conf.dst_tag, nChw16c)); |
363 | |
364 | kernel_ctx.define_int("WITH_BIAS" , conf.with_bias); |
365 | |
366 | dnnl_dims_t dst_dims; |
367 | dst_dims[0] = conf.mb; |
368 | dst_dims[1] = conf.oc_without_padding; |
369 | dst_dims[2] = conf.ndims > 4 ? conf.od : conf.oh; |
370 | dst_dims[3] = conf.ndims > 4 ? conf.oh : conf.ow; |
371 | dst_dims[4] = conf.ow; |
372 | def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_, &dst_dims); |
373 | |
374 | kernel_ctx.print_options(); |
375 | return status::success; |
376 | } |
377 | |
378 | status_t gen9_wino_convolution_fwd_t::execute_forward( |
379 | const exec_ctx_t &ctx) const { |
380 | |
381 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
382 | auto &weights = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
383 | auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
384 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
385 | |
386 | const auto &conf = pd()->conf; |
387 | const auto &attr_info = conf.attr_info; |
388 | |
389 | std::unique_ptr<memory_storage_t> wei_trans |
390 | = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_U); |
391 | compute::kernel_arg_list_t wei_transform_args; |
392 | wei_transform_args.set(0, *wei_trans); |
393 | wei_transform_args.set(1, weights); |
394 | auto wei_trans_nd_range = compute::nd_range_t(conf.U_gws_d, conf.U_lws_d); |
395 | status_t status = parallel_for( |
396 | ctx, wei_trans_nd_range, wei_trans_kernel_, wei_transform_args); |
397 | |
398 | if (conf.is_fused) { |
399 | compute::kernel_arg_list_t arg_list; |
400 | arg_list.set(0, dst); |
401 | arg_list.set(1, src); |
402 | arg_list.set(2, *wei_trans); |
403 | arg_list.set(3, bias); |
404 | append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_); |
405 | auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d); |
406 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
407 | } else { |
408 | std::unique_ptr<memory_storage_t> src_trans |
409 | = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_V); |
410 | compute::kernel_arg_list_t src_transform_args; |
411 | src_transform_args.set(0, *src_trans); |
412 | src_transform_args.set(1, src); |
413 | auto src_trans_nd_range |
414 | = compute::nd_range_t(conf.V_gws_d, conf.V_lws_d); |
415 | status = parallel_for( |
416 | ctx, src_trans_nd_range, src_trans_kernel_, src_transform_args); |
417 | |
418 | std::unique_ptr<memory_storage_t> M_buf |
419 | = ctx.get_scratchpad_grantor().get_memory_storage(key_wino_M); |
420 | compute::kernel_arg_list_t arg_list; |
421 | arg_list.set(0, *M_buf); |
422 | arg_list.set(1, *src_trans); |
423 | arg_list.set(2, *wei_trans); |
424 | auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d); |
425 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
426 | |
427 | compute::kernel_arg_list_t dst_transform_args; |
428 | dst_transform_args.set(0, dst); |
429 | dst_transform_args.set(1, *M_buf); |
430 | dst_transform_args.set(2, bias); |
431 | append_post_ops_to_arg_list( |
432 | ctx, dst_transform_args, 3, pd()->attr()->post_ops_); |
433 | auto dst_trans_nd_range |
434 | = compute::nd_range_t(conf.M_gws_d, conf.M_lws_d); |
435 | status = parallel_for( |
436 | ctx, dst_trans_nd_range, dst_trans_kernel_, dst_transform_args); |
437 | } |
438 | |
439 | if (attr_info.with_eltwise |
440 | && !gpu_eltwise_fwd_pd_t::eltwise_preserves_zero( |
441 | attr_info.eltwise_alg, attr_info.eltwise_alpha, |
442 | attr_info.eltwise_beta)) { |
443 | ctx.zero_pad_output(DNNL_ARG_DST); |
444 | } |
445 | return status; |
446 | } |
447 | } // namespace ocl |
448 | } // namespace gpu |
449 | } // namespace impl |
450 | } // namespace dnnl |
451 | |
452 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
453 | |