1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/conv/block_helper.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace jit {
23
24// Helper class to assign block sizes for problem dimensions according to BMNK
25// block sizes.
26class block_assigner_t {
27public:
28 block_assigner_t(const dim_info_t &bmnk_dim,
29 const std::vector<dim_info_t *> &prb_dims)
30 : bmnk_dim_(bmnk_dim), prb_dims_(prb_dims) {
31 move_to_next_bmnk_level();
32 move_to_next_prb_dim();
33 }
34
35 bool has_blocks() const {
36 if (utils::one_of(level_, tile_level_t::unknown, tile_level_t::_last))
37 return false;
38 if (prb_dim_idx_ >= (int)prb_dims_.size()) return false;
39 return true;
40 }
41
42 void assign_block() {
43 ir_assert(has_blocks());
44 ir_assert(rem_bmnk_dim_.is_unlimited() || rem_bmnk_dim_ > 1);
45 ir_assert(rem_prb_dim_ > 1);
46
47 // Try a shortcut path to assign all dimensions to the current tile
48 // level at once.
49 if (try_assign_multi_blocks()) return;
50
51 dim_value_t target_dim
52 = min(rem_bmnk_dim_, prb_dims_[prb_dim_idx_]->max_dim(level_));
53 int dim = compute_next_block(level_, prb_level_dim(), target_dim,
54 bmnk_dim_.base_iter_block(), rem_prb_dim_,
55 prb_dims_[prb_dim_idx_]->base_iter_block(), is_last_prb_dim(),
56 prb_dims_[prb_dim_idx_]->pad_block());
57 if (level_ == tile_level_t::iter) {
58 ir_assert(dim % prb_dims_[prb_dim_idx_]->base_iter_block() == 0);
59 }
60
61 // Assign the computed block size to the current problem dimension.
62 prb_dims_[prb_dim_idx_]->set_dim(level_, dim);
63
64 // Update the remaining dimensions.
65 if (!rem_bmnk_dim_.is_unlimited()) rem_bmnk_dim_ = rem_bmnk_dim_ / dim;
66 rem_prb_dim_ = utils::div_up(rem_prb_dim_, dim);
67
68 ir_assert(rem_bmnk_dim_.is_unlimited() || rem_bmnk_dim_ >= 1);
69 ir_assert(rem_prb_dim_ >= 1);
70
71 // Move to the next BMNK tile or next problem dimension (depending on
72 // split/fuse settings and remaining sizes).
73 if (rem_bmnk_dim_ != 1 && rem_prb_dim_ != 1) {
74 if (allow_fuse()) {
75 rem_prb_dim_ = 1;
76 } else if (prb_dims_[prb_dim_idx_]->allow_split()) {
77 rem_bmnk_dim_ = 1;
78 } else {
79 rem_prb_dim_ = 1;
80 rem_bmnk_dim_ = 1;
81 }
82 }
83
84 if (rem_bmnk_dim_ == 1 && rem_prb_dim_ == 1) {
85 move_to_next_bmnk_level();
86 move_to_next_prb_dim();
87 return;
88 }
89
90 if (rem_bmnk_dim_ == 1) {
91 move_to_next_bmnk_level();
92 if (!prb_dims_[prb_dim_idx_]->allow_split()) {
93 move_to_next_prb_dim();
94 }
95 return;
96 }
97 if (rem_prb_dim_ == 1) {
98 move_to_next_prb_dim();
99 if (!allow_fuse()) move_to_next_bmnk_level();
100 return;
101 }
102 ir_error_not_expected();
103 }
104
105private:
106 bool allow_fuse() const {
107 for (auto *d : prb_dims_)
108 if (!d->allow_fuse()) return false;
109 return true;
110 }
111
112 void move_to_next_bmnk_level() {
113 bool found = false;
114 int l_beg = (int)level_ + 1;
115 int l_end = (int)tile_level_t::_last;
116 for (int l = l_beg; l < l_end; l++) {
117 tile_level_t level = (tile_level_t)l;
118 if (bmnk_dim_.dim(level) != 1) {
119 found = true;
120 level_ = level;
121 rem_bmnk_dim_ = bmnk_dim_.dim(level);
122 break;
123 }
124 }
125 if (!found) level_ = tile_level_t::_last;
126 }
127
128 void move_to_next_prb_dim() {
129 bool found = false;
130 for (int i = prb_dim_idx_ + 1; i < (int)prb_dims_.size(); i++) {
131 if (prb_dims_[i]->size() != 1) {
132 found = true;
133 prb_dim_idx_ = i;
134 rem_prb_dim_ = prb_dims_[i]->size();
135 break;
136 }
137 }
138 if (!found) prb_dim_idx_ = (int)prb_dims_.size();
139 }
140
141 bool is_last_prb_dim() const {
142 if (!allow_fuse()) return true;
143 for (int i = prb_dim_idx_ + 1; i < (int)prb_dims_.size(); i++) {
144 auto *d = prb_dims_[i];
145 int max_dim = min(d->size(), d->max_dim(level_));
146 if (max_dim != 1) return false;
147 }
148 return true;
149 }
150
151 int prb_level_dim() const {
152 int ret = 1;
153 for (auto *d : prb_dims_)
154 ret *= d->dim(level_);
155 return ret;
156 }
157
158 int compute_next_block(tile_level_t level, int level_dim,
159 dim_value_t target_dim, int target_base_blk, int dim,
160 int base_iter_block, bool is_last_dim, int pad_block,
161 double target_eff = 0.75) const {
162 if (target_dim.is_unlimited()) return dim;
163
164 bool require_pow_2 = false;
165 if (level == tile_level_t::tg) require_pow_2 = true;
166 if (level == tile_level_t::iter
167 && (bmnk_dim_.bmnk() == 'N'
168 || (bmnk_dim_.bmnk() == 'M'
169 && bmnk_dim_.inner_dims() == 1)))
170 require_pow_2 = true;
171
172 int step = 1;
173 int rem_target_base_blk = 1;
174 if (level == tile_level_t::iter) {
175 rem_target_base_blk
176 = target_base_blk / math::gcd(level_dim, target_base_blk);
177 step = base_iter_block;
178 ir_assert(rem_target_base_blk % base_iter_block == 0);
179 if (is_last_dim) step = rem_target_base_blk;
180 }
181
182 int dimension_bound = [&]() {
183 // Prefer powers of 2 for small dimensions as they generally
184 // result in better load instructions.
185 int bound = utils::rnd_up_pow2(dim);
186 if (bound % step == 0 && (double)dim / bound >= target_eff)
187 return bound;
188 return utils::rnd_up(dim, step);
189 }();
190 int target_bound = utils::rnd_dn(target_dim, step);
191 int ret = std::min(dimension_bound, target_bound);
192
193 while (ret >= step) {
194 bool ok = true;
195 if (require_pow_2 && !math::is_pow2(ret)) ok = false;
196 if (!is_last_dim) {
197 if (ret % rem_target_base_blk != 0
198 && rem_target_base_blk % ret != 0)
199 ok = false;
200 }
201 if (ok) {
202 int dim_padded = utils::rnd_up(dim, ret);
203 double eff = (double)dim / dim_padded;
204 if (eff >= target_eff) break;
205 }
206 ret -= step;
207 }
208 if (ret == 0) ret = step;
209 if (require_pow_2) ir_assert(math::is_pow2(ret));
210 if (level == tile_level_t::iter) ir_assert(ret % base_iter_block == 0);
211 if (pad_block > ret && pad_block % ret != 0)
212 ret = math::gcd(pad_block, ret);
213 return ret;
214 }
215
216 bool is_loop_unlimited() const {
217 if (level_ != tile_level_t::loop) return false;
218 if (rem_bmnk_dim_.is_unlimited()) return false;
219 if (bmnk_dim_.tg_dim() != 1) return false;
220 return true;
221 }
222
223 bool is_iter_full_match() const {
224 if (level_ != tile_level_t::iter) return false;
225
226 int prb_total = 1;
227 for (auto *d : prb_dims_) {
228 if (d->iter_dim() != 1) return false;
229 int max_iter_dim = min(d->size(), d->max_dim(tile_level_t::iter));
230 prb_total *= max_iter_dim;
231 }
232
233 if (rem_bmnk_dim_ != prb_total) return false;
234 return true;
235 }
236
237 bool try_assign_multi_blocks() {
238 // Check restrictions to apply the heuristics.
239 if (!allow_fuse()) return false;
240 if (!is_loop_unlimited() && !is_iter_full_match()) return false;
241
242 int nprb_dims = (int)prb_dims_.size();
243 std::vector<int> rem_dims(nprb_dims, 1);
244 std::vector<int> dims(nprb_dims, 1);
245
246 int max_total_dim = 1;
247 for (int i = prb_dim_idx_; i < nprb_dims; i++) {
248 int dim = prb_dims_[i]->size();
249 int rem_dim = (i == prb_dim_idx_) ? (int)rem_prb_dim_ : dim;
250 rem_dim = min(rem_dim, prb_dims_[i]->max_dim(level_));
251 rem_dims[i] = rem_dim;
252 max_total_dim *= rem_dim;
253 }
254
255 bool found = false;
256 std::function<void(int, int, double, double)> step;
257 step = [&](int idx, int total_dim, double eff, double target_eff) {
258 if (total_dim > rem_bmnk_dim_) return;
259 if (eff < target_eff) return;
260 if (idx == nprb_dims) {
261 double min_dim_ratio = 0.5;
262 double dim_ratio = total_dim / (double)rem_bmnk_dim_;
263 // If all available dimensions are assigned, skip any checks.
264 if (total_dim != max_total_dim) {
265 // Skip if the full dimension is too small relative to the
266 // target size.
267 if (dim_ratio < min_dim_ratio) return;
268 // Skip if the padding due to blocking is too large.
269 if (eff < target_eff) return;
270 }
271 // Found good blocking, set the flag.
272 found = true;
273 return;
274 }
275 int dim = prb_dims_[idx]->size();
276 int rem_dim = rem_dims[idx];
277 for (int blk = rem_dim; blk >= 1; blk--) {
278 int dim_padded = utils::rnd_up(dim, blk);
279 double dim_eff = (double)dim / dim_padded;
280 dims[idx] = blk;
281 step(idx + 1, total_dim * blk, eff * dim_eff, target_eff);
282 if (found) break;
283 }
284 };
285
286 if (level_ == tile_level_t::iter) {
287 // is_iter_full_match() returned true so all dimensions can be
288 // assigned as is.
289 ir_assert(rem_bmnk_dim_ == max_total_dim);
290 for (int i = prb_dim_idx_; i < nprb_dims; i++) {
291 dims[i] = rem_dims[i];
292 }
293 found = true;
294 } else {
295 // Try to target different efficiencies until a good blocking is
296 // found.
297 for (double eff = 1.0; eff >= 0.5; eff -= 0.05) {
298 step(prb_dim_idx_, 1, 1.0, eff);
299 if (found) break;
300 }
301 }
302
303 ir_assert(found) << "Can't assign blocks.";
304 for (int i = prb_dim_idx_; i < nprb_dims; i++) {
305 prb_dims_[i]->set_dim(level_, dims[i]);
306 }
307
308 prb_dim_idx_ = nprb_dims;
309 return true;
310 }
311
312 tile_level_t level_ = tile_level_t::unknown;
313
314 dim_info_t bmnk_dim_;
315 dim_value_t rem_bmnk_dim_;
316
317 std::vector<dim_info_t *> prb_dims_;
318 int prb_dim_idx_ = -1;
319 dim_value_t rem_prb_dim_ = 0;
320};
321
322void block_helper_t::compute() {
323 is_frozen_ = true;
324 ir_assert(vectorize_by_b() || vectorize_by_n());
325
326 init_bmnk_dims();
327 init_bmnk_blocks();
328 init_prb_blocks();
329
330 // XXX: Fix up loop dims for when K grid slicing is disabled.
331 if (!allow_k_grid_slicing_) {
332 for (auto &kv : dims_) {
333 auto &d = kv.second;
334 if (d.bmnk() != 'K') continue;
335 d.set_loop_dim(1);
336 d.set_loop_dim(d.grid_dim());
337 }
338 }
339
340#ifdef GEN_CONV_DEBUG
341 for (auto &kv : dims_) {
342 auto &d = kv.second;
343 const char *tags[] = {"iter", "tg", "loop"};
344 for (int i = min_tile_level_idx; i <= max_tile_level_idx; i++) {
345 auto level = (tile_level_t)i;
346 std::string env_name
347 = d.name() + "_" + tags[i - min_tile_level_idx] + "_dim";
348 int env_dim = getenv_int(env_name.c_str(), -1);
349 if (env_dim != -1) d.set_dim(level, env_dim);
350 }
351 }
352#endif
353
354 // Verify blocks.
355 for (auto &kv : dims_) {
356 auto &d = kv.second;
357 ir_assert(d.iter_dim() % d.base_iter_block() == 0);
358 for (int i = min_tile_level_idx; i <= max_tile_level_idx; i++) {
359 auto level = (tile_level_t)i;
360 auto max_dim = d.max_dim(level);
361 ir_assert(max_dim.is_unlimited() || d.dim(level) <= max_dim);
362 }
363 }
364
365 for (char bmnk : {'B', 'M', 'N', 'K'}) {
366 int iter_blk = 1;
367 for (auto &kv : dims_) {
368 auto &d = kv.second;
369 if (d.bmnk() != bmnk) continue;
370 iter_blk *= d.iter_dim();
371 }
372 ir_assert(iter_blk % bmnk_dim(bmnk).base_iter_block() == 0);
373 }
374}
375
376void block_helper_t::init_bmnk_blocks() {
377 int m_blk = 0;
378 int k_blk = 0;
379 int bn_blk = 0;
380 int m_inst_blk = 0;
381 int k_inst_blk = 0;
382 int bn_inst_blk = 0;
383 bool is_ge_hpc = (hw_cfg_.hw() >= ngen::HW::XeHPC);
384 bool reduce_m_block = false;
385 if (reduce_m_block_hint_set_) {
386 reduce_m_block = reduce_m_block_hint_;
387 } else {
388 if (m_dim().base_iter_block() == 1) reduce_m_block = true;
389 if (k_dim().base_iter_block() == 1) reduce_m_block = true;
390 }
391 if (is_tf32() && fma_kind_ != fma_kind_t::mad) reduce_m_block = true;
392 int eu_thr_mul = (!is_ge_hpc && reduce_m_block) ? 2 : 4;
393#ifdef GEN_CONV_DEBUG
394 eu_thr_mul = getenv_int("eu_thr_mul", eu_thr_mul);
395#endif
396 auto &bn_dim = (vectorize_by_b() ? b_dim() : n_dim());
397 switch (fma_kind_) {
398 case fma_kind_t::mad: {
399 int max_m_iter_dim = prb_max_dim('M', tile_level_t::iter);
400 m_inst_blk = std::min(8, utils::rnd_down_pow2(max_m_iter_dim));
401 bn_inst_blk = vec_size_;
402 k_inst_blk = 1;
403 bool use_small_m_block = hw_cfg_.hw() <= ngen::HW::XeHP
404 && m_dim().base_iter_block() == 1;
405 m_blk = (is_x8x8s32() || use_small_m_block ? 8 : 16);
406 bool small_m_tg = m_dim().base_iter_block() == 1
407 && hw_cfg_.hw() == ngen::HW::XeHPG
408 && !m_dim().pref_tg_block();
409 if (!m_dim().pref_tg_block())
410 m_dim().set_max_dim(tile_level_t::tg, small_m_tg ? 1 : 4);
411 bn_blk = vec_size_;
412 k_blk = compute_mad_k_block();
413 if (!allow_k_grid_slicing_ && !allow_k_tg_slicing_) {
414 do {
415 int est_bmn_threads = 1;
416 est_bmn_threads *= utils::div_up(m_dim().size(), m_blk);
417 est_bmn_threads *= utils::div_up(bn_dim.size(), bn_blk);
418 if (est_bmn_threads >= eu_thr_mul * hw_cfg_.eu_count())
419 break;
420 m_blk /= 2;
421 m_inst_blk = std::min(m_inst_blk, m_blk);
422 } while (m_blk > 1);
423 }
424 break;
425 }
426 case fma_kind_t::dp4a:
427 case fma_kind_t::dpas:
428 case fma_kind_t::dpasw: {
429 ir_assert(vectorize_by_n())
430 << "dpas can support N vectorization only.";
431 int max_iter_dim = prb_max_dim('M', tile_level_t::iter);
432 int target_m_blk = reduce_m_block ? 16 : 32;
433 if (max_iter_dim % target_m_blk != 0 && max_iter_dim > 32) {
434 float max_utilization_rate = 0.;
435 for (int i
436 = min(32, utils::rnd_dn((int)(1.5 * target_m_blk), 4));
437 i > target_m_blk; i -= 4) {
438 float utilization_rate = (float)max_iter_dim
439 / utils::rnd_up(max_iter_dim, i);
440 // Heuristic constant preferring larger blocks, experimentally determined.
441 const float threshhold = 1.05;
442 if (utilization_rate > threshhold * max_utilization_rate) {
443 max_utilization_rate = utilization_rate;
444 target_m_blk = i;
445 }
446 }
447 }
448 m_blk = target_m_blk;
449 m_inst_blk = m_blk % 8 == 0 ? 8 : m_blk;
450 bn_inst_blk = 8;
451 k_inst_blk = is_x8x8s32() ? 32 : 16;
452 bn_blk = is_ge_hpc ? 64 : 32;
453 int est_bmn_threads = 1;
454 est_bmn_threads *= utils::div_up(m_dim().size(), m_blk);
455 est_bmn_threads *= utils::div_up(bn_dim.size(), bn_blk);
456 auto thread_factor
457 = utils::div_up(hw_cfg_.eu_count(), est_bmn_threads);
458 if (thread_factor > (is_ge_hpc ? 6 : 2) && !allow_k_grid_slicing_) {
459 if (m_inst_blk % 2 == 0 && bn_blk != bn_dim.size()) {
460 m_blk /= 2;
461 if (m_blk % m_inst_blk != 0) m_inst_blk /= 2;
462 }
463 }
464
465 k_blk = k_inst_blk;
466 break;
467 }
468 default: ir_error_not_expected();
469 }
470
471 m_blk = math::lcm(m_blk, m_dim().base_iter_block());
472 k_blk = math::lcm(k_blk, k_dim().base_iter_block());
473 bn_blk = math::lcm(bn_blk, bn_dim.base_iter_block());
474
475 // Shrink block sizes to leverage is_iter_full_match() when applicable.
476 const char bmnks[] = {'M', 'K', vectorize_by_b() ? 'B' : 'N'};
477 int *blocks[] = {&m_blk, &k_blk, &bn_blk};
478 int inst_blocks[] = {m_inst_blk, k_inst_blk, bn_inst_blk};
479 for (int i = 0; i < 3; i++) {
480 int max_iter_dim = prb_max_dim(bmnks[i], tile_level_t::iter);
481 int base_blk = bmnk_dim(bmnks[i]).base_iter_block();
482 int &blk = *blocks[i];
483 int inst_blk = inst_blocks[i];
484 if (max_iter_dim % base_blk == 0 && max_iter_dim % inst_blk == 0) {
485 blk = std::min(blk, max_iter_dim);
486 }
487 ir_assert(blk % inst_blk == 0);
488 }
489
490 // Pad base iteration blocks according to instruction blocks.
491 for (char bmnk : {'B', 'M', 'N', 'K'}) {
492 bool is_bn = utils::one_of(bmnk, 'B', 'N');
493 auto &d = bmnk_dim(bmnk);
494 if (is_bn && !vectorize_by_bmnk(bmnk)) continue;
495
496 int blk = d.base_iter_block();
497 int inst_blk
498 = is_bn ? bn_inst_blk : (bmnk == 'M') ? m_inst_blk : k_inst_blk;
499 d.set_base_iter_block(math::lcm(blk, inst_blk));
500 }
501
502 m_blk = compute_block(m_dim().size(), m_blk, m_dim().base_iter_block());
503 // Require pow2 when only one m dim is non-trivial
504 if (m_dim().inner_dims() == 1) m_blk = utils::rnd_down_pow2(m_blk);
505 k_blk = compute_block(k_dim().size(), k_blk, k_dim().base_iter_block());
506 bn_blk = compute_block(bn_dim.size(), bn_blk, bn_dim.base_iter_block());
507
508#ifdef GEN_CONV_DEBUG
509 m_blk = getenv_int("m_iter_blk", m_blk);
510 k_blk = getenv_int("k_iter_blk", k_blk);
511 bn_blk = getenv_int("bn_iter_blk", bn_blk);
512 if (vectorize_by_b()) {
513 bn_blk = getenv_int("b_iter_blk", bn_blk);
514 } else {
515 bn_blk = getenv_int("n_iter_blk", bn_blk);
516 }
517#endif
518 m_dim().set_iter_dim(m_blk);
519 bn_dim.set_iter_dim(bn_blk);
520 k_dim().set_iter_dim(k_blk);
521
522 init_k_blocking();
523
524 for (char bmnk : {'B', 'M', 'N', 'K'}) {
525 auto &d = bmnk_dim(bmnk);
526 ir_assert(d.iter_dim() % d.base_iter_block() == 0
527 || d.base_iter_block() % d.iter_dim() == 0);
528 }
529
530 // Set thread group blocks.
531 bool with_k_tg_slicing = (k_dim().tg_dim() > 1);
532 if (!with_k_tg_slicing) {
533 int target_tg_size = max_tg_size_;
534 int est_threads = 1;
535 for (char bmnk : {'B', 'M', 'N'})
536 est_threads *= bmnk_dim(bmnk).grid_dim();
537 if (est_threads < 2 * hw_cfg_.eu_count()
538 || hw_cfg_.hw() >= ngen::HW::XeHPC) {
539 target_tg_size = std::min(target_tg_size, 16);
540 }
541
542 auto init_tg_dim = [&](dim_info_t &d) -> int {
543 int i_max_tg_dim = min(target_tg_size, d.max_dim(tile_level_t::tg));
544 int target_blk = i_max_tg_dim * d.iter_dim();
545 int base_tg_dim = compute_block(d.size(), target_blk, d.iter_dim())
546 / d.iter_dim();
547 //restrict maximum single tg dim as max_tg size is reduced
548 return std::min(utils::rnd_down_pow2(base_tg_dim),
549 max_tg_overridden_ ? target_tg_size
550 / (hw_cfg_.hw() >= ngen::HW::XeHPC ? 2 : 4)
551 : simd_size_);
552 };
553
554 // Compute max thread group blocks, independently for each dimension.
555 std::vector<char> tg_bmnks = {vectorize_by_b() ? 'B' : 'N', 'M'};
556 std::vector<int> tg_dims(tg_bmnks.size(), 1);
557 bool any_pref_dim = any_pref_tg_block();
558 int *split_dim_idx, *pref_dim_idx;
559 char split_dim_bmnk;
560 for (size_t i = 0; i < tg_bmnks.size(); i++) {
561 auto &d = bmnk_dim(tg_bmnks[i]);
562 int tg_dim = init_tg_dim(d);
563 tg_dims[i] = tg_dim;
564 if (!d.pref_tg_block()) {
565 split_dim_idx = &tg_dims.at(i);
566 split_dim_bmnk = tg_bmnks[i];
567 } else {
568 pref_dim_idx = &tg_dims.at(i);
569 }
570 }
571
572 auto total_tg_dim = [&]() {
573 return std::accumulate(
574 tg_dims.begin(), tg_dims.end(), 1, std::multiplies<int>());
575 };
576 auto max_tg_dim = [&]() -> int & {
577 if (any_pref_dim) {
578 auto split_dim = bmnk_dim(split_dim_bmnk);
579 int split_rem_grid = split_dim.size() / split_dim.iter_dim();
580 // Special case preserve non-zero tg dim for non-pref bmnk
581 bool preserve_min_dim
582 = (split_rem_grid > 1 && *split_dim_idx < 4
583 && *pref_dim_idx > *split_dim_idx);
584 if (*split_dim_idx == 1
585 || (preserve_min_dim
586 && (*pref_dim_idx > *split_dim_idx)))
587 return *pref_dim_idx;
588 else
589 return *split_dim_idx;
590 }
591 return *std::max_element(tg_dims.begin(), tg_dims.end());
592 };
593
594 // Reduce thread group size until it fits the target size.
595 while (total_tg_dim() > target_tg_size) {
596 max_tg_dim() /= 2;
597 }
598
599 for (size_t i = 0; i < tg_bmnks.size(); i++) {
600 auto &d = bmnk_dim(tg_bmnks[i]);
601 d.set_tg_dim(tg_dims[i]);
602 }
603 }
604}
605
606void block_helper_t::init_k_blocking() {
607 // Thread and thread group dims must not be set yet.
608 for (char bmnk : {'B', 'M', 'N', 'K'}) {
609 auto &d = bmnk_dim(bmnk);
610 ir_assert(d.loop_dim() == 1);
611 ir_assert(d.tg_dim() == 1);
612 }
613
614 if (allow_k_grid_slicing_) {
615 int est_threads = 1;
616 for (char bmnk : {'B', 'M', 'N', 'K'})
617 est_threads *= bmnk_dim(bmnk).grid_dim();
618 int def_k_loop_dim = utils::div_up(est_threads, 2 * hw_cfg_.eu_count());
619 def_k_loop_dim = std::min(100, def_k_loop_dim);
620 def_k_loop_dim = std::max(1, def_k_loop_dim);
621 int k_loop_dim = def_k_loop_dim;
622#ifdef GEN_CONV_DEBUG
623 k_loop_dim = getenv_int("k_loop_dim", k_loop_dim);
624#endif
625 k_dim().set_loop_dim(k_loop_dim);
626 return;
627 }
628
629 if (!enable_k_tg_slicing()) {
630 k_dim().set_loop_dim(dim_value_t::unlimited());
631 return;
632 }
633
634 int k_nblks = utils::div_up(
635 prb_blocked_dim('K').size(), k_dim().base_iter_block());
636 int tg_dim0 = min(max_tg_size_, k_dim().max_dim(tile_level_t::tg));
637 for (int tg_dim = tg_dim0; tg_dim >= 1; tg_dim /= 2) {
638 if (k_nblks % tg_dim == 0) {
639 k_dim().set_loop_dim(k_nblks / tg_dim);
640 k_dim().set_tg_dim(tg_dim);
641 return;
642 }
643 }
644
645 // Couldn't enable TG slicing.
646 k_dim().set_loop_dim(dim_value_t::unlimited());
647}
648
649bool block_helper_t::enable_k_tg_slicing() const {
650#ifdef GEN_CONV_DEBUG
651 int env_value = getenv_int("enable_k_tg_slicing", -1);
652 if (env_value != -1) return (bool)env_value;
653#endif
654 if (!allow_k_tg_slicing_) return false;
655
656 if (m_dim().iter_dim() > 16) return false;
657
658 // TG slicing is supported only when there is only one k dimension.
659 if (prb_blocked_ndims('K') > 1) return false;
660
661 // Do not enable TG slicing if there are enough non-K threads.
662 int non_k_threads = 1;
663 for (char bmnk : {'B', 'M', 'N'}) {
664 auto &d = bmnk_dim(bmnk);
665 non_k_threads *= d.grid_dim() * d.tg_dim();
666 }
667 if (non_k_threads >= hw_cfg_.eu_count()) return false;
668
669 // Do not enable TG slicing if reduction is small.
670 int k_nblks = utils::div_up(k_dim().size(), k_dim().base_iter_block());
671 if (k_nblks < 16) return false;
672
673 return true;
674}
675
676void block_helper_t::init_prb_blocks() {
677 // Pad sizes to base block multiples.
678 for (auto &kv : dims_) {
679 auto &d = kv.second;
680 d.set_size(utils::rnd_up(d.size(), d.base_iter_block()));
681 }
682
683 // Filter blocked dimensions and sort them according to their keys.
684 std::vector<dim_info_t *> sorted_dims;
685 for (auto &kv : dims_) {
686 auto &d = kv.second;
687 if (!d.is_blocked()) continue;
688 sorted_dims.push_back(&d);
689 }
690 std::sort(sorted_dims.begin(), sorted_dims.end(),
691 [](const dim_info_t *a, const dim_info_t *b) {
692 if (a->order_key() == b->order_key()) {
693 return a->name().compare(b->name()) < 0;
694 }
695 return a->order_key() < b->order_key();
696 });
697
698 for (char bmnk : {'B', 'N', 'M', 'K'}) {
699 std::vector<dim_info_t *> cur_dims;
700 for (auto *d : sorted_dims) {
701 if (d->bmnk() != bmnk) continue;
702 cur_dims.push_back(d);
703 }
704
705 ir_assert(!cur_dims.empty());
706
707 // Pad dimensions according to BMNK base block requirements.
708 int max_iter_dim = prb_max_dim(bmnk, tile_level_t::iter);
709 int base_blk = bmnk_dim(bmnk).base_iter_block();
710 if (max_iter_dim == 1 && base_blk > 1) {
711 ir_assert(cur_dims[0]->base_iter_block() == 1);
712 cur_dims[0]->set_size(base_blk);
713 }
714
715 block_assigner_t assigner(bmnk_dim(bmnk), cur_dims);
716 while (assigner.has_blocks()) {
717 assigner.assign_block();
718 }
719 }
720}
721
722int block_helper_t::compute_mad_k_block() const {
723 int k_base_blk = k_dim().base_iter_block();
724 if (k_base_blk >= 16) return k_base_blk;
725
726 bool is_fused = true;
727 int k_blocked_size = 1;
728 for (auto &kv : dims_) {
729 auto &d = kv.second;
730 if (!d.is_blocked()) continue;
731 if (d.bmnk() != 'K') continue;
732 k_blocked_size *= d.size();
733 if (!d.allow_fuse()) is_fused = false;
734 }
735
736 if (!is_fused) return 16;
737
738 int max_k_blk = 32;
739 if ((k_blocked_size <= max_k_blk)
740 && (k_blocked_size % k_dim().base_iter_block() == 0))
741 return k_blocked_size;
742
743 return 16;
744}
745
746} // namespace jit
747} // namespace gpu
748} // namespace impl
749} // namespace dnnl
750