1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/conv/config.hpp"
18
19#include <cctype>
20#include <cstring>
21
22#include "common/type_helpers.hpp"
23#include "gpu/jit/conv/block_helper.hpp"
24#include "gpu/jit/conv/config_lookup_table.hpp"
25#include "gpu/jit/conv/grf_usage.hpp"
26#include "gpu/jit/conv/normalization.hpp"
27#include "gpu/jit/ir/block_2d_utils.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace gpu {
32namespace jit {
33
34// Helper functions.
35layout_t make_layout(const memory_desc_t &md) {
36 if (md.format_kind == format_kind::any) return layout_t();
37 return layout_t(md, /*do_normalize=*/false);
38}
39
40layout_t make_layout(const memory_desc_t &md, const std::string &tag) {
41 return layout_t(md, tag, /*do_normalize=*/false);
42}
43
44layout_t make_layout(const type_t &type, const std::vector<dim_t> &dims,
45 const std::string &tag) {
46 return layout_t(type, 0, tag, dims, /*do_normalize=*/false);
47}
48
49void set_default_format(memory_desc_t &md, const std::string &tag) {
50 if (md.format_kind != format_kind::any) return;
51 md = make_layout(md, tag).to_dnnl(md.dims);
52}
53
54bool matches_tag(const layout_t &layout, const std::string &tag) {
55 if (layout.is_empty()) return false;
56 auto tag_layout = make_layout(layout.type(), layout.dims(), tag);
57 if (layout != tag_layout) return false;
58 return true;
59}
60
61bool matches_tag_strict(const layout_t &layout, const std::string &tag) {
62 if (layout.is_empty()) return false;
63 auto tag_layout = make_layout(layout.type(), layout.dims(), tag);
64 if (!layout.is_strictly_equal(tag_layout)) return false;
65 return true;
66}
67
68bool matches_tag(const memory_desc_t &md, const std::string &tag) {
69 if (md.format_kind == format_kind::any) return false;
70 return matches_tag(make_layout(md), tag);
71}
72
73bool matches_tag_strict(const memory_desc_t &md, const std::string &tag) {
74 if (md.format_kind == format_kind::any) return false;
75 return matches_tag_strict(make_layout(md), tag);
76}
77
78layout_t init_layout(memory_desc_t &user_md, const std::string &optimal_tag) {
79 auto optimal = make_layout(user_md, optimal_tag);
80 if (user_md.format_kind != format_kind::any) {
81 auto user = make_layout(user_md);
82 // If layouts are physically different return the layout passed by
83 // the user and return unimplemented later.
84 if (user != optimal) return user;
85 } else {
86 user_md = optimal.to_dnnl(user_md.dims);
87 }
88 return optimal;
89}
90
91std::string prepend_groups_to_tag(const std::string &tag) {
92 auto ret = tag;
93 for (auto &c : ret) {
94 bool is_lower_dim = ('a' <= c && c < 'a' + DNNL_MAX_NDIMS);
95 bool is_upper_dim = ('A' <= c && c < 'A' + DNNL_MAX_NDIMS);
96 if (!is_lower_dim && !is_upper_dim) continue;
97 c += 1;
98 }
99 return "a" + ret;
100}
101
102bool is_small_ic(const conv_problem_t &prb) {
103 int size = (int)types::data_type_size(prb.src_data_type);
104 if (size >= 4)
105 return prb.ic <= 8;
106 else
107 return prb.ic * size <= 16;
108}
109
110bool is_small_oc(const conv_problem_t &prb) {
111 int size = (int)types::data_type_size(prb.dst_data_type);
112 if (size >= 4)
113 return prb.oc <= 8;
114 else
115 return prb.oc * size <= 16;
116}
117
118bool is_dw_large_mb(const conv_problem_t &prb) {
119 return prb.is_dw && prb.mb >= 16;
120}
121
122status_t conv_problem_t::init(
123 const engine_t *engine, const convolution_pd_t *conv_pd) {
124 using namespace compute;
125
126 if (conv_pd->has_zero_dim_memory()) return status::unimplemented;
127
128 this->conv_pd = conv_pd;
129 attr = conv_pd->attr();
130 is_fwd = conv_pd->is_fwd();
131 is_bwd_d = conv_pd->is_bwd_d();
132 is_bwd_w = conv_pd->is_bwd_w();
133 with_bias = conv_pd->with_bias();
134 with_groups = conv_pd->with_groups();
135 with_sum = with_sum_post_op();
136
137 src_data_type = conv_pd->invariant_src_md()->data_type;
138 wei_data_type = conv_pd->invariant_wei_md()->data_type;
139 bia_data_type = conv_pd->invariant_bia_md()->data_type;
140 dst_data_type = conv_pd->invariant_dst_md()->data_type;
141 fpmath_mode = attr->fpmath_mode_;
142
143 ndims = conv_pd->ndims();
144
145 mb = conv_pd->MB();
146 g = conv_pd->G();
147 ic = ir_utils::safe_divide(conv_pd->IC(), g);
148 oc = ir_utils::safe_divide(conv_pd->OC(), g);
149
150 // Input spatial.
151 id = conv_pd->ID();
152 ih = conv_pd->IH();
153 iw = conv_pd->IW();
154
155 // Output spatial.
156 od = conv_pd->OD();
157 oh = conv_pd->OH();
158 ow = conv_pd->OW();
159
160 // Kernel sizes.
161 kd = conv_pd->KD();
162 kh = conv_pd->KH();
163 kw = conv_pd->KW();
164
165 // Strides.
166 sd = conv_pd->KSD();
167 sh = conv_pd->KSH();
168 sw = conv_pd->KSW();
169
170 // Padding.
171 pd = conv_pd->padFront();
172 ph = conv_pd->padT();
173 pw = conv_pd->padL();
174
175 // Dilation.
176 dd = conv_pd->KDD();
177 dh = conv_pd->KDH();
178 dw = conv_pd->KDW();
179
180 try_reduce_to_1d();
181
182 is_dw = with_groups && (g > 1) && (oc == 1) && (ic == 1);
183 ksp = kd * kh * kw;
184 isp = id * ih * iw;
185 osp = od * oh * ow;
186
187 auto *compute_engine = utils::downcast<const compute_engine_t *>(engine);
188 auto *device_info = compute_engine->device_info();
189 gpu_arch_t gpu_arch = device_info->gpu_arch();
190 auto hw = convert_dnnl_arch_to_ngen(gpu_arch);
191
192 CHECK(init_abc_data_types(hw));
193 CHECK(init_acc_data_type());
194 CHECK(init_zero_points_config());
195
196 return status::success;
197}
198
199void conv_problem_t::try_reduce_to_1d() {
200 bool is_1x1 = (kd * kh * kw == 1);
201 bool is_eq_oi = (od == id && oh == ih && ow == iw);
202 bool is_iw_1 = iw == 1 && kw == 1 && pw == 0 && ow == 1;
203 bool is_ih_1 = ih == 1 && kh == 1 && ph == 0 && oh == 1;
204 reduced_dim = 0;
205 auto shift_oh_to_ow = [&]() {
206 ow = oh;
207 iw = ih;
208 ih = 1;
209 oh = 1;
210 kw = kh;
211 kh = 1;
212 pw = ph;
213 ph = 0;
214 sw = sh;
215 sh = 1;
216 dw = dh;
217 dh = 0;
218 reduced_dim += 1;
219 };
220 auto shift_od_to_oh = [&]() {
221 oh = od;
222 ih = id;
223 id = 1;
224 od = 1;
225 kh = kd;
226 kd = 1;
227 ph = pd;
228 pd = 0;
229 sh = sd;
230 sd = 1;
231 dh = dd;
232 dd = 0;
233 reduced_dim += 1;
234 };
235
236 if (is_iw_1) { shift_oh_to_ow(); }
237 if (is_ih_1 || is_iw_1) { shift_od_to_oh(); }
238 if (is_iw_1 && is_ih_1) { shift_oh_to_ow(); }
239
240 if (is_1x1 && is_stride1() && is_eq_oi) {
241 ir_assert(pd == 0 && ph == 0 && pw == 0);
242 ow = od * oh * ow;
243 iw = id * ih * iw;
244 od = id = kd = 1;
245 oh = ih = kh = 1;
246 reduced_dim = 3;
247 }
248}
249
250status_t conv_problem_t::init_zero_points_config() {
251 zp_cfg = zero_points_config_t();
252 zp_cfg.do_src_compensation
253 = !attr->zero_points_.has_default_values(DNNL_ARG_SRC);
254 zp_cfg.do_dst_compensation
255 = !attr->zero_points_.has_default_values(DNNL_ARG_DST);
256 zp_cfg.is_runtime_src_zero_points
257 = !attr->zero_points_.defined(DNNL_ARG_SRC);
258 zp_cfg.is_runtime_dst_zero_points
259 = !attr->zero_points_.defined(DNNL_ARG_DST);
260 zp_cfg.is_common_src_zero_point = attr->zero_points_.common(DNNL_ARG_SRC);
261 zp_cfg.is_common_dst_zero_point = attr->zero_points_.common(DNNL_ARG_DST);
262 zp_cfg.common_src_zero_point = 0;
263 zp_cfg.common_dst_zero_point = 0;
264 return status::success;
265}
266
267std::string conv_problem_t::desc_str(bool print_mb) const {
268 std::ostringstream oss;
269 if (print_mb) oss << "mb" << mb;
270 if (g > 1) oss << "g" << g;
271 oss << "ic" << ic;
272
273 std::vector<int> xd = {id, od, kd, sd, dd, pd};
274 std::vector<int> xh = {ih, oh, kh, sh, dh, ph};
275 std::vector<int> xw = {iw, ow, kw, sw, dw, pw};
276 std::vector<int> xdef = {1, 1, 1, 1, 0, 0};
277 bool has_d = !ir_utils::is_equal(xd, xdef);
278 bool has_h = !ir_utils::is_equal(xh, xdef);
279 bool is_square = ir_utils::is_equal(xh, xw);
280 bool is_cubic = is_square && ir_utils::is_equal(xd, xh);
281 bool print_d = has_d;
282 bool print_h = has_h && !is_cubic;
283 bool print_w = !is_cubic && !is_square;
284
285 if (print_d) oss << "id" << id;
286 if (print_h) oss << "ih" << ih;
287 if (print_w) oss << "iw" << iw;
288 oss << "oc" << oc;
289 if (print_d) oss << "od" << od;
290 if (print_h) oss << "oh" << oh;
291 if (print_w) oss << "ow" << ow;
292 if (print_d) oss << "kd" << kd;
293 if (print_h) oss << "kh" << kh;
294 if (print_w) oss << "kw" << kw;
295 if (print_d && sd != 1) oss << "sd" << sd;
296 if (print_h && sh != 1) oss << "sh" << sh;
297 if (print_w && sw != 1) oss << "sw" << sw;
298 if (print_d && dd != 0) oss << "dd" << dd;
299 if (print_h && dh != 0) oss << "dh" << dh;
300 if (print_w && dw != 0) oss << "dw" << dw;
301 if (print_d) oss << "pd" << pd;
302 if (print_h) oss << "ph" << ph;
303 if (print_w) oss << "pw" << pw;
304 return oss.str();
305}
306
307int get_default_max_tg_size(const hw_config_t &hw_cfg, int regs, int simd) {
308 const compute::gpu_arch_t arch = convert_ngen_arch_to_dnnl(hw_cfg.hw());
309 const int max_eus_per_wg = compute::device_info_t::max_eus_per_wg(arch);
310 const int threads_per_eu
311 = compute::device_info_t::threads_per_eu(arch, regs > 128);
312 const int wg_per_thr = simd * compute::device_info_t::threads_per_eu(arch)
313 / threads_per_eu;
314
315 // Optimal thread group size may differ from hardware thread count due
316 // to simd_size used in computation.
317 return std::min(max_eus_per_wg * utils::rnd_down_pow2(threads_per_eu),
318 static_cast<int>(hw_cfg.max_wg_size() / wg_per_thr));
319}
320
321std::vector<dim_t> get_prelu_weights_dims(
322 uint32_t mask, const memory_desc_t &md) {
323 std::vector<dim_t> dims(md.dims, md.dims + md.ndims);
324 for (int i = 0; i < md.ndims; ++i)
325 dims[i] = (mask & (1 << i)) ? dims[i] : 1;
326 return dims;
327}
328
329std::string build_tag(const std::vector<int> &inner_blocks,
330 const std::vector<int> &outer_blocks, const std::vector<char> &letters,
331 const std::vector<int> &idxs) {
332 size_t n = letters.size();
333 ir_assert(inner_blocks.size() == n);
334 ir_assert(outer_blocks.size() == n);
335 ir_assert(idxs.size() == n);
336
337 std::string tag;
338 std::vector<bool> seen(n);
339
340 // Iterate through outer blocks.
341 for (int i = (int)n - 1; i >= 0; i--) {
342 int idx = idxs[i];
343 int blk = outer_blocks[idx];
344 if (blk == 1) continue;
345 seen[idx] = true;
346 tag += std::to_string(blk) + letters[idx];
347 }
348
349 // Iterate through inner blocks.
350 for (int i = (int)n - 1; i >= 0; i--) {
351 int idx = idxs[i];
352 int blk = inner_blocks[idx];
353 if (blk == 1) continue;
354 seen[idx] = true;
355 tag += std::to_string(blk) + letters[idx];
356 }
357
358 if (tag.empty()) {
359 // Assume this is an activations tag, use NHWC by default.
360 tag = "axb";
361 } else {
362 tag = 'x' + tag;
363 for (int i = (int)n - 1; i >= 0; i--) {
364 char c = letters[i];
365 if (c == ' ') continue;
366 if (seen[i]) c = std::toupper(c);
367 tag = c + tag;
368 }
369 }
370
371 return tag;
372}
373
374int pick_block_impl(bool prefer_rnd_up, int dim, int b0, int b1, int b2) {
375 int blocks[3] = {b0, b1, b2};
376 int prev_blk = 1;
377 for (int i = 0; i < 3; i++) {
378 if (blocks[i] == 0) continue;
379 if (prefer_rnd_up) {
380 if (dim <= blocks[i] / 2) return prev_blk;
381 } else {
382 if (dim < blocks[i]) return prev_blk;
383 }
384 prev_blk = blocks[i];
385 }
386 return prev_blk;
387}
388
389int pick_block_rnd_up(int dim, int b0, int b1 = 0, int b2 = 0) {
390 return pick_block_impl(true, dim, b0, b1, b2);
391}
392
393int pick_block(int dim, int b0, int b1 = 0, int b2 = 0) {
394 return pick_block_impl(false, dim, b0, b1, b2);
395}
396
397struct nc_block_t {
398 nc_block_t(int n_block, int c_block, bool nc_order = true)
399 : n_block_(n_block), c_block_(c_block), nc_order_(nc_order) {}
400
401 std::string tag() const {
402 std::vector<int> idxs = {1, 0};
403 if (!nc_order_) std::swap(idxs[0], idxs[1]);
404 return build_tag({n_block_, c_block_}, {1, 1}, {'a', 'b'}, idxs);
405 }
406
407 // Ideally, this should only depend on data type, direction, mb, c, and g to
408 // enable the same src/dst formats and avoid reorders between convolutions
409 static nc_block_t get_default_blocking(type_t type, bool is_dw, int n,
410 int c, int g, bool is_input, bool is_small_c,
411 int min_block_size = 0, bool nc_order = true,
412 bool force_default_c_blk = false) {
413 bool is_small_c_input
414 = (type.size() <= 2 && is_input && g == 1 && is_small_c);
415 auto default_c_blk = type.size() == 1 ? 32 : 16;
416 auto c_block = [&]() {
417 if (force_default_c_blk) return default_c_blk;
418 // Special case for small input channel shapes with dpas.
419 if (is_small_c_input) {
420 int packed_dword_elems = 4 / type.size();
421 return std::max(packed_dword_elems, utils::rnd_up_pow2(c));
422 }
423 auto blk_dim = is_dw ? g : g * c;
424 return pick_block_rnd_up(blk_dim, default_c_blk);
425 }();
426
427 // Non-depthwise convolutions currently require channel is a multiple of
428 // c_block. If that implementation restriction is removed, this logic
429 // could be removed.
430 if (g > 1 && !is_dw && c % c_block != 0 && c_block % c != 0)
431 c_block = 1;
432
433 auto default_n_blk = type.size() < 4 ? 32 : 16;
434 auto n_block = [&]() {
435 if (c_block == 1)
436 return 1;
437 else if (is_small_c_input)
438 return pick_block(n, 8, 16);
439 else
440 return pick_block(n, 16, default_n_blk);
441 }();
442
443 // Require minimum block size, used to enable better message behavior
444 while (n_block * c_block * type.size() < min_block_size) {
445 // Prefer increasing blocks in dimensions with available data, and
446 // otherwise just increase c_block to meet requirements. Limit
447 // blocking dimensions to avoid untested edge cases.
448 if (c_block < c && c_block < default_c_blk)
449 c_block *= 2;
450 else if (n_block < n && n_block < default_n_blk)
451 n_block *= 2;
452 else {
453 c_block = utils::div_up(min_block_size, type.size() * n_block);
454 if (c_block > default_c_blk) c_block = default_c_blk;
455 break;
456 }
457 }
458
459 return nc_block_t(n_block, c_block, nc_order);
460 }
461
462private:
463 int n_block_;
464 int c_block_;
465 bool nc_order_;
466};
467
468struct goi_block_t {
469 goi_block_t(fma_kind_t fma_kind, bool is_dw, bool is_bwd_d, int g_block,
470 int o_block, int i_block, int o_block_outer, int i_block_outer)
471 : fma_kind_(fma_kind)
472 , is_dw_(is_dw)
473 , is_bwd_d_(is_bwd_d)
474 , g_block_(g_block)
475 , o_block_(o_block)
476 , i_block_(i_block)
477 , o_block_outer_(o_block_outer)
478 , i_block_outer_(i_block_outer) {}
479
480 std::string tag() const {
481 std::vector<char> wei_letters(3, ' ');
482 char wei_letter = 'a';
483 for (int i = (is_dw_ ? 0 : 1); i < 3; i++) {
484 wei_letters[i] = wei_letter++;
485 }
486 std::vector<int> wei_idxs = {0, 1, 2}; // g, ic, oc
487 // dpas requires ic to go before oc in innermost blocks for weights.
488 if (fma_kind_ != fma_kind_t::mad) std::swap(wei_idxs[1], wei_idxs[2]);
489 if (is_bwd_d_) std::swap(wei_idxs[1], wei_idxs[2]);
490 return build_tag({g_block_, o_block_, i_block_},
491 {1, o_block_outer_, i_block_outer_}, wei_letters, wei_idxs);
492 }
493
494 static goi_block_t get_default_blocking(type_t type, int vec_size,
495 fma_kind_t fma_kind, bool is_bwd_d, bool is_small_ic, int g, int o,
496 int i) {
497 int x = o;
498 int y = i;
499 int g_block = 1;
500 int o_block = 1;
501 int i_block = 1;
502 int o_block_outer = 1;
503 int i_block_outer = 1;
504 int *x_block = &o_block;
505 int *y_block = &i_block;
506 int *x_block_outer = &o_block_outer;
507 int *y_block_outer = &i_block_outer;
508 // Backward by data requires flipped ic/oc in weights.
509 if (is_bwd_d) {
510 std::swap(x, y);
511 std::swap(x_block, y_block);
512 std::swap(x_block_outer, y_block_outer);
513 }
514 get_default_blocking(type, vec_size, fma_kind, is_bwd_d, is_small_ic, g,
515 x, y, g_block, *x_block, *y_block, *x_block_outer,
516 *y_block_outer);
517 return goi_block_t(fma_kind, is_dw(g, o, i), is_bwd_d, g_block, o_block,
518 i_block, o_block_outer, i_block_outer);
519 }
520
521 static void get_default_blocking(type_t type, int vec_size,
522 fma_kind_t fma_kind, bool is_bwd_d, bool is_small_ic, int g, int x,
523 int y, int &g_block, int &x_block, int &y_block, int &x_block_outer,
524 int &y_block_outer) {
525 if (is_dw(g, x, y)) {
526 g_block = type.is_x8() ? 32 : 16;
527 } else if (fma_kind == fma_kind_t::mad) {
528 x_block = vec_size;
529 y_block = pick_block(y, 8, 16);
530 } else {
531 int packed_dword_elems = 4 / type.size();
532 x_block = vec_size;
533 y_block = packed_dword_elems;
534 if (is_bwd_d || !is_small_ic) y_block_outer = 8;
535 }
536 }
537
538private:
539 static bool is_dw(int g, int o, int i) {
540 return (g > 1 && o == 1 && i == 1);
541 }
542
543 fma_kind_t fma_kind_;
544 bool is_dw_;
545 bool is_bwd_d_;
546 int g_block_;
547 int o_block_;
548 int i_block_;
549 int o_block_outer_;
550 int i_block_outer_;
551};
552
553// TODO: Remove this logic and switch to an IR generation-driven flow.
554bool can_use_2d_send(const conv_config_t &cfg, const layout_t &l, bool is_a) {
555 const auto &prb = cfg.prb();
556 bool is_b = !is_a;
557 if (!cfg.is_ge_xe_hpc()) return false;
558
559 bool with_blocking
560 = !cfg.iter_dims().is_empty() || !cfg.loop_dims().is_empty();
561
562 // Can't use 2D block messages for non-trivial strided dimensions.
563 if (is_a && (prb.is_fwd || prb.is_bwd_w) && prb.sw != 1
564 && (prb.kw != 1 || prb.pw != 0)) {
565 if (with_blocking) {
566 if (cfg.iter_dim({"osp", "ow", "iw"}) > 1) return false;
567 } else if (prb.mb < 16) {
568 return false;
569 }
570 }
571 if (is_a && prb.is_bwd_d && prb.sw != 1) {
572 if (with_blocking && cfg.iter_dim({"osp", "ow", "iw"}) > 1) {
573 return false;
574 } else if (prb.mb < 16) {
575 return false;
576 }
577 }
578
579 // Can't use 2D block messages for compound blocks.
580 if (is_a && with_blocking) {
581 bool has_mb_block = (cfg.iter_dim("mb") > 1);
582 bool has_sp_block = (cfg.iter_dim({"osp", "ow", "iw"}) > 1);
583 if (has_mb_block && has_sp_block) return false;
584 }
585
586 // 2D messages does not support vnni format with 4 byte elements
587 if (type_t(prb.b_data_type).size() >= 4) return false;
588
589 auto is_plain_ok = [&]() {
590 if (is_a || prb.is_bwd_w) return matches_tag_strict(l, "axb");
591 if (is_b && l.is_empty()) return true;
592 if (is_b && prb.is_fwd) return matches_tag_strict(l, "xba");
593 if (is_b && prb.is_bwd_d) return matches_tag_strict(l, "xab");
594 return false;
595 };
596
597 if (!is_plain_ok()) return false;
598
599 // Check 2D block message limitations.
600 // Layouts for A:
601 // FWD: NHWC (src)
602 // BWD_D: NHWC (dst)
603 // BWD_W: NHWC (src)
604 // Layouts for B:
605 // FWD: HWIO (wei)
606 // BWD_D: HWOI (wei)
607 // BWD_W: NHWC (dst)
608 int a_width = (prb.is_fwd || prb.is_bwd_w) ? prb.ic : prb.oc;
609 int b_width = (prb.is_fwd || prb.is_bwd_w) ? prb.oc : prb.ic;
610 int a_max_height
611 = std::max((prb.is_fwd || prb.is_bwd_w) ? prb.iw : prb.ow, prb.mb);
612 int b_max_height = prb.is_fwd
613 ? prb.ic
614 : (prb.is_bwd_d ? prb.oc : std::max(prb.ow, prb.mb));
615 int a_max_pitch = (prb.is_fwd || prb.is_bwd_w) ? (prb.ic * prb.isp)
616 : (prb.oc * prb.osp);
617 int b_max_pitch
618 = (prb.is_fwd || prb.is_bwd_d) ? b_width : (prb.oc * prb.osp);
619 int data_type_size = (is_a ? prb.a_data_type_size : prb.b_data_type_size);
620 int width = (is_a ? a_width : b_width);
621 int max_height = (is_a ? a_max_height : b_max_height);
622 int max_pitch = (is_a ? a_max_pitch : b_max_pitch);
623 if (!block_2d_width_ok(width, data_type_size)) return false;
624 if (!block_2d_height_ok(max_height)) return false;
625 if (!block_2d_pitch_ok(cfg.hw_cfg(), width, data_type_size)) return false;
626 if (!block_2d_pitch_ok(cfg.hw_cfg(), max_pitch, data_type_size))
627 return false;
628 return true;
629}
630
631void init_data_tags(const conv_config_t &cfg, bool allow_src_reorder,
632 bool allow_wei_reorder, bool allow_dst_reorder,
633 const memory_desc_t &src_md, const memory_desc_t &wei_md,
634 const memory_desc_t &dst_md,
635
636 std::string &src_tag, std::string &wei_tag, std::string &dst_tag,
637 std::string &user_wei_tag) {
638 const auto &prb = cfg.prb();
639 auto src_compute_type = prb.is_bwd_d ? prb.c_data_type : prb.a_data_type;
640 auto dst_compute_type = prb.is_fwd
641 ? prb.c_data_type
642 : (prb.is_bwd_d ? prb.a_data_type : prb.b_data_type);
643 auto wei_compute_type = prb.is_bwd_w ? prb.c_data_type : prb.b_data_type;
644
645 int src_type_size = (int)types::data_type_size(src_compute_type);
646
647 // Prefer larger messages for large mb bwd_w
648 bool is_bwd_w_message_opt = prb.is_bwd_w && src_type_size <= 2
649 && allow_src_reorder && prb.mb >= 16;
650 int min_block_size = is_bwd_w_message_opt ? 128 : 0;
651 bool nc_order = is_bwd_w_message_opt ? false : true;
652
653 nc_block_t src_blk = nc_block_t::get_default_blocking(src_compute_type,
654 prb.is_dw, prb.mb, prb.ic, prb.g, prb.is_fwd || prb.is_bwd_w,
655 is_small_ic(prb), min_block_size, nc_order);
656 // TODO: Force use of default_c_blk for bwd_w with bias due to reduction
657 // limitation to register granularity
658 nc_block_t dst_blk = nc_block_t::get_default_blocking(dst_compute_type,
659 prb.is_dw, prb.mb, prb.oc, prb.g, prb.is_bwd_d || prb.is_bwd_w,
660 is_small_oc(prb), 0, true, prb.is_bwd_w && prb.with_bias);
661
662 auto wei_blk = goi_block_t::get_default_blocking(wei_compute_type,
663 cfg.vec_size(), cfg.fma_kind(), prb.is_bwd_d, is_small_ic(prb),
664 prb.g, prb.oc, prb.ic);
665
666 src_tag = src_blk.tag();
667 wei_tag = wei_blk.tag();
668 dst_tag = dst_blk.tag();
669
670 // Use OhwIXoYi weights for small-channel forward convolution to ensure
671 // c-after-w order of reduction blocks to match the source layout.
672 if (is_small_ic(prb) && !prb.is_dw && prb.is_fwd && cfg.is_dp_fma()) {
673 const char *patterns[] = {"ABx", "AxB", "Abx", "Axb", nullptr};
674 bool found = false;
675 for (auto *p = patterns; *p; p += 2) {
676 auto pos = wei_tag.find(*p);
677 if (pos == std::string::npos) continue;
678 wei_tag = wei_tag.replace(pos, std::strlen(*p), *(p + 1));
679 found = true;
680 break;
681 }
682 ir_assert(found) << wei_tag;
683 }
684
685 // Align weights layout between forward/backward by data in some cases via
686 // internal reorder to eliminate user-side reorder.
687 auto fwd_wei_blk = goi_block_t::get_default_blocking(wei_compute_type,
688 cfg.vec_size(), cfg.fma_kind(), /*is_bwd_d=*/false,
689 is_small_ic(prb), prb.g, prb.oc, prb.ic);
690 auto fwd_wei_tag = fwd_wei_blk.tag();
691 if (fwd_wei_tag != wei_tag && allow_wei_reorder) {
692 user_wei_tag = fwd_wei_tag;
693 }
694
695 // Override compute layouts when using nhwc with block 2D messages.
696 bool a_2d_ok = can_use_2d_send(cfg, make_layout(prb.a_md()), true);
697 bool b_2d_ok = can_use_2d_send(cfg, make_layout(prb.b_md()), false);
698 if (a_2d_ok && b_2d_ok) {
699 if (prb.is_bwd_d && !is_small_ic(prb)) {
700 wei_tag = "xab";
701 } else {
702 wei_tag = "xba";
703 }
704 user_wei_tag = "xba";
705 }
706
707 // Override compute layouts for nhwc case.
708 bool src_matches = matches_tag(src_md, src_tag);
709 bool dst_matches = matches_tag(dst_md, dst_tag);
710 bool src_axb = matches_tag(src_md, "axb");
711 bool dst_axb = matches_tag(dst_md, "axb");
712 if (src_axb && dst_axb && (!src_matches || !dst_matches)) {
713 if (!allow_src_reorder) src_tag = "axb";
714 if (!allow_dst_reorder) dst_tag = "axb";
715 }
716
717 // Override compute layouts for plain outputs.
718 if (prb.is_fwd && dst_axb) dst_tag = "axb";
719 if (prb.is_bwd_d && src_axb) src_tag = "axb";
720}
721
722status_t init_tensor_layouts(conv_config_t &cfg, convolution_pd_t *pd) {
723 const auto &prb = cfg.prb();
724 // Compute layout tags and user layout tags. If a compute layout is
725 // different from a user layout then an extra pre/post reorder will be
726 // executed before/after convolution.
727 std::string src_tag, user_src_tag;
728 std::string wei_tag, user_wei_tag;
729 std::string dst_tag, user_dst_tag;
730
731 auto &src_md = *pd->invariant_src_md();
732 auto &wei_md = *pd->invariant_wei_md();
733 auto &dst_md = *pd->invariant_dst_md();
734 auto &bia_md = *pd->invariant_bia_md();
735
736 // If src/dst is nhwc then set the other one with any to nhwc too (except
737 // 1st convolution).
738 bool is_small_ic_non_dw = is_small_ic(prb) && prb.g == 1;
739 bool is_small_oc_non_dw = is_small_oc(prb) && prb.g == 1;
740 bool propagate_nhwc = (matches_tag(src_md, "axb") && !is_small_ic_non_dw)
741 || matches_tag(dst_md, "axb");
742 if (propagate_nhwc) {
743 set_default_format(src_md, "axb");
744 set_default_format(dst_md, "axb");
745 }
746
747 bool allow_src_reorder = false;
748 // Allow internal weights reorder in some cases. The goal is to have
749 // aligned weights layouts between fwd/bwd_d/bwd_w to reduce potential
750 // weights reorders during training. In general it's more efficient than
751 // the external reorder.
752 bool allow_wei_reorder = cfg.is_ge_xe_hpc() && cfg.is_dp_fma();
753 bool allow_dst_reorder = false;
754 bool src_abx = matches_tag(src_md, "abx");
755 bool src_axb = matches_tag(src_md, "axb");
756 if ((src_abx || src_axb) && (prb.is_fwd || prb.is_bwd_w)
757 && is_small_ic_non_dw) {
758 allow_src_reorder = true;
759 }
760
761 init_data_tags(cfg, allow_src_reorder, allow_wei_reorder, allow_dst_reorder,
762 src_md, wei_md, dst_md, src_tag, wei_tag, dst_tag, user_wei_tag);
763
764 if (allow_src_reorder) {
765 if (src_abx) user_src_tag = "abx";
766 if (src_axb) user_src_tag = "axb";
767 }
768
769 // Prefer nhwc for small-channel inputs.
770 if (user_src_tag.empty() && prb.is_fwd && is_small_ic_non_dw) {
771 if (!matches_tag(src_md, src_tag)) user_src_tag = "axb";
772 }
773 if (user_dst_tag.empty() && prb.is_bwd_d && is_small_oc_non_dw) {
774 if (!matches_tag(dst_md, dst_tag)) user_dst_tag = "axb";
775 }
776
777 // Allow internal reorder from oihw/ohwi to more optimal weights layout.
778 if (allow_wei_reorder) {
779 if (matches_tag(wei_md, "abx")) user_wei_tag = "abx";
780 if (matches_tag(wei_md, "axb")) user_wei_tag = "axb";
781 }
782
783 if (user_src_tag.empty()) user_src_tag = src_tag;
784 if (user_wei_tag.empty()) user_wei_tag = wei_tag;
785 if (user_dst_tag.empty()) user_dst_tag = dst_tag;
786
787 bool wei_prepend_groups = (prb.with_groups && !prb.is_dw);
788 if (wei_prepend_groups) {
789 wei_tag = prepend_groups_to_tag(wei_tag);
790 user_wei_tag = prepend_groups_to_tag(user_wei_tag);
791 }
792
793 // Select user layouts.
794 auto user_src_layout = init_layout(src_md, user_src_tag);
795 auto user_wei_layout = init_layout(wei_md, user_wei_tag);
796 auto user_dst_layout = init_layout(dst_md, user_dst_tag);
797
798 layout_t user_bia_layout;
799 if (prb.with_bias) user_bia_layout = init_layout(bia_md, "a");
800
801 if (!user_src_layout.is_strictly_equal(make_layout(src_md, user_src_tag)))
802 return status::unimplemented;
803 if (!user_dst_layout.is_strictly_equal(make_layout(dst_md, user_dst_tag)))
804 return status::unimplemented;
805 if (!user_wei_layout.is_strictly_equal(make_layout(wei_md, user_wei_tag)))
806 return status::unimplemented;
807
808 auto src_layout = (src_tag != user_src_tag) ? make_layout(src_md, src_tag)
809 : user_src_layout;
810 auto wei_layout = (wei_tag != user_wei_tag) ? make_layout(wei_md, wei_tag)
811 : user_wei_layout;
812 auto dst_layout = (dst_tag != user_dst_tag) ? make_layout(dst_md, dst_tag)
813 : user_dst_layout;
814 auto bia_layout = user_bia_layout;
815
816 if (prb.is_bwd_w) {
817 if (prb.wei_data_type == data_type::bf16)
818 wei_layout = wei_layout.retype(type_t::f32());
819 if (prb.bia_data_type == data_type::bf16)
820 bia_layout = bia_layout.retype(type_t::f32());
821 }
822
823 auto &src = cfg.src_layout();
824 auto &wei = cfg.wei_layout();
825 auto &dst = cfg.dst_layout();
826 auto &bia = cfg.bia_layout();
827
828 src.set_compute_unnormalized(src_layout);
829 src.set_user_unnormalized(user_src_layout);
830 wei.set_compute_unnormalized(wei_layout);
831 wei.set_user_unnormalized(user_wei_layout);
832 dst.set_compute_unnormalized(dst_layout);
833 dst.set_user_unnormalized(user_dst_layout);
834 bia.set_compute_unnormalized(bia_layout);
835 bia.set_user_unnormalized(user_bia_layout);
836
837 // Normalize layouts: add group dimension for all layouts and reduce/fuse
838 // spatial dimensions when applicable.
839 normalize_conv_layouts(src_layout, wei_layout, dst_layout, bia_layout,
840 prb.with_groups, prb.g, prb.ic, prb.oc, prb.is_dw, prb.reduced_dim,
841 /*fuse_spatial=*/false,
842 /*add_groups=*/true);
843 normalize_conv_layouts(user_src_layout, user_wei_layout, user_dst_layout,
844 user_bia_layout, prb.with_groups, prb.g, prb.ic, prb.oc, prb.is_dw,
845 prb.reduced_dim,
846 /*fuse_spatial=*/false,
847 /*add_groups=*/true);
848
849 src.set_compute(src_layout);
850 src.set_user(user_src_layout);
851 wei.set_compute(wei_layout);
852 wei.set_user(user_wei_layout);
853 dst.set_compute(dst_layout);
854 dst.set_user(user_dst_layout);
855 bia.set_compute(bia_layout);
856 bia.set_user(user_bia_layout);
857
858 return status::success;
859}
860
861bool hw_ok(const hw_config_t &hw_cfg) {
862 // Disable pre-XeHP until performance parity is reached with OpenCL
863 // kernels.
864 if (hw_cfg.hw() < ngen::HW::XeHP) return false;
865 return true;
866}
867
868bool data_types_ok(const conv_problem_t &prb, const hw_config_t &hw_cfg) {
869 auto src = prb.src_data_type;
870 auto wei = prb.wei_data_type;
871 auto dst = prb.dst_data_type;
872 auto bia = prb.bia_data_type;
873 bool is_bf16 = utils::one_of(data_type::bf16, src, wei, dst, bia);
874 if (!prb.is_f64_conv() && utils::one_of(data_type::f64, src, wei, dst, bia))
875 return false;
876 if (is_bf16 && hw_cfg.hw() <= ngen::HW::XeLP) return false;
877 if (prb.is_f64_conv()
878 && utils::one_of(hw_cfg.hw(), ngen::HW::XeLP, ngen::HW::XeHPG))
879 return false;
880 if (prb.is_fwd) return true;
881 if (prb.is_bwd_d) return true;
882 if (prb.is_bwd_w) {
883 bool ok = true;
884 data_type_t default_acc_type
885 = src == data_type::f64 ? data_type::f64 : data_type::f32;
886 ok &= utils::one_of(
887 src, data_type::bf16, data_type::f32, data_type::f64);
888 ok &= (dst == src);
889 ok &= utils::one_of(wei, src, default_acc_type);
890
891 if (prb.with_bias) { ok &= utils::one_of(bia, src, data_type::f32); }
892 return ok;
893 }
894 return false;
895}
896
897bool zero_points_ok(const conv_problem_t &prb) {
898 auto *pd = prb.conv_pd;
899 auto *attr = pd->attr();
900
901 // TODO: implement the rest of the cases and remove this 'if'
902 bool ic_kdhw
903 = (prb.ic <= 8) && (prb.kd * prb.kh * prb.kw > 1) && !prb.is_dw;
904 if (!attr->zero_points_.has_default_values(DNNL_ARG_SRC) && ic_kdhw)
905 return false;
906
907 using namespace data_type;
908 const auto input_type = (prb.is_fwd) ? pd->invariant_src_md()->data_type
909 : pd->invariant_dst_md()->data_type;
910 int mask_src = 0, mask_dst = 0;
911 attr->zero_points_.get(DNNL_ARG_SRC, &mask_src);
912 attr->zero_points_.get(DNNL_ARG_DST, &mask_dst);
913
914 return IMPLICATION(!utils::one_of(input_type, s8, u8),
915 attr->zero_points_.has_default_values())
916 && attr->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
917 && (mask_src == 0 || mask_src == 1 << 1)
918 && (mask_dst == 0 || mask_dst == 1 << 1);
919}
920
921std::vector<int> get_scale_args(const conv_problem_t &prb) {
922 conv_arg_helper_t h(prb);
923 std::vector<int> ret = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
924 return ret;
925}
926
927bool post_ops_ok(const conv_problem_t &prb, const hw_config_t &hw_cfg) {
928 auto *attr = prb.attr;
929
930 // No post-ops are supported for f64
931 if (prb.is_f64_conv() && !attr->has_default_values()) return false;
932
933 if (prb.is_fwd || prb.is_bwd_d) {
934 using sm = primitive_attr_t::skip_mask_t;
935 auto attr_skip_mask = sm::post_ops | sm::sum_dt
936 | sm::zero_points_runtime | sm::scales_runtime;
937 if (!attr->has_default_values(attr_skip_mask)) return false;
938 } else {
939 if (!attr->has_default_values()) return false;
940 }
941
942 if (!attr->scales_.has_default_values())
943 if (!prb.is_s32_accumulator()) return false;
944 auto scale_args = get_scale_args(prb);
945 if (!attr->scales_.has_default_values(scale_args)) return false;
946 for (int arg : scale_args) {
947 int mask = attr->scales_.get(arg).mask_;
948 // XXX: per_oc for BWD_D is treated as per_ic assuming it's called from
949 // deconvolution.
950 int c_idx = prb.with_groups;
951 if (arg == DNNL_ARG_WEIGHTS) {
952 if (!utils::one_of(mask, 0, 1 << c_idx)) return false;
953 } else {
954 if (mask != 0) return false;
955 }
956 }
957
958 for (int i = 0; i < attr->post_ops_.len(); i++) {
959 auto &po = attr->post_ops_.entry_[i];
960 if (po.is_eltwise()) {
961 if (!jit_eltwise_injector_f32_is_supported(po.eltwise.alg))
962 return false;
963 else if (po.eltwise.alg == alg_kind::eltwise_tanh
964 && hw_cfg.hw() == ngen::HW::XeHPG
965 && hw_cfg.eu_count() <= 128)
966 // Workaround for hard to reproduce issue in end to end
967 // workloads. It is unclear what the actual issue is as the
968 // kernel always works correctly in benchdnn.
969 return false;
970 }
971 }
972 return true;
973}
974
975const memory_desc_t *output_md(const convolution_pd_t *pd) {
976 if (pd->is_fwd()) return pd->dst_md();
977 if (pd->is_bwd_d()) return pd->diff_src_md();
978 if (pd->is_bwd_w()) return pd->diff_weights_md();
979 ir_error_not_expected();
980 return nullptr;
981}
982
983void maybe_override_from_lookup_table(conv_config_t &cfg) {
984 static conv_config_lookup_table_t table;
985 auto *s_params = table.find(cfg);
986 if (s_params) cfg.override_set(s_params);
987}
988
989void maybe_override_from_env(conv_config_t &cfg) {
990 auto cfg_env = ir_utils::getenv_str("cfg", "");
991 if (cfg_env.empty()) return;
992 cfg.override_set(cfg_env.c_str());
993}
994
995void maybe_override(conv_config_t &cfg) {
996 maybe_override_from_lookup_table(cfg);
997#ifdef GEN_CONV_DEBUG
998 maybe_override_from_env(cfg);
999#endif
1000}
1001
1002status_t init_fma_kind(conv_config_t &cfg) {
1003 const auto &prb = cfg.prb();
1004 auto fma_kind = fma_kind::get_supported_kind(
1005 cfg.hw(), prb.a_data_type, prb.b_data_type, prb.acc_data_type);
1006 // Force mad for some cases.
1007 if (prb.is_dw || (prb.g > 1 && prb.ic < 4 && prb.oc < 4 && prb.mb < 8))
1008 fma_kind = fma_kind_t::mad;
1009 if (fma_kind == fma_kind_t::unknown) return status::unimplemented;
1010 cfg.set_fma_kind(fma_kind);
1011 return status::success;
1012}
1013
1014status_t init_simd(conv_config_t &cfg) {
1015 if (cfg.exec_cfg_param().is_overridden("simd")) return status::success;
1016
1017 const auto &prb = cfg.prb();
1018 int simd = fma_kind::get_simd_size(cfg.hw(), cfg.fma_kind(),
1019 prb.a_data_type, prb.b_data_type, prb.acc_data_type);
1020 cfg.set_simd(simd);
1021 return status::success;
1022}
1023
1024status_t init_vec_size(conv_config_t &cfg) {
1025 const auto &prb = cfg.prb();
1026 int vec_size = cfg.simd();
1027 if (cfg.fma_kind() == fma_kind_t::mad) {
1028 int grf_elems = cfg.grf_size() / prb.acc_data_type_size;
1029 int vec_dim = (prb.is_fwd || prb.is_bwd_w) ? prb.oc : prb.ic;
1030 if (vec_size > grf_elems && vec_dim <= 8) vec_size = grf_elems;
1031 }
1032 cfg.set_vec_size(vec_size);
1033 return status::success;
1034}
1035
1036bool post_op_layouts_ok(const conv_problem_t &prb) {
1037 auto *pd = prb.conv_pd;
1038 auto *attr = pd->attr();
1039
1040 for (int i = 0; i < attr->post_ops_.len(); i++) {
1041 auto &po = attr->post_ops_.entry_[i];
1042 if (po.is_binary() || po.is_prelu()) {
1043 int mask = po.is_prelu()
1044 ? po.prelu.mask
1045 : utils::get_dims_mask(pd->invariant_dst_md()->dims,
1046 po.binary.src1_desc.dims, prb.ndims, true);
1047 // These cases don't have message-related limitations.
1048 if ((mask & (1 << 1)) == 0 || mask == (1 << 1)) continue;
1049 auto rhs_layout = po.is_prelu() ? layout_t(type_t::f32(), 0,
1050 get_prelu_weights_dims(po.prelu.mask,
1051 *pd->invariant_dst_md()))
1052 : layout_t(po.binary.src1_desc);
1053 // No blocks means it's a scalar, can be always loaded.
1054 if (rhs_layout.blocks().empty()) return true;
1055
1056 auto rhs0 = rhs_layout.blocks()[0];
1057 // Innermost block must:
1058 // - be across output channels
1059 // - be dense
1060 if (rhs0.dim_idx != 1 || dim_t(rhs0.stride) != 1) return false;
1061 }
1062 }
1063 return true;
1064}
1065
1066status_t init_pd_time_cfg(const conv_problem_t &prb, conv_config_t &cfg,
1067 const engine_t *engine, convolution_pd_t *pd, primitive_attr_t *attr) {
1068 hw_config_t hw_cfg(engine);
1069
1070 if (!hw_ok(hw_cfg)) return status::unimplemented;
1071 if (!data_types_ok(prb, hw_cfg)) return status::unimplemented;
1072 if (!post_ops_ok(prb, hw_cfg)) return status::unimplemented;
1073 if (!zero_points_ok(prb)) return status::unimplemented;
1074
1075 cfg.set_prb(prb);
1076 cfg.set_exec_cfg(exec_config_t(hw_cfg));
1077
1078 maybe_override(cfg);
1079
1080 CHECK(init_fma_kind(cfg));
1081 CHECK(init_simd(cfg));
1082 CHECK(init_vec_size(cfg));
1083 CHECK(init_tensor_layouts(cfg, pd));
1084
1085 CHECK(attr->set_default_formats(&prb.c_md()));
1086
1087 if (!post_op_layouts_ok(prb)) return status::unimplemented;
1088
1089 return status::success;
1090}
1091
1092void init_hint(conv_config_t &cfg) {
1093 const auto &prb = cfg.prb();
1094 if (prb.is_fwd && is_small_ic(prb)) {
1095 int max_tg = 16;
1096 auto hint = cfg.hint();
1097 if (hint.max_tg_size() > max_tg) {
1098 hint.set_max_tg_size(max_tg);
1099 cfg.set_hint(hint);
1100 }
1101 }
1102}
1103
1104void init_pipeline(conv_config_t &cfg) {
1105 if (cfg.pipeline().is_overridden()) return;
1106
1107 const auto &prb = cfg.prb();
1108 bool do_unroll = true;
1109 if (prb.is_fwd) {
1110 const int max_unroll = 9;
1111 if (prb.ksp > max_unroll) do_unroll = false;
1112 if (is_small_ic(prb)) do_unroll = false;
1113 } else if (prb.is_bwd_d) {
1114 // Do not perform full unrolling when there are too many inner
1115 // iterations.
1116 int kernel_limit = prb.is_f32_conv() ? 4 : 9;
1117 if (prb.ksp > kernel_limit) do_unroll = false;
1118
1119 // Do not perform full unrolling with non-unit stride unless special
1120 // stride optimization is enabled. These cases have non-trivial
1121 // post-increment updates which result in unrolling all reduction loops
1122 // and exceeding the instruction cache.
1123 if (!prb.is_stride1() && !cfg.bwd_d_optimize_strided_iw())
1124 do_unroll = false;
1125 } else if (prb.is_bwd_w) {
1126 int mb_iter_blk = cfg.iter_dim("mb");
1127 do_unroll = (cfg.is_ge_xe_hpc() && cfg.is_dp_fma() && mb_iter_blk > 1);
1128 }
1129 // Unrolling with mad or dp4a results in too large kernels.
1130 if (utils::one_of(cfg.fma_kind(), fma_kind_t::mad, fma_kind_t::dp4a)
1131 && (cfg.hw() >= ngen::HW::XeHPG || prb.mb != 1))
1132 do_unroll = false;
1133 cfg.pipeline().set(do_unroll);
1134}
1135
1136void init_send_2d_nhwc(conv_config_t &cfg) {
1137 const auto &prb = cfg.prb();
1138
1139 bool a_ok = can_use_a_2d_send(cfg);
1140 bool b_ok = can_use_b_2d_send(cfg);
1141
1142 int64_t est_threads = 1;
1143 est_threads *= prb.g;
1144 est_threads *= prb.ic;
1145 est_threads *= prb.ksp;
1146 est_threads *= prb.mb;
1147 est_threads *= prb.oc;
1148 est_threads *= prb.osp;
1149
1150 // Estimated max reduction size per thread for BWD_W.
1151 const int bwd_w_max_k_per_thr = 1000;
1152 // Estimated M x N elements per thread.
1153 const int mn_per_thr = 16 * 16;
1154 // Crosspoint to enable 2D send and blocking.
1155 const int min_threads_to_enable_2d = 1024;
1156
1157 int k_fwd = prb.ic;
1158 int k_bwd_d = prb.oc;
1159 int k_bwd_w = std::min(bwd_w_max_k_per_thr, prb.mb * prb.osp);
1160 int k = prb.pick_by_dir(k_fwd, k_bwd_d, k_bwd_w);
1161 est_threads /= mn_per_thr;
1162 est_threads /= k;
1163
1164 if (est_threads < min_threads_to_enable_2d) {
1165 cfg.set_send_2d_nhwc(false);
1166 return;
1167 }
1168
1169 cfg.set_send_2d_nhwc(a_ok && b_ok);
1170}
1171
1172void init_fuse_spatial(conv_config_t &cfg) {
1173 if (cfg.fuse_spatial_param().is_overridden()) return;
1174
1175 const auto &prb = cfg.prb();
1176 if (!prb.is_fwd || is_small_ic(prb)) return;
1177
1178 // Spatial fusion may be suboptimal for small batch due to:
1179 // - Using smaller messages (load blocks are not fully dense anymore)
1180 // - Extra division arithmetic to work with fused indices
1181 if (cfg.src_layout().compute().inner_block(0) == 1) {
1182 if (!prb.is_fwd || cfg.is_ge_xe_hpc()) return;
1183 // Enable fusion for cases without m block with overwhelming spatial dim.
1184 if (prb.is_int8_dst() || (prb.osp < 4096)
1185 || !(prb.oh == prb.ow && prb.ow == prb.od)) {
1186 return;
1187 }
1188 }
1189
1190 cfg.set_fuse_spatial(true);
1191}
1192
1193void init_hoist_masks_from_compute_loop(conv_config_t &cfg) {
1194 if (cfg.send_2d_nhwc()) {
1195 cfg.set_hoist_masks_from_compute_loop(true);
1196 return;
1197 }
1198 if (!cfg.fuse_spatial()) return;
1199 if (cfg.hw() < ngen::HW::XeHPC) return;
1200
1201 // Both nhwc layouts and mask hoisting require extra GRF memory so avoid
1202 // enabling both.
1203 if (matches_tag(cfg.a_layout().compute_unnormalized(), "axb")) return;
1204
1205 cfg.set_hoist_masks_from_compute_loop(true);
1206}
1207
1208void init_ow_kw_grf_cache(conv_config_t &cfg) {
1209 const auto &prb = cfg.prb();
1210 if (!prb.is_fwd || !is_small_ic(prb) || prb.kw < 3 || is_dw_large_mb(prb))
1211 return;
1212 if (cfg.is_dp_fma()) return;
1213 if (cfg.fuse_spatial()) return;
1214
1215 const int iw_blk_limit = 40;
1216 const int max_ow_blk = 16;
1217 int max_iw_blk
1218 = (prb.sw * (max_ow_blk - 1) + (prb.kw - 1) * (1 + prb.dw) + 1);
1219 if (max_iw_blk > iw_blk_limit) return;
1220
1221 cfg.set_ow_kw_grf_cache(true);
1222}
1223
1224void init_common_blocking(conv_config_t &cfg, block_helper_t &bh) {
1225 const auto &prb = cfg.prb();
1226
1227 auto &src_layout = cfg.src_layout().compute();
1228 auto &wei_layout = cfg.wei_layout().compute();
1229 auto &dst_layout = cfg.dst_layout().compute();
1230
1231 bh.set_hw_config(cfg.hw_cfg());
1232 bh.set_fma_kind(cfg.fma_kind());
1233 bh.set_simd_size(cfg.simd());
1234 bh.set_vec_size(cfg.vec_size());
1235 bh.set_max_tg_size(cfg.hint().max_tg_size());
1236 bh.set_max_tg_overridden(cfg.hint().max_tg_overridden());
1237 bh.set_abc_types(prb.a_data_type, prb.b_data_type, prb.acc_data_type);
1238
1239 bh.set_dim("mb", prb.mb);
1240 bh.set_dim("g", prb.g);
1241 bh.set_dim("oc", prb.oc);
1242 //take into account blocked ic channels when selecting block sizes
1243 bh.set_dim("ic",
1244 prb.is_bwd_w ? std::max(src_layout.dims()[2], wei_layout.dims()[2])
1245 : prb.ic);
1246 bh.set_dims({"kd", "kh", "kw"}, {prb.kd, prb.kh, prb.kw});
1247
1248 bh.set_b_dims({"g"});
1249
1250 if (prb.is_fwd) {
1251 if (cfg.fuse_spatial()) {
1252 bh.set_dims({"osp"}, {prb.osp});
1253 bh.set_m_dims({"mb", "osp"});
1254 } else {
1255 bh.set_dims({"od", "oh", "ow"}, {prb.od, prb.oh, prb.ow});
1256 bh.set_m_dims({"mb", "od", "oh", "ow"});
1257 }
1258 bh.set_n_dims({"oc"});
1259 bh.set_k_dims({"ic", "kd", "kh", "kw"});
1260 } else if (prb.is_bwd_d) {
1261 ir_assert(!cfg.fuse_spatial());
1262 bh.set_dims({"id", "ih", "iw"}, {prb.id, prb.ih, prb.iw});
1263 bh.set_m_dims({"mb", "id", "ih", "iw"});
1264 bh.set_n_dims({"ic"});
1265 bh.set_k_dims({"oc", "kd", "kh", "kw"});
1266 } else if (prb.is_bwd_w) {
1267 ir_assert(!cfg.fuse_spatial());
1268 bh.set_dims({"od", "oh", "ow"}, {prb.od, prb.oh, prb.ow});
1269 bh.set_m_dims({"ic", "kd", "kh", "kw"});
1270 bh.set_n_dims({"oc"});
1271 bh.set_k_dims({"mb", "od", "oh", "ow"});
1272 } else {
1273 ir_error_not_expected();
1274 }
1275
1276 for (auto &kv : bh.dims()) {
1277 bh.set_pad_block(kv.first, cfg.pad_block(kv.first));
1278 }
1279
1280 // Set base blocks to align kernel blocking with layout blocking.
1281 if (prb.is_fwd) {
1282 bh.set_base_iter_block("mb", src_layout.inner_block(0));
1283 int src_g_blk = prb.is_dw ? src_layout.inner_block(1) : 1;
1284 int wei_g_blk = prb.is_dw ? wei_layout.inner_block(0) : 1;
1285 bh.set_base_iter_block("g", src_g_blk, wei_g_blk);
1286 int src_ic_blk = src_layout.inner_block(2);
1287 int wei_ic_blk = wei_layout.inner_block(2);
1288 bh.set_base_iter_block("ic", src_ic_blk, wei_ic_blk);
1289 if (cfg.is_g_mad()) {
1290 bh.set_base_iter_block(
1291 "oc", dst_layout.inner_block(2), wei_layout.inner_block(1));
1292 bh.dim("oc").set_iter_dim(bh.dim("oc").base_iter_block());
1293 }
1294 } else if (prb.is_bwd_d) {
1295 bh.set_base_iter_block("mb", dst_layout.inner_block(0));
1296 int dst_g_blk = dst_layout.inner_block(1);
1297 int wei_g_blk = wei_layout.inner_block(0);
1298 bh.set_base_iter_block("g", dst_g_blk, wei_g_blk);
1299 int dst_oc_blk = dst_layout.inner_block(2);
1300 int wei_oc_blk = wei_layout.inner_block(1);
1301 bh.set_base_iter_block("oc", dst_oc_blk, wei_oc_blk);
1302 } else if (prb.is_bwd_w) {
1303 bh.set_base_iter_block("g", wei_layout.inner_block(0));
1304 int wei_oc_blk = wei_layout.inner_block(2);
1305 int dst_oc_blk = dst_layout.inner_block(2);
1306 bh.set_base_iter_block("oc", wei_oc_blk, dst_oc_blk);
1307 int src_ic_blk = src_layout.inner_block(2);
1308 int wei_ic_blk = wei_layout.inner_block(2);
1309 bh.set_base_iter_block("ic", src_ic_blk, wei_ic_blk);
1310 int src_mb_blk = src_layout.inner_block(0);
1311 int dst_mb_blk = dst_layout.inner_block(0);
1312 bh.set_base_iter_block("mb", src_mb_blk, dst_mb_blk);
1313 }
1314}
1315
1316bool should_use_spatial_blocking(const conv_config_t &cfg,
1317 dim_value_t mb_max_iter_dim, int d, int h, int w) {
1318 const auto &prb = cfg.prb();
1319 if (!cfg.is_ge_xe_hpc()) return true;
1320 if (mb_max_iter_dim == 1) return true;
1321 if (cfg.send_2d_nhwc() && prb.is_bwd_d && prb.sw != 1) return false;
1322 int sp = (prb.ksp == 1 && prb.is_fwd) ? (d * h * w) : w;
1323 int block = 16;
1324 double mb_ratio = (double)prb.mb / utils::rnd_up(prb.mb, block);
1325 double sp_ratio = (double)sp / utils::rnd_up(sp, block);
1326 return sp_ratio >= mb_ratio;
1327}
1328
1329void init_fwd(conv_config_t &cfg, block_helper_t &bh) {
1330 using namespace ir_utils;
1331
1332 const auto &prb = cfg.prb();
1333 const char *osp_name = cfg.fuse_spatial() ? "osp" : "ow";
1334
1335 //set iter block for cases with no m block and large spatial
1336 if (!cfg.is_ge_xe_hpc() && cfg.src_layout().compute().inner_block(0) == 1
1337 && prb.mb > 1 && (prb.oh == prb.ow && prb.ow == prb.od)
1338 && prb.osp >= 512) {
1339 bh.set_base_iter_block(osp_name, 16);
1340 }
1341
1342 if (cfg.ow_kw_grf_cache()) {
1343 bh.set_base_iter_block("mb", 1);
1344 bh.dim("mb").set_iter_dim(1);
1345 bh.set_max_iter_dim("mb", 1);
1346 bh.set_max_m_tg_dim(2);
1347 bh.set_max_n_tg_dim(2);
1348 }
1349 if (cfg.is_g_mad()) {
1350 bh.set_base_iter_block("mb", 1);
1351 bh.dim("mb").set_iter_dim(1);
1352 bh.set_max_iter_dim("mb", 1);
1353 bh.set_max_iter_dim(osp_name, 4);
1354 }
1355 bh.set_loop_dim("kd", prb.kd);
1356 bh.set_loop_dim("kh", prb.kh);
1357 if (is_small_ic(prb) && !is_dw_large_mb(prb)
1358 && (prb.g == 1 || prb.ic == prb.oc)) {
1359 bh.set_block_dims({"kw"});
1360 } else {
1361 bh.set_loop_dim("kw", prb.kw);
1362 // mad is not tested with thread group k-slicing.
1363 if (cfg.is_dp_fma()) {
1364 bh.allow_k_tg_slicing();
1365 bh.set_max_k_tg_dim(8);
1366 }
1367 }
1368
1369 bh.set_block_dims({"g", "oc", "ic", "mb", osp_name});
1370 bh.set_vector_dim(prb.is_dw || cfg.is_g_mad() ? "g" : "oc");
1371 bh.allow_fuse({"ic", "kw"});
1372 bh.allow_split({"oc", "ic", "kw"});
1373
1374 int mb_base_iter_blk = bh.dim("mb").base_iter_block();
1375 // mb blocking is always outer so we can safely use a smaller divisor to
1376 // have more flexible blocking for some cases.
1377 int mb_base_iter_divisor = is_dw_large_mb(prb) ? 32 : 8;
1378 mb_base_iter_blk = math::gcd(mb_base_iter_divisor, mb_base_iter_blk);
1379
1380 bh.set_base_iter_block("mb", mb_base_iter_blk);
1381
1382 bool use_sp_blocking = false;
1383 if (matches_tag(cfg.src_layout().compute_unnormalized(), "axb")) {
1384 use_sp_blocking = should_use_spatial_blocking(
1385 cfg, bh.max_iter_dim("mb"), prb.od, prb.oh, prb.ow);
1386 } else if (cfg.src_layout().compute().inner_block(0) == 1) {
1387 use_sp_blocking = true;
1388 } else if (prb.is_dw && !is_dw_large_mb(prb)) {
1389 use_sp_blocking = true;
1390 } else if (cfg.is_g_mad() || cfg.ow_kw_grf_cache()) {
1391 use_sp_blocking = true;
1392 }
1393
1394 if (use_sp_blocking) {
1395 if (prb.is_dw) bh.set_pref_tg_block(osp_name);
1396 bh.allow_split({osp_name, "mb"});
1397 bh.reorder({osp_name, "mb"});
1398 if (!prb.is_int8_dst() && !cfg.fuse_spatial() && prb.mb < 16
1399 && prb.iw % 8 != 0 && !prb.is_dw) {
1400 bh.set_max_m_tg_dim(1);
1401 }
1402 } else {
1403 const int large_sp_threshold = cfg.is_ge_xe_hpc() ? 128 : 256;
1404 if (!prb.is_dw && prb.ow > large_sp_threshold)
1405 bh.set_pref_tg_block("oc");
1406 else if (cfg.is_dp_fma() && prb.mb >= 16)
1407 bh.set_pref_tg_block(osp_name);
1408 bh.reorder({"mb", osp_name});
1409 auto spatial_dim = cfg.fuse_spatial() ? prb.osp : prb.ow;
1410 if (!cfg.send_2d_nhwc() && prb.mb >= 128
1411 && (spatial_dim % 4 != 0 || spatial_dim < 64))
1412 bh.allow_split({"mb"});
1413 }
1414
1415 if (prb.mb < 8 && !bh.any_pref_tg_block())
1416 bh.set_pref_tg_block(prb.ow > prb.oc ? osp_name : "oc");
1417
1418 bh.reorder({"ic", "kw"});
1419
1420 if (cfg.send_2d_nhwc()) {
1421 // Use 64-byte reduction step to avoid partial cache line loads.
1422 bh.set_base_iter_block("ic", 64 / prb.a_data_type_size);
1423 bh.set_reduce_m_block_hint(false);
1424 }
1425
1426 bh.compute();
1427}
1428
1429void init_bwd_d(conv_config_t &cfg, block_helper_t &bh) {
1430 using namespace ir_utils;
1431
1432 const auto &prb = cfg.prb();
1433 bh.set_loop_dim("kw", prb.kw);
1434 bh.set_loop_dim("kd", prb.kd);
1435 bh.set_loop_dim("kh", prb.kh);
1436 bh.set_block_dims({"g", "oc", "ic", "mb", "iw"});
1437 bh.set_vector_dim(prb.is_dw ? "g" : "ic");
1438 bh.allow_split({"oc", "ic"});
1439
1440 bool use_w_blocking = false;
1441 if (matches_tag(cfg.dst_layout().compute_unnormalized(), "axb")) {
1442 use_w_blocking = should_use_spatial_blocking(
1443 cfg, bh.max_iter_dim("mb"), prb.id, prb.ih, prb.iw);
1444 } else if (cfg.dst_layout().compute().inner_block(0) == 1) {
1445 use_w_blocking = true;
1446 }
1447
1448 if (use_w_blocking) {
1449 bh.allow_fuse({"iw", "mb"});
1450 bh.allow_split({"iw", "mb"});
1451 bh.reorder({"iw", "mb"});
1452 } else {
1453 bh.reorder({"mb", "iw"});
1454 bh.set_base_iter_block("mb", 8);
1455 }
1456
1457 if (cfg.send_2d_nhwc()) {
1458 bh.set_base_iter_block("oc", 64 / prb.a_data_type_size);
1459 if (!prb.is_stride1()) bh.allow_split({"mb"});
1460 bh.set_reduce_m_block_hint(false);
1461 }
1462
1463 bh.compute();
1464}
1465
1466void init_bwd_w(conv_config_t &cfg, block_helper_t &bh) {
1467 const auto &prb = cfg.prb();
1468 bh.allow_k_grid_slicing();
1469
1470 bh.set_block_dims({"g", "oc", "ic", "mb", "oh", "ow"});
1471 bh.set_vector_dim(prb.is_dw ? "g" : "oc");
1472
1473 if (prb.oc <= 32) bh.set_max_iter_dim("oc", 16);
1474 if (prb.ic <= 32) bh.set_max_iter_dim("ic", 16);
1475
1476 if (is_small_ic(prb) && !prb.is_dw) {
1477 bh.set_block_dims({"kw"});
1478 bh.set_max_tg_dim("kw", 1);
1479 bh.set_max_iter_dim("kw", 8);
1480 }
1481
1482 // Avoid 2D spatial blocking when possible (when 1D blocking can be
1483 // enough). Extra oh/od loops may result in assembly bloat due to pipeline
1484 // unroll.
1485 if (prb.mb >= 32 && prb.ow >= 16) {
1486 bh.set_max_loop_dim("oh", 1);
1487 bh.set_max_loop_dim("od", 1);
1488 }
1489
1490 bh.set_max_iter_dim("oh", 1);
1491
1492 bh.allow_split({"oc", "ic", "mb", "ow"});
1493 bh.allow_fuse({"ic", "kw"});
1494 bh.allow_fuse({"mb", "oh", "ow"});
1495 bh.set_max_loop_dim("mb", 2);
1496 bh.set_base_iter_block("mb", math::gcd(16, bh.dim("mb").base_iter_block()));
1497
1498 bh.reorder({"mb", "ow", "oh"});
1499
1500 if (cfg.send_2d_nhwc()) bh.set_reduce_m_block_hint(false);
1501
1502 bh.compute();
1503}
1504
1505void init_blocking(conv_config_t &cfg) {
1506 const auto &prb = cfg.prb();
1507 block_helper_t bh;
1508 init_common_blocking(cfg, bh);
1509 if (prb.is_fwd) {
1510 init_fwd(cfg, bh);
1511 } else if (prb.is_bwd_d) {
1512 init_bwd_d(cfg, bh);
1513 } else if (prb.is_bwd_w) {
1514 init_bwd_w(cfg, bh);
1515 } else {
1516 ir_error_not_expected();
1517 }
1518
1519 auto &dims = cfg.dims();
1520 auto &iter_dims = cfg.iter_dims();
1521 auto &thread_group_dims = cfg.thread_group_dims();
1522 auto &loop_dims = cfg.loop_dims();
1523
1524 for (auto &kv : bh.dims()) {
1525 auto &name = kv.first;
1526 auto &d = kv.second;
1527 if (!dims.is_overridden()) dims.set(name, d.size());
1528 if (!iter_dims.is_overridden()) iter_dims.set(name, d.iter_dim());
1529 if (!thread_group_dims.is_overridden())
1530 thread_group_dims.set(name, d.tg_dim());
1531 if (!loop_dims.is_overridden()) loop_dims.set(name, d.loop_dim());
1532 if (cfg.shrink_tg_dims()) {
1533 int dim = cfg.dim(name);
1534 int iter = cfg.iter_dim(name);
1535 int tg = cfg.thread_group_dim(name);
1536 int loop = cfg.loop_dim(name);
1537 int pad_blk = cfg.pad_block(name);
1538 while (tg > 1) {
1539 int padded = utils::rnd_up(
1540 dim, math::lcm(iter * tg * loop, pad_blk));
1541 if (dim * 2 > padded) break;
1542 tg = std::max(1, tg / 2);
1543 }
1544 cfg.thread_group_dims().set(name, tg);
1545 }
1546 }
1547}
1548
1549const char **get_kernel_grid_conv_dims(const conv_problem_t &prb, int idx) {
1550 static const char *fwd_0[] = {"oc", nullptr};
1551 static const char *fwd_1[] = {"g", "osp", "od", "oh", "ow", nullptr};
1552 static const char *fwd_2[] = {"mb", nullptr};
1553 static const char *bwd_d_0[] = {"ic", nullptr};
1554 static const char *bwd_d_1[] = {"g", "id", "ih", "iw", nullptr};
1555 static const char *bwd_d_2[] = {"mb", nullptr};
1556 static const char *bwd_w_0[] = {"oc", nullptr};
1557 static const char *bwd_w_1[]
1558 = {"ic", "kd", "kh", "kw", "od", "oh", "ow", nullptr};
1559 static const char *bwd_w_2[] = {"g", "mb", nullptr};
1560 static const char **fwd[] = {fwd_0, fwd_1, fwd_2};
1561 static const char **bwd_d[] = {bwd_d_0, bwd_d_1, bwd_d_2};
1562 static const char **bwd_w[] = {bwd_w_0, bwd_w_1, bwd_w_2};
1563 ir_assert(idx >= 0 && idx < 3);
1564 if (prb.is_fwd) return fwd[idx];
1565 if (prb.is_bwd_d) return bwd_d[idx];
1566 if (prb.is_bwd_w) return bwd_w[idx];
1567 ir_error_not_expected();
1568 return nullptr;
1569}
1570
1571const char **get_thread_group_grid_conv_dims(
1572 const conv_problem_t &prb, int idx) {
1573 static const char *fwd_0[] = {"oc", nullptr};
1574 static const char *fwd_1[] = {"mb", "osp", "ow", nullptr};
1575 static const char *fwd_2[] = {"ic", nullptr};
1576 static const char *bwd_d_0[] = {"ic", nullptr};
1577 static const char *bwd_d_1[] = {"mb", "iw", nullptr};
1578 static const char *bwd_d_2[] = {"oc", nullptr};
1579 static const char *bwd_w_0[] = {"oc", nullptr};
1580 static const char *bwd_w_1[] = {"ic", nullptr};
1581 static const char *bwd_w_2[] = {nullptr};
1582 static const char **fwd[] = {fwd_0, fwd_1, fwd_2};
1583 static const char **bwd_d[] = {bwd_d_0, bwd_d_1, bwd_d_2};
1584 static const char **bwd_w[] = {bwd_w_0, bwd_w_1, bwd_w_2};
1585 ir_assert(idx >= 0 && idx < 3);
1586 if (prb.is_fwd) return fwd[idx];
1587 if (prb.is_bwd_d) return bwd_d[idx];
1588 if (prb.is_bwd_w) return bwd_w[idx];
1589 ir_error_not_expected();
1590 return nullptr;
1591}
1592
1593void init_padded_dims(conv_config_t &cfg) {
1594 for (auto &kv : cfg.dims().get()) {
1595 auto &name = kv.first;
1596 int dim = cfg.dim(name);
1597 int iter = cfg.iter_dim(name);
1598 int tg = cfg.thread_group_dim(name);
1599 int loop = cfg.loop_dim(name);
1600 int blk = iter * tg * loop;
1601 int pad_blk = cfg.pad_block(name);
1602 int padded = utils::rnd_up(dim, math::lcm(blk, pad_blk));
1603 cfg.padded_dims().set(name, padded);
1604 }
1605}
1606
1607void init_kernel_grid(conv_config_t &cfg) {
1608 const auto &prb = cfg.prb();
1609 auto get = [&](const char *name) {
1610 int padded = cfg.padded_dim(name);
1611 int iter = cfg.iter_dim(name);
1612 int loop = cfg.loop_dim(name);
1613 int tg = cfg.thread_group_dim(name);
1614 int tg_block = iter * loop * tg;
1615 return ir_utils::safe_divide(padded, tg_block);
1616 };
1617
1618 const int grid_ndims = 3;
1619 std::vector<int> dims = {1, 1, 1};
1620 for (int i = 0; i < grid_ndims; i++) {
1621 auto **dd = get_kernel_grid_conv_dims(prb, i);
1622 for (auto **d = dd; *d; d++)
1623 dims[i] *= get(*d);
1624 }
1625 cfg.set_kernel_grid(grid_info_t(dims, "grid_idx"));
1626}
1627
1628void init_thread_group_grid(conv_config_t &cfg) {
1629 const auto &prb = cfg.prb();
1630 auto get = [&](const char *name) {
1631 return cfg.thread_group_dims().get(name);
1632 };
1633
1634 const int grid_ndims = 3;
1635 std::vector<int> dims = {1, 1, 1};
1636 for (int i = 0; i < grid_ndims; i++) {
1637 auto **dd = get_thread_group_grid_conv_dims(prb, i);
1638 for (auto **d = dd; *d; d++)
1639 dims[i] *= get(*d);
1640 }
1641 cfg.set_thread_group_grid(grid_info_t(dims, "tg_idx"));
1642}
1643
1644// Enable optimization for strided BWD_D convolution.
1645void init_bwd_d_optimize_strided(conv_config_t &cfg) {
1646 const auto &prb = cfg.prb();
1647 if (!prb.is_bwd_d) return;
1648 if (prb.is_stride1()) return;
1649
1650 cfg.set_bwd_d_optimize_strided(true);
1651
1652 if (cfg.iter_dim("iw") > 1) return;
1653 if (prb.iw % prb.sw != 0) return;
1654 cfg.set_bwd_d_optimize_strided_iw(true);
1655
1656 // Update blocks.
1657 int iw_tg_dim0 = cfg.thread_group_dim("iw");
1658 ir_assert(math::is_pow2(iw_tg_dim0));
1659 ir_assert(prb.iw % prb.sw == 0);
1660 for (int tg_dim = iw_tg_dim0; tg_dim >= 1; tg_dim /= 2) {
1661 if ((prb.iw / prb.sw) % tg_dim != 0) continue;
1662
1663 cfg.thread_group_dims().set("iw", tg_dim);
1664 int mb_iter_dim = cfg.iter_dim("mb");
1665 int new_mb_tg_dim = cfg.thread_group_dim("mb") * iw_tg_dim0 / tg_dim;
1666 // TODO: non-uniform thread group is unsupported
1667 while (new_mb_tg_dim > 1
1668 && utils::rnd_up(prb.mb, mb_iter_dim * new_mb_tg_dim) - prb.mb
1669 >= mb_iter_dim) {
1670 new_mb_tg_dim /= 2;
1671 }
1672 if (mb_iter_dim * new_mb_tg_dim <= prb.mb) {
1673 cfg.thread_group_dims().set("mb", new_mb_tg_dim);
1674 }
1675 break;
1676 }
1677}
1678
1679void init_unroll(conv_config_t &cfg) {
1680 if (cfg.unroll().is_overridden()) return;
1681
1682 const auto &prb = cfg.prb();
1683
1684 if (prb.is_bwd_w) {
1685 int mb_loop_dim = cfg.loop_dim("mb");
1686 int ow_loop_dim = cfg.loop_dim("ow");
1687 cfg.unroll().set("mb", mb_loop_dim);
1688 if (cfg.iter_dim("ow") > 1 && ow_loop_dim <= 8 && cfg.is_dp_fma()) {
1689 cfg.unroll().set("ow", ow_loop_dim);
1690 }
1691 }
1692}
1693
1694bool can_split_across_thread_group(int tg_size, int elems, int type_size) {
1695 // Thread group grid is limited to powers of two. We can reliably split
1696 // only powers of two elements across such grids.
1697 if (!math::is_pow2(elems)) return false;
1698
1699 // Check that the buffer can be uniformly distributed.
1700 if (elems % tg_size != 0) return false;
1701
1702 // Check that SLM can be stored with oword messages.
1703 int bytes_per_thr = (elems / tg_size) * type_size;
1704 if (bytes_per_thr % 16 != 0) return false;
1705
1706 return true;
1707}
1708
1709void init_slm(conv_config_t &cfg) {
1710 if (cfg.slm().is_overridden()) return;
1711
1712 const auto &prb = cfg.prb();
1713 if (cfg.hw() >= ngen::HW::XeHPC) return;
1714
1715 int bufs = 0;
1716 int gmem_bufs = 0;
1717 bool enable_a = false;
1718 bool enable_b = false;
1719 auto &tg = cfg.thread_group_grid();
1720 int tg_size = tg.elems();
1721 bmnk_dim_helper_t h(cfg);
1722 int m_tg_blk = h.thread_group_dim('m') * h.iter_dim('m');
1723 int n_tg_blk = h.thread_group_dim('n') * h.iter_dim('n');
1724 int k_iter_blk = h.iter_dim('k');
1725 if (!cfg.ow_kw_grf_cache()) {
1726 //Check that SLM can be stored with oword messages.
1727 int tg_size = tg.elems();
1728 int bytes_per_tg = (m_tg_blk * k_iter_blk * prb.a_data_type_size);
1729 int align = prb.is_bwd_w ? 32 : 16;
1730 bool can_split_a = bytes_per_tg % align == 0
1731 && bytes_per_tg / tg_size >= k_iter_blk && k_iter_blk % 2 == 0;
1732 enable_a = (tg.dim(0) > 1) && can_split_a;
1733 }
1734 bool can_split_b = can_split_across_thread_group(
1735 tg_size, n_tg_blk * k_iter_blk, prb.b_data_type_size);
1736 enable_b = (tg.dim(1) > 1) && can_split_b;
1737
1738 if (enable_a || enable_b) {
1739 bool is_small_tg = (tg.dim(0) * tg.dim(1) <= 8);
1740 int pref_bufs
1741 = ((is_small_tg || prb.is_f32_conv()) && prb.mb > 1 ? 2 : 3);
1742 if (cfg.pipeline().do_unroll()) {
1743 bufs = pref_bufs;
1744 gmem_bufs = (cfg.is_dp_fma() ? 2 : 1);
1745 } else {
1746 // Double/triple SLM buffering is not supported when only one
1747 // matrix is SLM-buffered.
1748 bufs = (enable_a == enable_b ? pref_bufs : 1);
1749 gmem_bufs = 1;
1750 }
1751 }
1752 cfg.slm().set(bufs, gmem_bufs, enable_a, enable_b);
1753}
1754
1755void init_prefetch(conv_config_t &cfg) {
1756 if (cfg.prefetch().is_overridden()) return;
1757
1758 const auto &prb = cfg.prb();
1759 if (cfg.hw() < ngen::HW::XeHPC) return;
1760
1761 auto &tg = cfg.thread_group_grid();
1762 int tg_size = tg.elems();
1763 bmnk_dim_helper_t h(cfg);
1764 int m_tg_blk = h.thread_group_dim('m') * h.iter_dim('m');
1765 int n_tg_blk = h.thread_group_dim('n') * h.iter_dim('n');
1766 int k_iter_blk = h.iter_dim('k');
1767 bool can_split_a = (tg.dim(0) == 1)
1768 || can_split_across_thread_group(
1769 tg_size, m_tg_blk * k_iter_blk, prb.a_data_type_size);
1770 bool can_split_b = (tg.dim(1) == 1)
1771 || can_split_across_thread_group(
1772 tg_size, n_tg_blk * k_iter_blk, prb.b_data_type_size);
1773
1774 bool use_prefetch = (can_split_a && can_split_b);
1775 if (!use_prefetch) return;
1776
1777 cfg.prefetch().set(prb.is_f32_conv() ? 2 : 3);
1778}
1779
1780void init_allow_grf_reorder(conv_config_t &cfg) {
1781 const auto &prb = cfg.prb();
1782 cfg.set_allow_a_grf_reorder(!prb.matches_user_types());
1783 cfg.set_allow_b_grf_reorder(!prb.matches_user_types());
1784
1785 bool is_mad = !cfg.is_dp_fma();
1786 if (is_mad && prb.is_s32_accumulator()) {
1787 cfg.set_allow_a_grf_reorder(true);
1788 cfg.set_allow_b_grf_reorder(true);
1789 return;
1790 }
1791
1792 if (is_mad && prb.b_data_type == data_type::bf16) {
1793 cfg.set_allow_b_grf_reorder(true);
1794 return;
1795 }
1796
1797 bool use_a_2d_send = can_use_a_2d_send(cfg);
1798 bool use_b_2d_send = can_use_b_2d_send(cfg);
1799 bool is_a_grf_blocked
1800 = (cfg.a_layout().compute().innermost_block_layout().size()
1801 % cfg.grf_size()
1802 == 0);
1803 if ((prb.is_fwd || prb.is_bwd_d) && !use_a_2d_send && !is_a_grf_blocked) {
1804 const char *dim_name = (prb.is_fwd ? "ic" : "oc");
1805 int dim = (prb.is_fwd ? prb.ic : prb.oc);
1806 int blk = cfg.iter_dim(dim_name);
1807 if (blk * prb.a_data_type_size % cfg.grf_size() != 0
1808 || dim != cfg.padded_dim(dim_name)) {
1809 cfg.set_allow_a_grf_reorder(true);
1810 }
1811 }
1812 if (cfg.send_2d_nhwc()) cfg.set_allow_a_grf_reorder(true);
1813
1814 bool a_is_small_c = (prb.is_fwd || prb.is_bwd_w) ? is_small_ic(prb)
1815 : is_small_oc(prb);
1816 if (cfg.is_dp_fma() && !prb.is_dw && a_is_small_c) {
1817 cfg.set_allow_a_grf_reorder(true);
1818 }
1819
1820 if (prb.is_bwd_w && cfg.is_dp_fma()) {
1821 cfg.set_allow_a_grf_reorder(true);
1822 if (!use_b_2d_send) cfg.set_allow_b_grf_reorder(true);
1823 }
1824}
1825
1826void init_allow_slm_tg_slicing(conv_config_t &cfg) {
1827 const auto &prb = cfg.prb();
1828 if (!prb.is_bwd_w) return;
1829 if (!utils::everyone_is(prb.a_data_type, prb.b_data_type, data_type::bf16))
1830 return;
1831 if (!cfg.is_dp_fma()) return;
1832
1833 // Enable only for layouts with batch blocking.
1834 int src_mb_blk = cfg.src_layout().compute().inner_block(0);
1835 int src_ic_blk = cfg.src_layout().compute().inner_block(2);
1836 int dst_mb_blk = cfg.dst_layout().compute().inner_block(0);
1837 int dst_oc_blk = cfg.dst_layout().compute().inner_block(2);
1838 if (src_mb_blk < 16 || dst_mb_blk < 16) return;
1839
1840 bmnk_dim_helper_t h(cfg);
1841 int k_iter_blk = h.iter_dim('k');
1842 int m_iter_blk = h.iter_dim('m');
1843 int n_iter_blk = h.iter_dim('n');
1844 int m_tg_dim = h.thread_group_dim('m');
1845 int n_tg_dim = h.thread_group_dim('n');
1846 int tg_size = m_tg_dim * n_tg_dim;
1847
1848 // Backward by weights with dpas layouts requires GRF reorders for A/B
1849 // (e.g. 2c*16n16c -> 32c16n). When SLM is used, such reorders are
1850 // generated after load from GMEM and before store to SLM. For optimal
1851 // performance we need load/store layouts to have large dense blocks. This
1852 // means that in some cases we have to use only a sub-grid of thread group
1853 // (i.e. rely on TG slicing) to perform load-store operation, otherwise we
1854 // may end up with reorders like 8n16c -> 16c*8n which result in scattered
1855 // loads/stores).
1856 // At the same time using sub-grids results in higher GRF consumption so we
1857 // only enable TG slicing when the resulting sub-grid consists of at least
1858 // half of the total threads.
1859 int src_reorder_elems = k_iter_blk * src_ic_blk;
1860 int src_tg_elems = m_iter_blk * m_tg_dim * k_iter_blk;
1861 if (src_tg_elems % tg_size != 0) return;
1862 int src_elems_per_thr = src_tg_elems / tg_size;
1863 int src_slices = utils::div_up(src_reorder_elems, src_elems_per_thr);
1864 if (src_slices > 2) return;
1865
1866 int dst_reorder_elems = k_iter_blk * dst_oc_blk;
1867 int dst_tg_elems = n_iter_blk * n_tg_dim * k_iter_blk;
1868 if (dst_tg_elems % tg_size != 0) return;
1869 int dst_elems_per_thr = dst_tg_elems / tg_size;
1870 int dst_slices = utils::div_up(dst_reorder_elems, dst_elems_per_thr);
1871 if (dst_slices > 2) return;
1872
1873 cfg.set_allow_slm_tg_slicing(true);
1874}
1875
1876void init_reduce_b(conv_config_t &cfg) {
1877 const auto &prb = cfg.prb();
1878
1879 if (prb.is_bwd_w && prb.with_bias) { cfg.set_reduce_b(true); }
1880}
1881
1882void init_assign_sbids(conv_config_t &cfg) {
1883 if (cfg.is_dp_fma()) cfg.set_assign_sbids(true);
1884}
1885
1886// Overwrites parameters that are implied by other parameters.
1887status_t fixup_config(conv_config_t &cfg) {
1888 const auto &prb = cfg.prb();
1889
1890 // Downgrade dpasw -> dpas for some cases.
1891 if (cfg.fma_kind() == fma_kind_t::dpasw) {
1892 // dpasw is executed by fused EUs (across X thread group
1893 // dimension). Do not use dpasw if X is uneven.
1894 if (cfg.thread_group_grid().dim(0) % 2 != 0)
1895 cfg.set_fma_kind(fma_kind_t::dpas);
1896 // dpasw can't be generated in case of direct load from GMEM and reorder.
1897 if (prb.is_bwd_w
1898 && (cfg.allow_a_grf_reorder() || cfg.allow_b_grf_reorder())
1899 && (!cfg.slm().a() || !cfg.slm().b()))
1900 cfg.set_fma_kind(fma_kind_t::dpas);
1901 }
1902
1903 return status::success;
1904}
1905
1906template <typename GetFuncT>
1907bool in_grid_dims(
1908 GetFuncT get_func, const conv_problem_t &prb, const std::string &dim) {
1909 for (int i = 0; i < 3; i++) {
1910 auto **dd = get_func(prb, i);
1911 for (auto **d = dd; *d; d++)
1912 if (*d == dim) return true;
1913 }
1914 return false;
1915}
1916
1917status_t check_config(conv_config_t &cfg) {
1918 const auto &prb = cfg.prb();
1919 if (prb.is_fwd) {
1920 if (cfg.send_2d_nhwc() && prb.sw != 1 && (prb.kw != 1 || prb.pw != 0)) {
1921 int osp_iter_blk = cfg.iter_dim("osp") * cfg.iter_dim("ow");
1922 ir_assert(osp_iter_blk == 1)
1923 << "Can't use 2D block messages for non-trivial "
1924 "strided dimensions.";
1925 }
1926 } else if (prb.is_bwd_d) {
1927 if (cfg.send_2d_nhwc() && prb.mb < 16 && prb.sw != 1) {
1928 ir_assert(cfg.iter_dim("iw") == 1)
1929 << "Can't use 2D block messages for non-trivial "
1930 "strided dimensions.";
1931 }
1932 } else if (prb.is_bwd_w) {
1933 if (cfg.send_2d_nhwc() && prb.sw != 1 && (prb.kw != 1 || prb.pw != 0)) {
1934 ir_assert(cfg.iter_dim("ow") == 1)
1935 << "Can't use 2D block messages for non-trivial "
1936 "strided dimensions.";
1937 }
1938 }
1939
1940 for (auto &kv : cfg.dims().get()) {
1941 auto &name = kv.first;
1942 int tg = cfg.thread_group_dim(name);
1943 int grid = cfg.grid_dim(name);
1944 if (tg != 1)
1945 ir_assert(in_grid_dims(get_thread_group_grid_conv_dims, prb, name))
1946 << name;
1947 if (grid != 1)
1948 ir_assert(in_grid_dims(get_kernel_grid_conv_dims, prb, name))
1949 << name;
1950 }
1951 return status::success;
1952}
1953
1954bool try_reduce_grf_usage(conv_config_t &cfg) {
1955 const auto &prb = cfg.prb();
1956 if (!cfg.reduce_grf_usage()) return true;
1957
1958 // TODO: improve estimate register count, it fails to account for tmp
1959 // values like mask_registers among other things.
1960 int max_regs = cfg.regs();
1961 int est_regs = estimate_register_count(cfg);
1962 if (est_regs <= max_regs) return true;
1963
1964 // Try to disable GRF buffering.
1965 if (cfg.slm().gmem_bufs() > 1) {
1966 cfg.slm().set_gmem_bufs(1);
1967 int est_regs = estimate_register_count(cfg);
1968 if (est_regs <= max_regs) return true;
1969 }
1970
1971 bmnk_dim_helper_t h(cfg);
1972
1973 // Try to use subtiles for B.
1974 if (!cfg.subtiles().is_overridden()) {
1975 int n_iter_blk = h.iter_dim('n');
1976 int max_b_subtiles
1977 = std::min((cfg.slm().b() ? 4 : 2), n_iter_blk / cfg.simd());
1978 // XXX: avoid layout mismatch for B loads
1979 if (cfg.hw() >= ngen::HW::XeHPC && prb.is_bwd_w)
1980 max_b_subtiles = std::min(2, max_b_subtiles);
1981 while (cfg.subtiles().b() < max_b_subtiles) {
1982 cfg.subtiles().set_b(cfg.subtiles().b() * 2);
1983 int est_regs = estimate_register_count(cfg);
1984 if (est_regs <= max_regs) return true;
1985 }
1986
1987 // Try to use subtiles for A.
1988 int m_iter_blk = h.iter_dim('m');
1989 int max_a_subtiles = std::min((cfg.slm().a() ? 4 : 2), m_iter_blk / 8);
1990 if (cfg.subtiles().b() > 1) max_a_subtiles = 1;
1991 while (cfg.subtiles().a() < max_a_subtiles) {
1992 cfg.subtiles().set_a(cfg.subtiles().a() * 2);
1993 int est_regs = estimate_register_count(cfg);
1994 if (est_regs <= max_regs) return true;
1995 }
1996 }
1997
1998 if (!cfg.slm().is_overridden()) {
1999 // Try to use double SLM buffering.
2000 if (cfg.slm().bufs() == 3) {
2001 cfg.slm().set_bufs(2);
2002 int est_regs = estimate_register_count(cfg);
2003 if (est_regs <= max_regs) return true;
2004 }
2005
2006 // Try to use single SLM buffering.
2007 if (cfg.slm().bufs() == 2) {
2008 cfg.slm().set_bufs(1);
2009 int est_regs = estimate_register_count(cfg);
2010 if (est_regs <= max_regs) return true;
2011 }
2012 }
2013
2014 if (!cfg.pipeline().is_overridden()) {
2015 // Last resort settings to reduce GRF usage.
2016 cfg.pipeline().set(false);
2017 }
2018
2019 return estimate_register_count(cfg) <= max_regs;
2020}
2021
2022status_t try_init_cfg(conv_config_t &cfg) {
2023 init_hint(cfg);
2024 init_send_2d_nhwc(cfg);
2025 init_fuse_spatial(cfg);
2026 init_hoist_masks_from_compute_loop(cfg);
2027 init_ow_kw_grf_cache(cfg);
2028 init_blocking(cfg);
2029 init_bwd_d_optimize_strided(cfg);
2030 init_pipeline(cfg);
2031 init_padded_dims(cfg);
2032 init_kernel_grid(cfg);
2033 init_thread_group_grid(cfg);
2034 init_unroll(cfg);
2035 init_slm(cfg);
2036 init_prefetch(cfg);
2037 init_allow_grf_reorder(cfg);
2038 init_allow_slm_tg_slicing(cfg);
2039 init_reduce_b(cfg);
2040 init_assign_sbids(cfg);
2041
2042 CHECK(fixup_config(cfg));
2043 CHECK(check_config(cfg));
2044
2045 if (!try_reduce_grf_usage(cfg)) return status::unimplemented;
2046
2047 return status::success;
2048}
2049
2050// Returns max SLM size per thread group assuming max utilization (max
2051// concurrent threads per EU).
2052int max_slm_size(const conv_config_t &cfg) {
2053 ngen::HW hw = cfg.hw();
2054 int regs = cfg.regs();
2055 return compute::device_info_t::max_slm_size_per_tg(
2056 convert_ngen_arch_to_dnnl(hw), regs > 128);
2057}
2058
2059int slm_size(const conv_config_t &cfg) {
2060 const auto &prb = cfg.prb();
2061 auto &slm = cfg.slm();
2062 if (slm.bufs() == 0) return 0;
2063
2064 bmnk_dim_helper_t h(cfg);
2065 int m_tg_blk = h.thread_group_dim('m') * h.iter_dim('m');
2066 int n_tg_blk = h.thread_group_dim('n') * h.iter_dim('n');
2067 int k_iter_blk = h.iter_dim('k');
2068 int a_slm_size = m_tg_blk * k_iter_blk * prb.a_data_type_size;
2069 int b_slm_size = n_tg_blk * k_iter_blk * prb.b_data_type_size;
2070
2071 int ret = 0;
2072 if (slm.a()) ret += a_slm_size;
2073 if (slm.b()) ret += b_slm_size;
2074 ret *= slm.bufs();
2075
2076 return ret;
2077}
2078
2079status_t init_cfg(conv_config_t &cfg, const convolution_pd_t *pd) {
2080 cfg.set_pd(pd);
2081
2082 // Try large GRF mode first.
2083 int try_regs = cfg.hw_cfg().large_grf_support() ? 256 : 128;
2084 //if (prb.g == 1 && prb.is_f32_conv()) try_regs = 128;
2085
2086 int def_max_tg_size
2087 = get_default_max_tg_size(cfg.hw_cfg(), try_regs, cfg.simd());
2088 conv_hint_t hint(def_max_tg_size);
2089
2090 // Use fixed iterations to avoid infinite loop.
2091 int max_iters = 10;
2092 bool ok = false;
2093 for (int i = 0; i < max_iters; i++) {
2094 conv_config_t try_cfg = cfg;
2095 try_cfg.set_regs(try_regs);
2096 try_cfg.set_hint(hint);
2097
2098 CHECK(try_init_cfg(try_cfg));
2099
2100 // Reduce thread group size if SLM size is too large.
2101 if (try_cfg.check_slm_size()) {
2102 if (slm_size(try_cfg) > max_slm_size(try_cfg)) {
2103 hint.set_max_tg_size(hint.max_tg_size() / 2);
2104 continue;
2105 }
2106 }
2107
2108 // If the kernel fits 128 registers, switch to the normal mode which is
2109 // expected to have better performance for such cases.
2110 int bound = (!try_cfg.is_dp_fma() ? 128 : 116);
2111 int estimated_peak_grf_usage = estimate_register_count(try_cfg);
2112 if (try_regs == 256 && estimated_peak_grf_usage <= bound) {
2113 try_regs = 128;
2114 continue;
2115 }
2116 cfg = try_cfg;
2117 ok = true;
2118 break;
2119 }
2120
2121 return ok ? status::success : status::runtime_error;
2122}
2123
2124void conv_config_t::override_set(const std::string &s) {
2125 std::vector<param_t *> params;
2126 for (auto &gp : get_params_)
2127 params.push_back(gp(this));
2128 auto parts = ir_utils::split(s);
2129 for (auto &p : parts) {
2130 auto sub_parts = ir_utils::split(p, "=");
2131 ir_assert(sub_parts.size() == 2);
2132 auto &key = sub_parts[0];
2133 auto &value = sub_parts[1];
2134 bool found = false;
2135 for (auto *p : params) {
2136 if (p->accept_key(key)) {
2137 ir_info() << "Override " << p->name() << ": " << key << "="
2138 << value << std::endl;
2139 p->override_set(key, value);
2140 found = true;
2141 break;
2142 }
2143 }
2144 ir_assert(found) << p;
2145 }
2146}
2147
2148int get_thread_count(const conv_config_t &cfg) {
2149 return cfg.kernel_grid().elems() * cfg.thread_group_grid().elems();
2150}
2151
2152// Return thread utilization as a percentage. If this value is low,
2153// parallelism is a fundamental limitation to the current work scheduling.
2154float get_thread_utilization(const conv_config_t &cfg) {
2155 auto arch = convert_ngen_arch_to_dnnl(cfg.hw());
2156 int slice_eu_count = compute::device_info_t::max_eus_per_wg(arch);
2157 int slice_count = cfg.hw_cfg().eu_count() / slice_eu_count;
2158
2159 int min_wg_per_slice_wave
2160 = std::max(slice_eu_count / cfg.thread_group_grid().elems(), 1);
2161 int min_wg_per_wave = slice_count * min_wg_per_slice_wave;
2162 int wg = cfg.kernel_grid().elems();
2163 return ((float)wg / utils::rnd_up(wg, min_wg_per_wave)) * 100;
2164}
2165
2166// Return wave utilization as a percentage. If this value is low, memory
2167// latency may be an issue due to limited use of SMT to hide the latency.
2168float get_wave_utilization(const conv_config_t &cfg) {
2169 auto arch = convert_ngen_arch_to_dnnl(cfg.hw());
2170 int threads_per_eu = compute::device_info_t::threads_per_eu(
2171 arch, cfg.hw_cfg().large_grf_support());
2172 int slice_eu_count = compute::device_info_t::max_eus_per_wg(arch);
2173 int slice_count = cfg.hw_cfg().eu_count() / slice_eu_count;
2174
2175 int max_wg_per_slice_wave
2176 = slice_eu_count * threads_per_eu / cfg.thread_group_grid().elems();
2177 int max_wg_per_wave = slice_count * max_wg_per_slice_wave;
2178 int wg = cfg.kernel_grid().elems();
2179 return ((float)wg / utils::rnd_up(wg, max_wg_per_wave)) * 100;
2180}
2181
2182std::string conv_config_t::str() const {
2183 using namespace ir_utils;
2184
2185 std::ostringstream oss;
2186 // clang-format off
2187 oss << " HW config: " << exec_cfg().str(hint().max_tg_size()) << std::endl;
2188 oss << " Problem: " << prb().desc_str() << std::endl;
2189 const char *names[] = {"Source", "Weights", "Destination"};
2190 const layout_param_t *layouts[] = {&src_layout(), &wei_layout(), &dst_layout()};
2191 for (int i = 0; i < 3; i++) {
2192 std::string desc = std::string(names[i]) + " layout:";
2193 desc.insert(desc.size(), 28 - desc.size(), ' ');
2194 auto &compute_layout = layouts[i]->compute_unnormalized();
2195 auto &user_layout = layouts[i]->user_unnormalized();
2196 oss << " " << desc << compute_layout;
2197 if (user_layout != compute_layout) {
2198 oss << " (user: " << user_layout << ")";
2199 }
2200 oss << std::endl;
2201 }
2202 int estimated_peak_grf_usage = estimate_register_count(*this);
2203 oss << blocking_brief_str();
2204 oss << " Kernel grid: " << kernel_grid() << std::endl;
2205 oss << " Thread group: " << thread_group_grid() << std::endl;
2206 oss << " Threads: " << get_thread_count(*this) << " (utilization: "
2207 << get_thread_utilization(*this) << "% thread, "
2208 << get_wave_utilization(*this) << "% wave)" << std::endl;
2209 oss << " FMA kind: " << fma_kind::to_string(fma_kind()) << std::endl;
2210 oss << " SLM buffering: " << "A: " << to_string(slm().a()) << ", B: " << to_string(slm().b())
2211 << ", buffers: " << slm().bufs() << ", pad: " << to_string(pad_slm()) << std::endl;
2212 oss << " GRF buffers for GMEM load: " << slm().gmem_bufs() << std::endl;
2213 oss << " Prefetch: " << to_string(prefetch()) << ", buffers: " << prefetch().bufs() << std::endl;
2214 oss << " Do pipeline unroll: " << to_string(pipeline().do_unroll()) << std::endl;
2215 oss << " Assign SBIDs: " << to_string(assign_sbids()) << std::endl;
2216 oss << " Reduce GRF usage: " << to_string(reduce_grf_usage()) << std::endl;
2217 oss << " Reuse headers: " << to_string(pipeline().reuse_headers()) << std::endl;
2218 oss << " Allow GRF reorder: " << "A: " << to_string(allow_a_grf_reorder()) << ", B: " << to_string(allow_b_grf_reorder()) << std::endl;
2219 oss << " Subtiles: " << "A: " << subtiles().a() << ", B: " << subtiles().b() << std::endl;
2220 oss << " Estimated GRF usage: " << estimated_peak_grf_usage << std::endl;
2221 // clang-format on
2222 return oss.str();
2223}
2224
2225std::string pad_str(std::string s, int pad) {
2226 auto pos = (pad >= 0 ? 0 : s.length());
2227 s.insert(pos, std::abs(pad) - s.length(), ' ');
2228 return s;
2229}
2230
2231std::string pad_int(int i, int pad) {
2232 return pad_str(std::to_string(i), pad);
2233}
2234
2235std::string conv_config_t::blocking_brief_str() const {
2236 std::ostringstream oss;
2237 std::vector<std::string> names;
2238 for (auto &kv : dims().get()) {
2239 names.push_back(kv.first);
2240 }
2241 std::sort(names.begin(), names.end());
2242 for (auto &name : names) {
2243 int iter = iter_dim(name);
2244 int tg = thread_group_dim(name);
2245 int loop = loop_dim(name);
2246 int grid = grid_dim(name);
2247 if (iter == 1 && loop == 1 && tg == 1) continue;
2248 oss << " Dimension " << name << pad_str(":", -18 + (int)name.length());
2249 oss << "(grid:" << pad_int(grid, 5) << ") x ";
2250 oss << "(tg:" << pad_int(tg, 5) << ") x ";
2251 oss << "(loop:" << pad_int(loop, 5) << ") x ";
2252 oss << "(iter:" << pad_int(iter, 5) << ")\n";
2253 }
2254 return oss.str();
2255}
2256
2257bool conv_config_t::can_skip_wei_zero_out() const {
2258 if (!prb().is_bwd_w) return true;
2259 bmnk_dim_helper_t h(*this);
2260 int k_iter_dim = h.iter_dim('k');
2261 int k_loop_dim = h.loop_dim('k');
2262 int k_tg_dim = h.thread_group_dim('k');
2263 int k_tg_block = k_iter_dim * k_loop_dim * k_tg_dim;
2264 int k_padded = padded_dim("mb") * padded_dim("od") * padded_dim("oh")
2265 * padded_dim("ow");
2266 return k_tg_block >= k_padded;
2267}
2268
2269bool conv_config_t::can_skip_bia_zero_out() const {
2270 if (!prb().is_bwd_w || !prb().with_bias) return true;
2271 return can_skip_wei_zero_out() && !slm().b();
2272}
2273
2274void init_extra_tensors(const conv_config_t &cfg, tensor_config_t &tensor_cfg) {
2275 const auto &prb = cfg.prb();
2276 auto &zp_cfg = prb.zp_cfg;
2277 auto *pd = prb.conv_pd;
2278 auto *attr = prb.attr;
2279 if (zp_cfg.do_src_compensation && zp_cfg.is_runtime_src_zero_points) {
2280 int zp_ic = (zp_cfg.is_common_src_zero_point) ? 1 : prb.ic;
2281 std::vector<dim_t> dims = {zp_ic};
2282 layout_t zp_layout(type_t::s32(), 0, dims);
2283 int arg_key = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC;
2284 tensor_cfg.add_tensor("src_zero_points", arg_key,
2285 /*is_input=*/true, /*is_output=*/false, zp_layout);
2286 }
2287 if (zp_cfg.do_dst_compensation && zp_cfg.is_runtime_dst_zero_points) {
2288 std::vector<dim_t> dims = {prb.oc};
2289 layout_t zp_layout(type_t::s32(), 0, dims);
2290 int arg_key = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST;
2291 tensor_cfg.add_tensor("dst_zero_points", arg_key,
2292 /*is_input=*/true, /*is_output=*/false, zp_layout);
2293 }
2294 auto scale_args = get_scale_args(prb);
2295 const char *scale_names[] = {"src_scales", "wei_scales", "dst_scales"};
2296 const int scale_names_len = sizeof(scale_names) / sizeof(scale_names[0]);
2297 ir_assert((int)scale_args.size() == scale_names_len);
2298 for (int i = 0; i < (int)scale_args.size(); i++) {
2299 int arg = scale_args[i];
2300 auto &s = attr->scales_.get(arg);
2301 if (s.has_default_values()) continue;
2302 int dim = s.mask_ == 0 ? 1 : (prb.is_fwd ? prb.oc : prb.ic);
2303 std::vector<dim_t> dims = {dim};
2304 layout_t layout(type_t::f32(), 0, dims);
2305 int arg_key = DNNL_ARG_ATTR_SCALES | arg;
2306 tensor_cfg.add_tensor(scale_names[i], arg_key, /*is_input=*/true,
2307 /*is_output=*/false, layout);
2308 }
2309 for (int i = 0; i < attr->post_ops_.len(); i++) {
2310 auto &po = attr->post_ops_.entry_[i];
2311 if (po.is_eltwise()
2312 || po.is_sum(/*require_scale_one=*/false,
2313 /*require_zp_zero=*/false)) {
2314 // No extra tensors.
2315 } else if (po.is_binary()) {
2316 auto layout = make_layout(po.binary.src1_desc);
2317 int arg_key = DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1;
2318 tensor_cfg.add_tensor("binary_rhs_" + std::to_string(i), arg_key,
2319 /*is_input=*/true,
2320 /*is_output=*/false, layout);
2321 } else if (po.is_prelu()) {
2322 layout_t layout(type_t::f32(), 0,
2323 get_prelu_weights_dims(
2324 po.prelu.mask, *pd->invariant_dst_md()));
2325 int arg_key = DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_WEIGHTS;
2326 tensor_cfg.add_tensor("prelu_rhs_" + std::to_string(i), arg_key,
2327 /*is_input=*/true, /*is_output=*/false, layout);
2328 } else {
2329 ir_error_not_expected();
2330 }
2331 }
2332}
2333
2334tensor_config_t get_tensor_config(const conv_config_t &cfg) {
2335 const auto &prb = cfg.prb();
2336 tensor_config_t tensor_cfg;
2337 conv_arg_helper_t h(prb);
2338 auto &src = cfg.src_layout();
2339 auto &wei = cfg.wei_layout();
2340 auto &bia = cfg.bia_layout();
2341 auto &dst = cfg.dst_layout();
2342 tensor_cfg.add_tensor("src", h.src_arg_key(), h.is_src_input(),
2343 h.is_src_output(), src.compute(), src.user());
2344 tensor_cfg.add_tensor("wei", h.wei_arg_key(), h.is_wei_input(),
2345 h.is_wei_output(), wei.compute(), wei.user());
2346 if (prb.with_bias)
2347 tensor_cfg.add_tensor("bia", h.bia_arg_key(), h.is_bia_input(),
2348 h.is_bia_output(), bia.compute(), bia.user());
2349 tensor_cfg.add_tensor("dst", h.dst_arg_key(), h.is_dst_input(),
2350 h.is_dst_output(), dst.compute(), dst.user());
2351 if (prb.is_bwd_w && !prb.with_sum) {
2352 tensor_cfg.require_zero_out("wei");
2353 if (prb.with_bias) tensor_cfg.require_zero_out("bia");
2354 }
2355 init_extra_tensors(cfg, tensor_cfg);
2356 return tensor_cfg;
2357}
2358
2359int estimate_register_count(const conv_config_t &cfg) {
2360 return estimate_grf_usage(cfg).total();
2361}
2362
2363bool can_use_a_2d_send(const conv_config_t &cfg) {
2364 return can_use_2d_send(cfg, cfg.a_layout().compute_unnormalized(), true);
2365}
2366
2367bool can_use_b_2d_send(const conv_config_t &cfg) {
2368 return can_use_2d_send(cfg, cfg.b_layout().compute_unnormalized(), false);
2369}
2370
2371} // namespace jit
2372} // namespace gpu
2373} // namespace impl
2374} // namespace dnnl
2375