1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/conv/block_helper.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace jit { |
23 | |
24 | // Helper class to assign block sizes for problem dimensions according to BMNK |
25 | // block sizes. |
26 | class block_assigner_t { |
27 | public: |
28 | block_assigner_t(const dim_info_t &bmnk_dim, |
29 | const std::vector<dim_info_t *> &prb_dims) |
30 | : bmnk_dim_(bmnk_dim), prb_dims_(prb_dims) { |
31 | move_to_next_bmnk_level(); |
32 | move_to_next_prb_dim(); |
33 | } |
34 | |
35 | bool has_blocks() const { |
36 | if (utils::one_of(level_, tile_level_t::unknown, tile_level_t::_last)) |
37 | return false; |
38 | if (prb_dim_idx_ >= (int)prb_dims_.size()) return false; |
39 | return true; |
40 | } |
41 | |
42 | void assign_block() { |
43 | ir_assert(has_blocks()); |
44 | ir_assert(rem_bmnk_dim_.is_unlimited() || rem_bmnk_dim_ > 1); |
45 | ir_assert(rem_prb_dim_ > 1); |
46 | |
47 | // Try a shortcut path to assign all dimensions to the current tile |
48 | // level at once. |
49 | if (try_assign_multi_blocks()) return; |
50 | |
51 | dim_value_t target_dim |
52 | = min(rem_bmnk_dim_, prb_dims_[prb_dim_idx_]->max_dim(level_)); |
53 | int dim = compute_next_block(level_, prb_level_dim(), target_dim, |
54 | bmnk_dim_.base_iter_block(), rem_prb_dim_, |
55 | prb_dims_[prb_dim_idx_]->base_iter_block(), is_last_prb_dim(), |
56 | prb_dims_[prb_dim_idx_]->pad_block()); |
57 | if (level_ == tile_level_t::iter) { |
58 | ir_assert(dim % prb_dims_[prb_dim_idx_]->base_iter_block() == 0); |
59 | } |
60 | |
61 | // Assign the computed block size to the current problem dimension. |
62 | prb_dims_[prb_dim_idx_]->set_dim(level_, dim); |
63 | |
64 | // Update the remaining dimensions. |
65 | if (!rem_bmnk_dim_.is_unlimited()) rem_bmnk_dim_ = rem_bmnk_dim_ / dim; |
66 | rem_prb_dim_ = utils::div_up(rem_prb_dim_, dim); |
67 | |
68 | ir_assert(rem_bmnk_dim_.is_unlimited() || rem_bmnk_dim_ >= 1); |
69 | ir_assert(rem_prb_dim_ >= 1); |
70 | |
71 | // Move to the next BMNK tile or next problem dimension (depending on |
72 | // split/fuse settings and remaining sizes). |
73 | if (rem_bmnk_dim_ != 1 && rem_prb_dim_ != 1) { |
74 | if (allow_fuse()) { |
75 | rem_prb_dim_ = 1; |
76 | } else if (prb_dims_[prb_dim_idx_]->allow_split()) { |
77 | rem_bmnk_dim_ = 1; |
78 | } else { |
79 | rem_prb_dim_ = 1; |
80 | rem_bmnk_dim_ = 1; |
81 | } |
82 | } |
83 | |
84 | if (rem_bmnk_dim_ == 1 && rem_prb_dim_ == 1) { |
85 | move_to_next_bmnk_level(); |
86 | move_to_next_prb_dim(); |
87 | return; |
88 | } |
89 | |
90 | if (rem_bmnk_dim_ == 1) { |
91 | move_to_next_bmnk_level(); |
92 | if (!prb_dims_[prb_dim_idx_]->allow_split()) { |
93 | move_to_next_prb_dim(); |
94 | } |
95 | return; |
96 | } |
97 | if (rem_prb_dim_ == 1) { |
98 | move_to_next_prb_dim(); |
99 | if (!allow_fuse()) move_to_next_bmnk_level(); |
100 | return; |
101 | } |
102 | ir_error_not_expected(); |
103 | } |
104 | |
105 | private: |
106 | bool allow_fuse() const { |
107 | for (auto *d : prb_dims_) |
108 | if (!d->allow_fuse()) return false; |
109 | return true; |
110 | } |
111 | |
112 | void move_to_next_bmnk_level() { |
113 | bool found = false; |
114 | int l_beg = (int)level_ + 1; |
115 | int l_end = (int)tile_level_t::_last; |
116 | for (int l = l_beg; l < l_end; l++) { |
117 | tile_level_t level = (tile_level_t)l; |
118 | if (bmnk_dim_.dim(level) != 1) { |
119 | found = true; |
120 | level_ = level; |
121 | rem_bmnk_dim_ = bmnk_dim_.dim(level); |
122 | break; |
123 | } |
124 | } |
125 | if (!found) level_ = tile_level_t::_last; |
126 | } |
127 | |
128 | void move_to_next_prb_dim() { |
129 | bool found = false; |
130 | for (int i = prb_dim_idx_ + 1; i < (int)prb_dims_.size(); i++) { |
131 | if (prb_dims_[i]->size() != 1) { |
132 | found = true; |
133 | prb_dim_idx_ = i; |
134 | rem_prb_dim_ = prb_dims_[i]->size(); |
135 | break; |
136 | } |
137 | } |
138 | if (!found) prb_dim_idx_ = (int)prb_dims_.size(); |
139 | } |
140 | |
141 | bool is_last_prb_dim() const { |
142 | if (!allow_fuse()) return true; |
143 | for (int i = prb_dim_idx_ + 1; i < (int)prb_dims_.size(); i++) { |
144 | auto *d = prb_dims_[i]; |
145 | int max_dim = min(d->size(), d->max_dim(level_)); |
146 | if (max_dim != 1) return false; |
147 | } |
148 | return true; |
149 | } |
150 | |
151 | int prb_level_dim() const { |
152 | int ret = 1; |
153 | for (auto *d : prb_dims_) |
154 | ret *= d->dim(level_); |
155 | return ret; |
156 | } |
157 | |
158 | int compute_next_block(tile_level_t level, int level_dim, |
159 | dim_value_t target_dim, int target_base_blk, int dim, |
160 | int base_iter_block, bool is_last_dim, int pad_block, |
161 | double target_eff = 0.75) const { |
162 | if (target_dim.is_unlimited()) return dim; |
163 | |
164 | bool require_pow_2 = false; |
165 | if (level == tile_level_t::tg) require_pow_2 = true; |
166 | if (level == tile_level_t::iter |
167 | && (bmnk_dim_.bmnk() == 'N' |
168 | || (bmnk_dim_.bmnk() == 'M' |
169 | && bmnk_dim_.inner_dims() == 1))) |
170 | require_pow_2 = true; |
171 | |
172 | int step = 1; |
173 | int rem_target_base_blk = 1; |
174 | if (level == tile_level_t::iter) { |
175 | rem_target_base_blk |
176 | = target_base_blk / math::gcd(level_dim, target_base_blk); |
177 | step = base_iter_block; |
178 | ir_assert(rem_target_base_blk % base_iter_block == 0); |
179 | if (is_last_dim) step = rem_target_base_blk; |
180 | } |
181 | |
182 | int dimension_bound = [&]() { |
183 | // Prefer powers of 2 for small dimensions as they generally |
184 | // result in better load instructions. |
185 | int bound = utils::rnd_up_pow2(dim); |
186 | if (bound % step == 0 && (double)dim / bound >= target_eff) |
187 | return bound; |
188 | return utils::rnd_up(dim, step); |
189 | }(); |
190 | int target_bound = utils::rnd_dn(target_dim, step); |
191 | int ret = std::min(dimension_bound, target_bound); |
192 | |
193 | while (ret >= step) { |
194 | bool ok = true; |
195 | if (require_pow_2 && !math::is_pow2(ret)) ok = false; |
196 | if (!is_last_dim) { |
197 | if (ret % rem_target_base_blk != 0 |
198 | && rem_target_base_blk % ret != 0) |
199 | ok = false; |
200 | } |
201 | if (ok) { |
202 | int dim_padded = utils::rnd_up(dim, ret); |
203 | double eff = (double)dim / dim_padded; |
204 | if (eff >= target_eff) break; |
205 | } |
206 | ret -= step; |
207 | } |
208 | if (ret == 0) ret = step; |
209 | if (require_pow_2) ir_assert(math::is_pow2(ret)); |
210 | if (level == tile_level_t::iter) ir_assert(ret % base_iter_block == 0); |
211 | if (pad_block > ret && pad_block % ret != 0) |
212 | ret = math::gcd(pad_block, ret); |
213 | return ret; |
214 | } |
215 | |
216 | bool is_loop_unlimited() const { |
217 | if (level_ != tile_level_t::loop) return false; |
218 | if (rem_bmnk_dim_.is_unlimited()) return false; |
219 | if (bmnk_dim_.tg_dim() != 1) return false; |
220 | return true; |
221 | } |
222 | |
223 | bool is_iter_full_match() const { |
224 | if (level_ != tile_level_t::iter) return false; |
225 | |
226 | int prb_total = 1; |
227 | for (auto *d : prb_dims_) { |
228 | if (d->iter_dim() != 1) return false; |
229 | int max_iter_dim = min(d->size(), d->max_dim(tile_level_t::iter)); |
230 | prb_total *= max_iter_dim; |
231 | } |
232 | |
233 | if (rem_bmnk_dim_ != prb_total) return false; |
234 | return true; |
235 | } |
236 | |
237 | bool try_assign_multi_blocks() { |
238 | // Check restrictions to apply the heuristics. |
239 | if (!allow_fuse()) return false; |
240 | if (!is_loop_unlimited() && !is_iter_full_match()) return false; |
241 | |
242 | int nprb_dims = (int)prb_dims_.size(); |
243 | std::vector<int> rem_dims(nprb_dims, 1); |
244 | std::vector<int> dims(nprb_dims, 1); |
245 | |
246 | int max_total_dim = 1; |
247 | for (int i = prb_dim_idx_; i < nprb_dims; i++) { |
248 | int dim = prb_dims_[i]->size(); |
249 | int rem_dim = (i == prb_dim_idx_) ? (int)rem_prb_dim_ : dim; |
250 | rem_dim = min(rem_dim, prb_dims_[i]->max_dim(level_)); |
251 | rem_dims[i] = rem_dim; |
252 | max_total_dim *= rem_dim; |
253 | } |
254 | |
255 | bool found = false; |
256 | std::function<void(int, int, double, double)> step; |
257 | step = [&](int idx, int total_dim, double eff, double target_eff) { |
258 | if (total_dim > rem_bmnk_dim_) return; |
259 | if (eff < target_eff) return; |
260 | if (idx == nprb_dims) { |
261 | double min_dim_ratio = 0.5; |
262 | double dim_ratio = total_dim / (double)rem_bmnk_dim_; |
263 | // If all available dimensions are assigned, skip any checks. |
264 | if (total_dim != max_total_dim) { |
265 | // Skip if the full dimension is too small relative to the |
266 | // target size. |
267 | if (dim_ratio < min_dim_ratio) return; |
268 | // Skip if the padding due to blocking is too large. |
269 | if (eff < target_eff) return; |
270 | } |
271 | // Found good blocking, set the flag. |
272 | found = true; |
273 | return; |
274 | } |
275 | int dim = prb_dims_[idx]->size(); |
276 | int rem_dim = rem_dims[idx]; |
277 | for (int blk = rem_dim; blk >= 1; blk--) { |
278 | int dim_padded = utils::rnd_up(dim, blk); |
279 | double dim_eff = (double)dim / dim_padded; |
280 | dims[idx] = blk; |
281 | step(idx + 1, total_dim * blk, eff * dim_eff, target_eff); |
282 | if (found) break; |
283 | } |
284 | }; |
285 | |
286 | if (level_ == tile_level_t::iter) { |
287 | // is_iter_full_match() returned true so all dimensions can be |
288 | // assigned as is. |
289 | ir_assert(rem_bmnk_dim_ == max_total_dim); |
290 | for (int i = prb_dim_idx_; i < nprb_dims; i++) { |
291 | dims[i] = rem_dims[i]; |
292 | } |
293 | found = true; |
294 | } else { |
295 | // Try to target different efficiencies until a good blocking is |
296 | // found. |
297 | for (double eff = 1.0; eff >= 0.5; eff -= 0.05) { |
298 | step(prb_dim_idx_, 1, 1.0, eff); |
299 | if (found) break; |
300 | } |
301 | } |
302 | |
303 | ir_assert(found) << "Can't assign blocks." ; |
304 | for (int i = prb_dim_idx_; i < nprb_dims; i++) { |
305 | prb_dims_[i]->set_dim(level_, dims[i]); |
306 | } |
307 | |
308 | prb_dim_idx_ = nprb_dims; |
309 | return true; |
310 | } |
311 | |
312 | tile_level_t level_ = tile_level_t::unknown; |
313 | |
314 | dim_info_t bmnk_dim_; |
315 | dim_value_t rem_bmnk_dim_; |
316 | |
317 | std::vector<dim_info_t *> prb_dims_; |
318 | int prb_dim_idx_ = -1; |
319 | dim_value_t rem_prb_dim_ = 0; |
320 | }; |
321 | |
322 | void block_helper_t::compute() { |
323 | is_frozen_ = true; |
324 | ir_assert(vectorize_by_b() || vectorize_by_n()); |
325 | |
326 | init_bmnk_dims(); |
327 | init_bmnk_blocks(); |
328 | init_prb_blocks(); |
329 | |
330 | // XXX: Fix up loop dims for when K grid slicing is disabled. |
331 | if (!allow_k_grid_slicing_) { |
332 | for (auto &kv : dims_) { |
333 | auto &d = kv.second; |
334 | if (d.bmnk() != 'K') continue; |
335 | d.set_loop_dim(1); |
336 | d.set_loop_dim(d.grid_dim()); |
337 | } |
338 | } |
339 | |
340 | #ifdef GEN_CONV_DEBUG |
341 | for (auto &kv : dims_) { |
342 | auto &d = kv.second; |
343 | const char *tags[] = {"iter" , "tg" , "loop" }; |
344 | for (int i = min_tile_level_idx; i <= max_tile_level_idx; i++) { |
345 | auto level = (tile_level_t)i; |
346 | std::string env_name |
347 | = d.name() + "_" + tags[i - min_tile_level_idx] + "_dim" ; |
348 | int env_dim = getenv_int(env_name.c_str(), -1); |
349 | if (env_dim != -1) d.set_dim(level, env_dim); |
350 | } |
351 | } |
352 | #endif |
353 | |
354 | // Verify blocks. |
355 | for (auto &kv : dims_) { |
356 | auto &d = kv.second; |
357 | ir_assert(d.iter_dim() % d.base_iter_block() == 0); |
358 | for (int i = min_tile_level_idx; i <= max_tile_level_idx; i++) { |
359 | auto level = (tile_level_t)i; |
360 | auto max_dim = d.max_dim(level); |
361 | ir_assert(max_dim.is_unlimited() || d.dim(level) <= max_dim); |
362 | } |
363 | } |
364 | |
365 | for (char bmnk : {'B', 'M', 'N', 'K'}) { |
366 | int iter_blk = 1; |
367 | for (auto &kv : dims_) { |
368 | auto &d = kv.second; |
369 | if (d.bmnk() != bmnk) continue; |
370 | iter_blk *= d.iter_dim(); |
371 | } |
372 | ir_assert(iter_blk % bmnk_dim(bmnk).base_iter_block() == 0); |
373 | } |
374 | } |
375 | |
376 | void block_helper_t::init_bmnk_blocks() { |
377 | int m_blk = 0; |
378 | int k_blk = 0; |
379 | int bn_blk = 0; |
380 | int m_inst_blk = 0; |
381 | int k_inst_blk = 0; |
382 | int bn_inst_blk = 0; |
383 | bool is_ge_hpc = (hw_cfg_.hw() >= ngen::HW::XeHPC); |
384 | bool reduce_m_block = false; |
385 | if (reduce_m_block_hint_set_) { |
386 | reduce_m_block = reduce_m_block_hint_; |
387 | } else { |
388 | if (m_dim().base_iter_block() == 1) reduce_m_block = true; |
389 | if (k_dim().base_iter_block() == 1) reduce_m_block = true; |
390 | } |
391 | if (is_tf32() && fma_kind_ != fma_kind_t::mad) reduce_m_block = true; |
392 | int eu_thr_mul = (!is_ge_hpc && reduce_m_block) ? 2 : 4; |
393 | #ifdef GEN_CONV_DEBUG |
394 | eu_thr_mul = getenv_int("eu_thr_mul" , eu_thr_mul); |
395 | #endif |
396 | auto &bn_dim = (vectorize_by_b() ? b_dim() : n_dim()); |
397 | switch (fma_kind_) { |
398 | case fma_kind_t::mad: { |
399 | int max_m_iter_dim = prb_max_dim('M', tile_level_t::iter); |
400 | m_inst_blk = std::min(8, utils::rnd_down_pow2(max_m_iter_dim)); |
401 | bn_inst_blk = vec_size_; |
402 | k_inst_blk = 1; |
403 | bool use_small_m_block = hw_cfg_.hw() <= ngen::HW::XeHP |
404 | && m_dim().base_iter_block() == 1; |
405 | m_blk = (is_x8x8s32() || use_small_m_block ? 8 : 16); |
406 | bool small_m_tg = m_dim().base_iter_block() == 1 |
407 | && hw_cfg_.hw() == ngen::HW::XeHPG |
408 | && !m_dim().pref_tg_block(); |
409 | if (!m_dim().pref_tg_block()) |
410 | m_dim().set_max_dim(tile_level_t::tg, small_m_tg ? 1 : 4); |
411 | bn_blk = vec_size_; |
412 | k_blk = compute_mad_k_block(); |
413 | if (!allow_k_grid_slicing_ && !allow_k_tg_slicing_) { |
414 | do { |
415 | int est_bmn_threads = 1; |
416 | est_bmn_threads *= utils::div_up(m_dim().size(), m_blk); |
417 | est_bmn_threads *= utils::div_up(bn_dim.size(), bn_blk); |
418 | if (est_bmn_threads >= eu_thr_mul * hw_cfg_.eu_count()) |
419 | break; |
420 | m_blk /= 2; |
421 | m_inst_blk = std::min(m_inst_blk, m_blk); |
422 | } while (m_blk > 1); |
423 | } |
424 | break; |
425 | } |
426 | case fma_kind_t::dp4a: |
427 | case fma_kind_t::dpas: |
428 | case fma_kind_t::dpasw: { |
429 | ir_assert(vectorize_by_n()) |
430 | << "dpas can support N vectorization only." ; |
431 | int max_iter_dim = prb_max_dim('M', tile_level_t::iter); |
432 | int target_m_blk = reduce_m_block ? 16 : 32; |
433 | if (max_iter_dim % target_m_blk != 0 && max_iter_dim > 32) { |
434 | float max_utilization_rate = 0.; |
435 | for (int i |
436 | = min(32, utils::rnd_dn((int)(1.5 * target_m_blk), 4)); |
437 | i > target_m_blk; i -= 4) { |
438 | float utilization_rate = (float)max_iter_dim |
439 | / utils::rnd_up(max_iter_dim, i); |
440 | // Heuristic constant preferring larger blocks, experimentally determined. |
441 | const float threshhold = 1.05; |
442 | if (utilization_rate > threshhold * max_utilization_rate) { |
443 | max_utilization_rate = utilization_rate; |
444 | target_m_blk = i; |
445 | } |
446 | } |
447 | } |
448 | m_blk = target_m_blk; |
449 | m_inst_blk = m_blk % 8 == 0 ? 8 : m_blk; |
450 | bn_inst_blk = 8; |
451 | k_inst_blk = is_x8x8s32() ? 32 : 16; |
452 | bn_blk = is_ge_hpc ? 64 : 32; |
453 | int est_bmn_threads = 1; |
454 | est_bmn_threads *= utils::div_up(m_dim().size(), m_blk); |
455 | est_bmn_threads *= utils::div_up(bn_dim.size(), bn_blk); |
456 | auto thread_factor |
457 | = utils::div_up(hw_cfg_.eu_count(), est_bmn_threads); |
458 | if (thread_factor > (is_ge_hpc ? 6 : 2) && !allow_k_grid_slicing_) { |
459 | if (m_inst_blk % 2 == 0 && bn_blk != bn_dim.size()) { |
460 | m_blk /= 2; |
461 | if (m_blk % m_inst_blk != 0) m_inst_blk /= 2; |
462 | } |
463 | } |
464 | |
465 | k_blk = k_inst_blk; |
466 | break; |
467 | } |
468 | default: ir_error_not_expected(); |
469 | } |
470 | |
471 | m_blk = math::lcm(m_blk, m_dim().base_iter_block()); |
472 | k_blk = math::lcm(k_blk, k_dim().base_iter_block()); |
473 | bn_blk = math::lcm(bn_blk, bn_dim.base_iter_block()); |
474 | |
475 | // Shrink block sizes to leverage is_iter_full_match() when applicable. |
476 | const char bmnks[] = {'M', 'K', vectorize_by_b() ? 'B' : 'N'}; |
477 | int *blocks[] = {&m_blk, &k_blk, &bn_blk}; |
478 | int inst_blocks[] = {m_inst_blk, k_inst_blk, bn_inst_blk}; |
479 | for (int i = 0; i < 3; i++) { |
480 | int max_iter_dim = prb_max_dim(bmnks[i], tile_level_t::iter); |
481 | int base_blk = bmnk_dim(bmnks[i]).base_iter_block(); |
482 | int &blk = *blocks[i]; |
483 | int inst_blk = inst_blocks[i]; |
484 | if (max_iter_dim % base_blk == 0 && max_iter_dim % inst_blk == 0) { |
485 | blk = std::min(blk, max_iter_dim); |
486 | } |
487 | ir_assert(blk % inst_blk == 0); |
488 | } |
489 | |
490 | // Pad base iteration blocks according to instruction blocks. |
491 | for (char bmnk : {'B', 'M', 'N', 'K'}) { |
492 | bool is_bn = utils::one_of(bmnk, 'B', 'N'); |
493 | auto &d = bmnk_dim(bmnk); |
494 | if (is_bn && !vectorize_by_bmnk(bmnk)) continue; |
495 | |
496 | int blk = d.base_iter_block(); |
497 | int inst_blk |
498 | = is_bn ? bn_inst_blk : (bmnk == 'M') ? m_inst_blk : k_inst_blk; |
499 | d.set_base_iter_block(math::lcm(blk, inst_blk)); |
500 | } |
501 | |
502 | m_blk = compute_block(m_dim().size(), m_blk, m_dim().base_iter_block()); |
503 | // Require pow2 when only one m dim is non-trivial |
504 | if (m_dim().inner_dims() == 1) m_blk = utils::rnd_down_pow2(m_blk); |
505 | k_blk = compute_block(k_dim().size(), k_blk, k_dim().base_iter_block()); |
506 | bn_blk = compute_block(bn_dim.size(), bn_blk, bn_dim.base_iter_block()); |
507 | |
508 | #ifdef GEN_CONV_DEBUG |
509 | m_blk = getenv_int("m_iter_blk" , m_blk); |
510 | k_blk = getenv_int("k_iter_blk" , k_blk); |
511 | bn_blk = getenv_int("bn_iter_blk" , bn_blk); |
512 | if (vectorize_by_b()) { |
513 | bn_blk = getenv_int("b_iter_blk" , bn_blk); |
514 | } else { |
515 | bn_blk = getenv_int("n_iter_blk" , bn_blk); |
516 | } |
517 | #endif |
518 | m_dim().set_iter_dim(m_blk); |
519 | bn_dim.set_iter_dim(bn_blk); |
520 | k_dim().set_iter_dim(k_blk); |
521 | |
522 | init_k_blocking(); |
523 | |
524 | for (char bmnk : {'B', 'M', 'N', 'K'}) { |
525 | auto &d = bmnk_dim(bmnk); |
526 | ir_assert(d.iter_dim() % d.base_iter_block() == 0 |
527 | || d.base_iter_block() % d.iter_dim() == 0); |
528 | } |
529 | |
530 | // Set thread group blocks. |
531 | bool with_k_tg_slicing = (k_dim().tg_dim() > 1); |
532 | if (!with_k_tg_slicing) { |
533 | int target_tg_size = max_tg_size_; |
534 | int est_threads = 1; |
535 | for (char bmnk : {'B', 'M', 'N'}) |
536 | est_threads *= bmnk_dim(bmnk).grid_dim(); |
537 | if (est_threads < 2 * hw_cfg_.eu_count() |
538 | || hw_cfg_.hw() >= ngen::HW::XeHPC) { |
539 | target_tg_size = std::min(target_tg_size, 16); |
540 | } |
541 | |
542 | auto init_tg_dim = [&](dim_info_t &d) -> int { |
543 | int i_max_tg_dim = min(target_tg_size, d.max_dim(tile_level_t::tg)); |
544 | int target_blk = i_max_tg_dim * d.iter_dim(); |
545 | int base_tg_dim = compute_block(d.size(), target_blk, d.iter_dim()) |
546 | / d.iter_dim(); |
547 | //restrict maximum single tg dim as max_tg size is reduced |
548 | return std::min(utils::rnd_down_pow2(base_tg_dim), |
549 | max_tg_overridden_ ? target_tg_size |
550 | / (hw_cfg_.hw() >= ngen::HW::XeHPC ? 2 : 4) |
551 | : simd_size_); |
552 | }; |
553 | |
554 | // Compute max thread group blocks, independently for each dimension. |
555 | std::vector<char> tg_bmnks = {vectorize_by_b() ? 'B' : 'N', 'M'}; |
556 | std::vector<int> tg_dims(tg_bmnks.size(), 1); |
557 | bool any_pref_dim = any_pref_tg_block(); |
558 | int *split_dim_idx, *pref_dim_idx; |
559 | char split_dim_bmnk; |
560 | for (size_t i = 0; i < tg_bmnks.size(); i++) { |
561 | auto &d = bmnk_dim(tg_bmnks[i]); |
562 | int tg_dim = init_tg_dim(d); |
563 | tg_dims[i] = tg_dim; |
564 | if (!d.pref_tg_block()) { |
565 | split_dim_idx = &tg_dims.at(i); |
566 | split_dim_bmnk = tg_bmnks[i]; |
567 | } else { |
568 | pref_dim_idx = &tg_dims.at(i); |
569 | } |
570 | } |
571 | |
572 | auto total_tg_dim = [&]() { |
573 | return std::accumulate( |
574 | tg_dims.begin(), tg_dims.end(), 1, std::multiplies<int>()); |
575 | }; |
576 | auto max_tg_dim = [&]() -> int & { |
577 | if (any_pref_dim) { |
578 | auto split_dim = bmnk_dim(split_dim_bmnk); |
579 | int split_rem_grid = split_dim.size() / split_dim.iter_dim(); |
580 | // Special case preserve non-zero tg dim for non-pref bmnk |
581 | bool preserve_min_dim |
582 | = (split_rem_grid > 1 && *split_dim_idx < 4 |
583 | && *pref_dim_idx > *split_dim_idx); |
584 | if (*split_dim_idx == 1 |
585 | || (preserve_min_dim |
586 | && (*pref_dim_idx > *split_dim_idx))) |
587 | return *pref_dim_idx; |
588 | else |
589 | return *split_dim_idx; |
590 | } |
591 | return *std::max_element(tg_dims.begin(), tg_dims.end()); |
592 | }; |
593 | |
594 | // Reduce thread group size until it fits the target size. |
595 | while (total_tg_dim() > target_tg_size) { |
596 | max_tg_dim() /= 2; |
597 | } |
598 | |
599 | for (size_t i = 0; i < tg_bmnks.size(); i++) { |
600 | auto &d = bmnk_dim(tg_bmnks[i]); |
601 | d.set_tg_dim(tg_dims[i]); |
602 | } |
603 | } |
604 | } |
605 | |
606 | void block_helper_t::init_k_blocking() { |
607 | // Thread and thread group dims must not be set yet. |
608 | for (char bmnk : {'B', 'M', 'N', 'K'}) { |
609 | auto &d = bmnk_dim(bmnk); |
610 | ir_assert(d.loop_dim() == 1); |
611 | ir_assert(d.tg_dim() == 1); |
612 | } |
613 | |
614 | if (allow_k_grid_slicing_) { |
615 | int est_threads = 1; |
616 | for (char bmnk : {'B', 'M', 'N', 'K'}) |
617 | est_threads *= bmnk_dim(bmnk).grid_dim(); |
618 | int def_k_loop_dim = utils::div_up(est_threads, 2 * hw_cfg_.eu_count()); |
619 | def_k_loop_dim = std::min(100, def_k_loop_dim); |
620 | def_k_loop_dim = std::max(1, def_k_loop_dim); |
621 | int k_loop_dim = def_k_loop_dim; |
622 | #ifdef GEN_CONV_DEBUG |
623 | k_loop_dim = getenv_int("k_loop_dim" , k_loop_dim); |
624 | #endif |
625 | k_dim().set_loop_dim(k_loop_dim); |
626 | return; |
627 | } |
628 | |
629 | if (!enable_k_tg_slicing()) { |
630 | k_dim().set_loop_dim(dim_value_t::unlimited()); |
631 | return; |
632 | } |
633 | |
634 | int k_nblks = utils::div_up( |
635 | prb_blocked_dim('K').size(), k_dim().base_iter_block()); |
636 | int tg_dim0 = min(max_tg_size_, k_dim().max_dim(tile_level_t::tg)); |
637 | for (int tg_dim = tg_dim0; tg_dim >= 1; tg_dim /= 2) { |
638 | if (k_nblks % tg_dim == 0) { |
639 | k_dim().set_loop_dim(k_nblks / tg_dim); |
640 | k_dim().set_tg_dim(tg_dim); |
641 | return; |
642 | } |
643 | } |
644 | |
645 | // Couldn't enable TG slicing. |
646 | k_dim().set_loop_dim(dim_value_t::unlimited()); |
647 | } |
648 | |
649 | bool block_helper_t::enable_k_tg_slicing() const { |
650 | #ifdef GEN_CONV_DEBUG |
651 | int env_value = getenv_int("enable_k_tg_slicing" , -1); |
652 | if (env_value != -1) return (bool)env_value; |
653 | #endif |
654 | if (!allow_k_tg_slicing_) return false; |
655 | |
656 | if (m_dim().iter_dim() > 16) return false; |
657 | |
658 | // TG slicing is supported only when there is only one k dimension. |
659 | if (prb_blocked_ndims('K') > 1) return false; |
660 | |
661 | // Do not enable TG slicing if there are enough non-K threads. |
662 | int non_k_threads = 1; |
663 | for (char bmnk : {'B', 'M', 'N'}) { |
664 | auto &d = bmnk_dim(bmnk); |
665 | non_k_threads *= d.grid_dim() * d.tg_dim(); |
666 | } |
667 | if (non_k_threads >= hw_cfg_.eu_count()) return false; |
668 | |
669 | // Do not enable TG slicing if reduction is small. |
670 | int k_nblks = utils::div_up(k_dim().size(), k_dim().base_iter_block()); |
671 | if (k_nblks < 16) return false; |
672 | |
673 | return true; |
674 | } |
675 | |
676 | void block_helper_t::init_prb_blocks() { |
677 | // Pad sizes to base block multiples. |
678 | for (auto &kv : dims_) { |
679 | auto &d = kv.second; |
680 | d.set_size(utils::rnd_up(d.size(), d.base_iter_block())); |
681 | } |
682 | |
683 | // Filter blocked dimensions and sort them according to their keys. |
684 | std::vector<dim_info_t *> sorted_dims; |
685 | for (auto &kv : dims_) { |
686 | auto &d = kv.second; |
687 | if (!d.is_blocked()) continue; |
688 | sorted_dims.push_back(&d); |
689 | } |
690 | std::sort(sorted_dims.begin(), sorted_dims.end(), |
691 | [](const dim_info_t *a, const dim_info_t *b) { |
692 | if (a->order_key() == b->order_key()) { |
693 | return a->name().compare(b->name()) < 0; |
694 | } |
695 | return a->order_key() < b->order_key(); |
696 | }); |
697 | |
698 | for (char bmnk : {'B', 'N', 'M', 'K'}) { |
699 | std::vector<dim_info_t *> cur_dims; |
700 | for (auto *d : sorted_dims) { |
701 | if (d->bmnk() != bmnk) continue; |
702 | cur_dims.push_back(d); |
703 | } |
704 | |
705 | ir_assert(!cur_dims.empty()); |
706 | |
707 | // Pad dimensions according to BMNK base block requirements. |
708 | int max_iter_dim = prb_max_dim(bmnk, tile_level_t::iter); |
709 | int base_blk = bmnk_dim(bmnk).base_iter_block(); |
710 | if (max_iter_dim == 1 && base_blk > 1) { |
711 | ir_assert(cur_dims[0]->base_iter_block() == 1); |
712 | cur_dims[0]->set_size(base_blk); |
713 | } |
714 | |
715 | block_assigner_t assigner(bmnk_dim(bmnk), cur_dims); |
716 | while (assigner.has_blocks()) { |
717 | assigner.assign_block(); |
718 | } |
719 | } |
720 | } |
721 | |
722 | int block_helper_t::compute_mad_k_block() const { |
723 | int k_base_blk = k_dim().base_iter_block(); |
724 | if (k_base_blk >= 16) return k_base_blk; |
725 | |
726 | bool is_fused = true; |
727 | int k_blocked_size = 1; |
728 | for (auto &kv : dims_) { |
729 | auto &d = kv.second; |
730 | if (!d.is_blocked()) continue; |
731 | if (d.bmnk() != 'K') continue; |
732 | k_blocked_size *= d.size(); |
733 | if (!d.allow_fuse()) is_fused = false; |
734 | } |
735 | |
736 | if (!is_fused) return 16; |
737 | |
738 | int max_k_blk = 32; |
739 | if ((k_blocked_size <= max_k_blk) |
740 | && (k_blocked_size % k_dim().base_iter_block() == 0)) |
741 | return k_blocked_size; |
742 | |
743 | return 16; |
744 | } |
745 | |
746 | } // namespace jit |
747 | } // namespace gpu |
748 | } // namespace impl |
749 | } // namespace dnnl |
750 | |