1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_GEMM_SCHEDULE_HPP
18#define GPU_JIT_IR_GEMM_SCHEDULE_HPP
19
20#include <functional>
21#include <limits>
22#include <sstream>
23#include <string>
24#include <utility>
25#include <vector>
26#include <initializer_list>
27
28#include "gpu/jit/ir/ir.hpp"
29#include "gpu/jit/ir/tensor.hpp"
30#include "gpu/jit/utils/utils.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace jit {
36
37// Used to describe semantics of a dimension in the GEMM context.
38// GEMM operation is defined as C = A x B
39// GEMM dimension kinds:
40// - B: shared by all tensors A, B, C (batch dimension)
41// - M: shared only by A and C
42// - N: shared only by B and C
43// - K: shared only by A and B (reduction dimension)
44enum class bmnk_kind_t { undef = -1, b = 0, m = 1, n = 2, k = 3 };
45
46enum class abc_kind_t { undef, a, b, c };
47
48class bmnk_mapper_t {
49public:
50 bmnk_mapper_t() = default;
51
52 bmnk_mapper_t(const object_map_t<expr_t, bmnk_kind_t> &bmnk_kinds)
53 : bmnk_kinds_(bmnk_kinds) {}
54
55 bmnk_kind_t bmnk_kind(const expr_t &var) const {
56 auto it = bmnk_kinds_.find(var);
57 if (it == bmnk_kinds_.end()) return bmnk_kind_t::undef;
58 return it->second;
59 }
60
61 bmnk_kind_t bmnk_kind(abc_kind_t abc_kind, int dim_idx) const {
62 return bmnk_kind(var(abc_kind, dim_idx));
63 }
64
65 int ndims(abc_kind_t abc_kind) const {
66 return int(get_vars(abc_kind).size());
67 }
68
69 void set_a_vars(const std::vector<expr_t> &vars) { a_vars_ = vars; }
70 void set_b_vars(const std::vector<expr_t> &vars) { b_vars_ = vars; }
71 void set_c_vars(const std::vector<expr_t> &vars) { c_vars_ = vars; }
72
73 void set_bmnk_kind(const expr_t &var, bmnk_kind_t bmnk_kind) {
74 auto ret = bmnk_kinds_.insert({var, bmnk_kind});
75 ir_assert(ret.second) << "Can't set variable twice: " << var;
76 }
77
78 const expr_t &var(abc_kind_t abc_kind, int dim_idx) const {
79 return get_vars(abc_kind)[dim_idx];
80 }
81
82 int dim_idx(abc_kind_t abc_kind, const expr_t &var) const {
83 auto &vars = get_vars(abc_kind);
84 for (int i = 0; i < int(vars.size()); i++) {
85 if (vars[i].is_same(var)) return i;
86 }
87 return -1;
88 }
89
90 layout_t map_to_bmnk(abc_kind_t abc_kind,
91 const std::vector<bmnk_kind_t> &bmnk_kinds,
92 const view_t &view) const;
93
94 layout_t map_to_bmnk(abc_kind_t abc_kind,
95 const std::vector<bmnk_kind_t> &bmnk_kinds,
96 const layout_t &layout) const;
97
98private:
99 const std::vector<expr_t> &get_vars(abc_kind_t abc_kind) const {
100 switch (abc_kind) {
101 case abc_kind_t::a: return a_vars_;
102 case abc_kind_t::b: return b_vars_;
103 case abc_kind_t::c: return c_vars_;
104 default: ir_error_not_expected() << "Unknown ABC kind.";
105 }
106 return a_vars_;
107 }
108
109 std::vector<expr_t> &get_vars(abc_kind_t abc_kind) {
110 auto &vars
111 = const_cast<const bmnk_mapper_t *>(this)->get_vars(abc_kind);
112 return const_cast<std::vector<expr_t> &>(vars);
113 }
114
115 std::vector<expr_t> a_vars_;
116 std::vector<expr_t> b_vars_;
117 std::vector<expr_t> c_vars_;
118 object_map_t<expr_t, bmnk_kind_t> bmnk_kinds_;
119};
120
121class bmnk_block_mapper_t {
122public:
123 bmnk_block_mapper_t(const bmnk_mapper_t &bmnk_mapper)
124 : bmnk_mapper_(bmnk_mapper) {}
125
126 void push_blocks(abc_kind_t abc_kind, const std::vector<block_t> &blocks) {
127 for (auto &b : blocks)
128 push_block(abc_kind, b);
129 }
130
131 void push_block(abc_kind_t abc_kind, const block_t &b);
132
133 layout_t map_from_bmnk(abc_kind_t abc_kind,
134 const std::vector<bmnk_kind_t> &bmnk_kinds,
135 const layout_t &bmnk_layout) const;
136
137private:
138 static void pop_size_1_blocks(std::vector<block_t> &blocks) {
139 while (!blocks.empty() && blocks.front().block == 1) {
140 blocks.erase(blocks.begin());
141 }
142 }
143
144 std::vector<block_t> create_prb_blocks(abc_kind_t abc_kind,
145 const std::vector<std::pair<abc_kind_t, block_t>> &mn_blocks)
146 const {
147 std::vector<block_t> ret;
148 ret.reserve(mn_blocks.size());
149 for (auto &p : mn_blocks) {
150 auto b = p.second;
151 const auto &var = bmnk_mapper_.var(p.first, b.dim_idx);
152 b.dim_idx = bmnk_mapper_.dim_idx(abc_kind, var);
153 ret.push_back(b);
154 }
155 return ret;
156 }
157
158 bool pop_block(std::vector<block_t> &bmnk_blocks,
159 std::vector<block_t> &prb_blocks, const block_t &bmnk_block) const;
160
161 bmnk_mapper_t bmnk_mapper_;
162
163 // Ordered from innermost to outermost.
164 std::vector<std::pair<abc_kind_t, block_t>> m_blocks_;
165 std::vector<std::pair<abc_kind_t, block_t>> n_blocks_;
166 std::vector<std::pair<abc_kind_t, block_t>> k_blocks_;
167};
168
169enum class loop_kind_t : int {
170 undef,
171 kernel_grid, // Loop is bound to the kernel grid.
172 serial, // Loop is inside a thread (may be unrolled or just a regular loop).
173 tg_grid, // Loop is bound to the thread group grid.
174 tensorized, // Such loops are fully unrolled/vectorized and converted to blocked multiplication.
175};
176
177static std::string to_string(loop_kind_t kind) {
178 switch (kind) {
179 case loop_kind_t::undef: return "undef";
180 case loop_kind_t::kernel_grid: return "kernel_grid";
181 case loop_kind_t::serial: return "serial";
182 case loop_kind_t::tg_grid: return "tg_grid";
183 case loop_kind_t::tensorized: return "tensorized";
184 default: ir_error_not_expected();
185 }
186 return "unknown";
187}
188
189inline std::ostream &operator<<(std::ostream &out, loop_kind_t kind) {
190 out << to_string(kind);
191 return out;
192}
193
194class loop_t {
195public:
196 loop_t() : kind_(loop_kind_t::undef) {}
197
198 loop_t(const expr_t &var, const expr_t &bound, bool is_root)
199 : var_(var)
200 , kind_(loop_kind_t::serial)
201 , bound_(bound)
202 , is_root_(is_root) {}
203
204 const expr_t &var() const { return var_; }
205
206 loop_kind_t kind() const { return kind_; }
207
208 void set_kind(loop_kind_t kind) { kind_ = kind; }
209
210 int unroll_factor() const { return unroll_factor_; }
211
212 void set_unroll_factor(int factor) { unroll_factor_ = factor; }
213
214 bool is_kernel_grid() const { return kind() == loop_kind_t::kernel_grid; }
215
216 bool is_serial() const { return kind() == loop_kind_t::serial; }
217
218 bool is_tg_grid() const { return kind() == loop_kind_t::tg_grid; }
219
220 bool is_tensorized() const { return kind() == loop_kind_t::tensorized; }
221
222 const expr_t &bound() const { return bound_; }
223
224 void set_bound(const expr_t &bound) { bound_ = bound; }
225
226 bool is_bound() const { return !bound_var().is_empty(); }
227
228 const expr_t &bound_var() const { return bound_var_; }
229
230 void set_bound_var(const expr_t &v) { bound_var_ = v; }
231
232 bool is_root() const { return is_root_; }
233
234 // Returns true for loops that were neither split, nor fused with other loops.
235 bool is_leaf() const { return is_leaf_; }
236
237 // Returns true if this loop was split into outer/inner loops.
238 bool is_split_parent() const { return is_split_parent_; }
239
240 // Returns true if this loop was the result of a split.
241 bool is_split_child() const { return is_split_child_; }
242
243 // Returns true if this loop was fused with other loops.
244 bool is_fused_parent() const { return is_fused_parent_; }
245
246 // Returns true if this loop was the result of a fusion.
247 bool is_fused_child() const { return is_fused_child_; }
248
249 const std::vector<expr_t> &parent_vars() const { return parent_vars_; }
250 const std::vector<expr_t> &child_vars() const { return child_vars_; }
251
252 void set_split(loop_t &outer_loop, loop_t &inner_loop) {
253 outer_loop.parent_vars_.push_back(var());
254 child_vars_.push_back(outer_loop.var());
255 outer_loop.is_split_child_ = true;
256
257 inner_loop.parent_vars_.push_back(var());
258 child_vars_.push_back(inner_loop.var());
259 inner_loop.is_split_child_ = true;
260
261 is_split_parent_ = true;
262 is_leaf_ = false;
263 }
264
265 void set_fuse(std::vector<std::reference_wrapper<loop_t>> &loops) {
266 for (auto &l_ref : loops) {
267 auto &l = l_ref.get();
268 parent_vars_.push_back(l.var());
269 l.child_vars_.push_back(var());
270 l.is_fused_parent_ = true;
271 l.is_leaf_ = false;
272 }
273 is_fused_child_ = true;
274 }
275
276 // Returns a loop variable expressed in the variables of the leaf loops.
277 expr_t expand_var(const object_map_t<expr_t, loop_t> &all_loops,
278 bool skip_fused = false) const {
279 if (is_leaf()) return var();
280 if (is_split_parent()) {
281 ir_assert(child_vars_.size() == 2);
282 auto &outer_loop = all_loops.at(child_vars_[0]);
283 auto &inner_loop = all_loops.at(child_vars_[1]);
284 auto outer_var = outer_loop.expand_var(all_loops, skip_fused);
285 auto inner_var = inner_loop.expand_var(all_loops, skip_fused);
286 return outer_var * inner_loop.bound() + inner_var;
287 }
288 if (is_fused_parent()) {
289 if (skip_fused) return var();
290 // Example of "unpacking":
291 // fused_var = (a * b * c * d)
292 // b = (fused_var / (D * C)) % B
293 ir_assert(child_vars_.size() == 1);
294 auto &fused_loop = all_loops.at(child_vars_[0]);
295 int nvars = int(fused_loop.parent_vars_.size());
296 expr_t denom = 1;
297 for (int i = nvars - 1; i >= 0; i--) {
298 auto &v = fused_loop.parent_vars_[i];
299 auto &child_loop = all_loops.at(v);
300 auto &bound = child_loop.bound();
301 if (v.is_same(var())) {
302 auto e = fused_loop.expand_var(all_loops, skip_fused)
303 / denom;
304 return (i == 0 ? e : e % bound);
305 }
306 denom *= bound;
307 }
308 }
309
310 ir_error_not_expected();
311 return expr_t();
312 }
313
314 std::string str() const {
315 using namespace ir_utils;
316
317 std::ostringstream oss;
318 oss << "var: " << var_;
319 oss << " bound: " << bound_;
320 oss << " kind: " << kind_;
321 if (unroll_factor_ != 1) oss << " unroll: " << unroll_factor_;
322 std::vector<std::string> props;
323 if (is_root()) props.push_back("root");
324 if (is_fused_child()) props.push_back("fused");
325 if (is_split_parent()) props.push_back("split");
326 oss << "(" << make_seq_print_helper(props, ", ") << ")";
327 return oss.str();
328 }
329
330 IR_DEFINE_DUMP()
331
332private:
333 expr_t var_; // Loop index variable.
334 loop_kind_t kind_; // Loop kind.
335 expr_t bound_; // Loop bound (exclusive).
336
337 expr_t bound_var_; // External variable this loop bound to.
338
339 int unroll_factor_ = 1;
340
341 bool is_root_ = false;
342 bool is_leaf_ = true;
343
344 bool is_split_parent_ = false;
345 bool is_split_child_ = false;
346
347 bool is_fused_parent_ = false;
348 bool is_fused_child_ = false;
349
350 // For variables there were split or fused.
351 // Fusion: i x j -> k
352 // i.child_vars _= [k]
353 // j.child_vars _= [k]
354 // k.parent_vars_ = [i, j]
355 // Split: i -> j x k
356 // i.child_vars_ = [j, k]
357 // j.parent_vars_ = [i]
358 // k.parent_vars_ = [i]
359 std::vector<expr_t> parent_vars_;
360 std::vector<expr_t> child_vars_;
361};
362
363// Defines GEMM computation including:
364// - Blocking scheme (order of loops, tiles per thread group/iteration)
365// - Mapping of problem dimensions to GEMM dimensions (BMNK)
366class gemm_schedule_t {
367public:
368 gemm_schedule_t() = default;
369
370 gemm_schedule_t(constraint_set_t &cset, const grid_info_t &kernel_grid,
371 const grid_info_t &tg_grid)
372 : cset_(&cset), kernel_grid_(kernel_grid), tg_grid_(tg_grid) {}
373
374 const grid_info_t &kernel_grid() const { return kernel_grid_; }
375 const grid_info_t &tg_grid() const { return tg_grid_; }
376
377 bmnk_kind_t bmnk_kind(const expr_t &var) const {
378 return bmnk_kind(std::vector<expr_t>({var}));
379 }
380
381 const bmnk_mapper_t &bmnk_mapper() const { return bmnk_mapper_; }
382
383 void set_b_vars(const std::vector<expr_t> &vars) {
384 for (auto &v : vars)
385 set_bmnk_kind(v, bmnk_kind_t::b);
386 }
387
388 void set_m_vars(const std::vector<expr_t> &vars) {
389 for (auto &v : vars)
390 set_bmnk_kind(v, bmnk_kind_t::m);
391 }
392
393 void set_n_vars(const std::vector<expr_t> &vars) {
394 for (auto &v : vars)
395 set_bmnk_kind(v, bmnk_kind_t::n);
396 }
397
398 void set_k_vars(const std::vector<expr_t> &vars) {
399 for (auto &v : vars)
400 set_bmnk_kind(v, bmnk_kind_t::k);
401 }
402
403 // A/B/C views in the problem notation.
404 const view_t &a_view() const { return a_view_; }
405 const view_t &b_view() const { return b_view_; }
406 const view_t &c_view() const { return c_view_; }
407
408 void set_a_view(const view_t &v) {
409 set_abc_view(v, a_view_);
410 bmnk_mapper_.set_a_vars(a_view_.vvars());
411 }
412
413 void set_b_view(const view_t &v) {
414 set_abc_view(v, b_view_);
415 bmnk_mapper_.set_b_vars(b_view_.vvars());
416 }
417
418 void set_c_view(const view_t &v) {
419 set_abc_view(v, c_view_);
420 bmnk_mapper_.set_c_vars(c_view_.vvars());
421 }
422
423 void set_view(const view_t &view) {
424 // Create missing loops.
425 for (int i = 0; i < view.nvdims(); i++) {
426 auto &v = view.vvars()[i];
427 dim_t bound = view.vdims()[i];
428 if (has_loop(v)) {
429 auto &loop = find_loop(v);
430 ir_assert(bound == to_cpp<dim_t>(loop.bound()))
431 << "Inconsistent sizes.";
432 continue;
433 }
434 create_loop(v, bound, /*is_root=*/true);
435 }
436 }
437
438 tensor_t tg_view_tile(const view_t &view) const {
439 return view_tile(view, tile_level_t::thread_group);
440 }
441
442 tensor_t thr_view_tile(const view_t &view, bool is_relative = true) const {
443 auto thr_tile = view_tile(view, tile_level_t::iter);
444 if (is_relative) return thr_tile;
445 return tg_view_tile(view).create_sub_tensor(thr_tile);
446 }
447
448 view_t a_tg_view() const {
449 ir_assert(is_finalized_);
450 return a_view_.create_sub_view(a_tg_tile_);
451 }
452
453 view_t b_tg_view() const {
454 ir_assert(is_finalized_);
455 return b_view_.create_sub_view(b_tg_tile_);
456 }
457
458 view_t c_tg_view() const {
459 ir_assert(is_finalized_);
460 return c_view_.create_sub_view(c_tg_tile_);
461 }
462
463 // Thread group tiles for A, B, C.
464 const tensor_t &a_tg_tile() const { return a_tg_tile_; }
465 const tensor_t &b_tg_tile() const { return b_tg_tile_; }
466 const tensor_t &c_tg_tile() const { return c_tg_tile_; }
467
468 // Thread tiles for A, B, C.
469 tensor_t a_thr_tile(bool is_relative = true) const {
470 if (is_relative) return a_thr_tile_;
471 return a_tg_tile_.create_sub_tensor(a_thr_tile_);
472 }
473
474 tensor_t b_thr_tile(bool is_relative = true) const {
475 if (is_relative) return b_thr_tile_;
476 return b_tg_tile_.create_sub_tensor(b_thr_tile_);
477 }
478
479 tensor_t c_thr_tile(bool is_relative = true) const {
480 if (is_relative) return c_thr_tile_;
481 return c_tg_tile_.create_sub_tensor(c_thr_tile_);
482 }
483
484 int var_bound(const expr_t &var) const {
485 return to_cpp<int>(find_loop(var).bound());
486 }
487
488 void set_var_bound(const expr_t &var, int bound) {
489 return find_loop(var).set_bound(bound);
490 }
491
492 // Splits loop defined by `var` into two new loops based on `factor`.
493 // Before:
494 // for (int var = 0; var < I; var++) { ... }
495 // After:
496 // for (int outer_var = 0; outer_var < I / factor; outer_var++) {
497 // for (int inner_var = 0; inner_var < factor; inner_var++) {
498 // ...
499 // }
500 // }
501 void split(const expr_t &var, int factor, expr_t &outer_var,
502 expr_t &inner_var, const std::string &outer_name = {},
503 const std::string &inner_name = {}) {
504 auto &loop = find_loop(var);
505 ir_assert(loop.is_leaf()) << "Can't split, non-leaf loop.";
506
507 int bound = to_cpp<int>(loop.bound());
508 if (loop.is_root() && (bound % factor != 0)) {
509 // Auto round-up bounds for the root loops.
510 bound = utils::rnd_up(bound, factor);
511 loop.set_bound(bound);
512 }
513
514 ir_assert(bound % factor == 0) << "Can't split.";
515
516 if (outer_name.empty()) {
517 outer_var = create_var({var}, "outer");
518 } else {
519 outer_var = var_t::make(type_t::s32(), outer_name);
520 }
521 if (inner_name.empty()) {
522 inner_var = create_var({var}, "inner");
523 } else {
524 inner_var = var_t::make(type_t::s32(), inner_name);
525 }
526 auto &outer_loop = create_loop(outer_var, bound / factor);
527 auto &inner_loop = create_loop(inner_var, factor);
528 loop.set_split(outer_loop, inner_loop);
529 set_bmnk_kind(outer_var, bmnk_kind(var));
530 set_bmnk_kind(inner_var, bmnk_kind(var));
531 }
532
533 // Double split.
534 void split(const expr_t &var, int factor0, int factor1, expr_t &outer_var0,
535 expr_t &outer_var1, expr_t &inner_var) {
536 expr_t dummy_inner_var;
537 split(var, factor0, outer_var0, dummy_inner_var);
538 split(dummy_inner_var, factor1, outer_var1, inner_var);
539 }
540
541 // Fuses loops defined by `v0` and `v1` variables, v0 - outer variable, v1
542 // - inner variable.
543 // Before:
544 // for (int v0 = 0; v0 < V0; v0++) {
545 // for (int v1 = 0; v1 < V1; v1++) { ... }
546 // }
547 // After:
548 // for (int v = 0; v < V0 * V1; v++) {
549 // int v0 = v / V1;
550 // int v1 = v % V1;
551 // ...
552 // }
553 expr_t fuse(const expr_t &v0, const expr_t &v1) { return fuse({v0, v1}); }
554
555 // Double fuse, v0 - outermost variable, v2 - innermost variable.
556 expr_t fuse(const expr_t &v0, const expr_t &v1, const expr_t &v2) {
557 return fuse({v0, v1, v2});
558 }
559
560 // Fusion of multiple loops.
561 expr_t fuse(const std::vector<expr_t> &vars) {
562 auto fused_var = create_var(vars, "fused");
563 expr_t fused_bound = find_loop(vars[0]).bound();
564 for (int i = 1; i < int(vars.size()); i++) {
565 auto &loop = find_loop(vars[i]);
566 fused_bound *= loop.bound();
567 }
568 auto &fused_loop = create_loop(fused_var, fused_bound);
569 std::vector<std::reference_wrapper<loop_t>> loop_refs;
570 for (auto &v : vars) {
571 loop_refs.push_back(find_loop(v));
572 }
573 fused_loop.set_fuse(loop_refs);
574 set_bmnk_kind(fused_var, bmnk_kind(vars));
575 return fused_var;
576 }
577
578 // Sets unrolling factor for the given loop.
579 void unroll(const expr_t &v, int factor) {
580 auto &loop = find_loop(v);
581 loop.set_unroll_factor(factor);
582 }
583
584 // Marks the loop defined by `v` as tensorized.
585 void tensorize(const expr_t &v) {
586 auto &loop = find_loop(v);
587 loop.set_kind(loop_kind_t::tensorized);
588 }
589
590 // Binds the loop defined by `v` to an external variable.
591 void bind(const expr_t &v, const expr_t &bound_var) {
592 auto &loop = find_loop(v);
593 ir_assert(loop.is_leaf()) << "Can't bind non-leaf loop: " << v;
594 loop.set_bound_var(bound_var);
595 loop.set_kind(bound_var_to_loop_kind(bound_var));
596
597 int var_dim = bound_var_to_dim(bound_var);
598 ir_assert(to_cpp<int>(loop.bound()) == var_dim)
599 << "Dimension size doesn't match.";
600 }
601
602 // Reorders loops defined by given variables.
603 void reorder(const std::vector<expr_t> &ordered_vars) {
604 for (auto &v : ordered_vars) {
605 auto &loop = find_loop(v);
606 ir_assert(loop.is_leaf()) << "Can't reorder non-leaf loop: " << v;
607 }
608 std::vector<bool> found(vars_.size());
609 for (size_t i = 0; i < vars_.size(); i++) {
610 for (size_t j = 0; j < ordered_vars.size(); j++) {
611 if (ordered_vars[j].is_same(vars_[i])) {
612 found[i] = true;
613 break;
614 }
615 }
616 }
617
618 for (size_t i = 0, j = 0; i < vars_.size(); i++) {
619 if (!found[i]) continue;
620 vars_[i] = ordered_vars[j++];
621 }
622 }
623
624 // Adds a skip condition to the loop defined by `var`:
625 // for (var = 0; var < bound; var++) {
626 // if (cond) continue;
627 // ...
628 // }
629 void set_skip_condition(const expr_t &var, const expr_t &cond) {
630 ir_assert(find_loop(var).is_leaf()) << "Variable is non-leaf: " << var;
631 skip_conditions_[var] = expand(cond);
632 }
633
634 bool with_thread_group_k_slicing() const {
635 ir_assert(is_finalized_);
636 dim_t k_thr = 1;
637 dim_t k_tg = 1;
638 for (int i = 0; i < bmnk_mapper_.ndims(abc_kind_t::a); i++) {
639 if (bmnk_mapper_.bmnk_kind(abc_kind_t::a, i) != bmnk_kind_t::k)
640 continue;
641 k_thr *= a_thr_tile_(i);
642 k_tg *= a_tg_tile_(i);
643 }
644 ir_assert(k_tg % k_thr == 0);
645 return k_thr < k_tg;
646 }
647
648 bool with_kernel_grid_k_slicing() const {
649 ir_assert(is_finalized_);
650 dim_t k_loop = 1;
651 dim_t k = 1;
652 for (int i = 0; i < bmnk_mapper_.ndims(abc_kind_t::a); i++) {
653 if (bmnk_mapper_.bmnk_kind(abc_kind_t::a, i) != bmnk_kind_t::k)
654 continue;
655 auto info = get_split_info(a_view_.vvars()[i]);
656 k_loop *= info.dim(tile_level_t::loop);
657 k *= var_bound(a_view_.vvars()[i]);
658 }
659 return k_loop < k;
660 }
661
662 void finalize() {
663 init_problem_tiles();
664 init_constraint_set();
665 is_finalized_ = true;
666 }
667
668 template <typename F>
669 void for_each_var(const F &f) const {
670 for (auto &kv : loops_) {
671 f(kv.first);
672 }
673 }
674
675 expr_t expand(const expr_t &e, bool expand_trivial_vars = true) const {
676 auto found_vars = find_unique_objects<var_t>(e);
677 auto ret = e;
678 for (auto &v : found_vars) {
679 if (!has_loop(v)) continue;
680 auto &loop = find_loop(v);
681 auto v_value = loop.expand_var(loops_, /*skip_fused=*/true);
682 ret = substitute(ret, v, v_value);
683 }
684 if (expand_trivial_vars) {
685 for (auto &kv : loops_) {
686 int bound = to_cpp<int>(kv.second.bound());
687 if (bound != 1) continue;
688 if (!contains_object(ret, kv.first)) continue;
689 ret = substitute(ret, kv.first, expr_t(0));
690 }
691 }
692 return ret;
693 }
694
695 // Returns a statement describing the loop nest of the schedule.
696 stmt_t create_loop_nest(const stmt_t &_body = stmt_t()) const {
697 stmt_t body = _body;
698 auto found_vars = find_unique_objects<var_t>(body);
699 auto skip_conds = skip_conditions_;
700 for (auto it = vars_.rbegin(); it != vars_.rend(); it++) {
701 auto &var = *it;
702 auto &loop = find_loop(var);
703 if (!loop.is_leaf() || loop.is_tensorized() || loop.is_bound())
704 continue;
705 body = maybe_inject_let_for_fused_vars(body, loop);
706 auto cond_it = skip_conds.find(var);
707 if (cond_it != skip_conds.end()) {
708 auto skip_cond = cond_it->second;
709 cond_it->second = expr_t();
710 auto if_stmt = if_t::make(skip_cond, funcs::_continue());
711 body = if_stmt.append(body);
712 } else {
713 if (found_vars.count(var) == 0
714 && to_cpp<int>(loop.bound()) == 1)
715 continue;
716 }
717 body = for_t::make(
718 var, 0, loop.bound(), body, loop.unroll_factor());
719 }
720
721 for (auto &kv : skip_conds) {
722 auto &c = kv.second;
723 ir_assert(c.is_empty()) << "Skip condition is not injected: " << c;
724 }
725
726 return body;
727 }
728
729 stmt_t create_bind_stmt(const stmt_t &_body = stmt_t()) const {
730 stmt_t body = _body;
731 for (auto it = vars_.rbegin(); it != vars_.rend(); it++) {
732 auto &var = *it;
733 auto &loop = find_loop(var);
734 if (!loop.is_leaf() || !loop.is_bound()) continue;
735 body = maybe_inject_let_for_fused_vars(body, loop);
736 body = let_t::make(var, loop.bound_var(), body);
737 }
738 return body;
739 }
740
741private:
742 enum class tile_level_t { kernel_grid, loop, thread_group, iter };
743
744 static int nesting_level(tile_level_t level) {
745 switch (level) {
746 case tile_level_t::kernel_grid: return 0;
747 case tile_level_t::loop: return 1;
748 case tile_level_t::thread_group: return 2;
749 case tile_level_t::iter: return 3;
750 default: ir_error_not_expected();
751 }
752 return -1;
753 }
754
755 static int nesting_level(loop_kind_t kind) {
756 switch (kind) {
757 case loop_kind_t::kernel_grid: return 0;
758 case loop_kind_t::serial: return 1;
759 case loop_kind_t::tg_grid: return 2;
760 case loop_kind_t::tensorized: return 3;
761 default: ir_error_not_expected();
762 }
763 return -1;
764 }
765
766 // Describes split of a root loop into sub-loops.
767 class split_info_t {
768 public:
769 int nloops() const { return int(loops_.size()); }
770
771 void add_sub_loop(
772 const loop_t *loop, loop_kind_t loop_kind, int loop_level) {
773 loops_.push_back(loop);
774 loop_kinds_.push_back(loop_kind);
775 loop_levels_.push_back(loop_level);
776 }
777
778 // Verifies that sub-loops are ordered from outermost to innermost
779 // according to the schedule conventions. There are three set of loops:
780 // 1) Loops bound to kernel grid
781 // 2) Loops bound to thread group grid and serial loops
782 // 3) Tensorized loops
783 // Sets of loops must be ordered from outermost to innermost going from
784 // 1 to 3. Inside a set loops can be ordered arbitrarily.
785 bool is_valid() const {
786 auto get_loop_key = [&](int loop_idx) {
787 switch (loop_kinds_[loop_idx]) {
788 case loop_kind_t::kernel_grid: return -1;
789 case loop_kind_t::tg_grid:
790 // FIXME
791 case loop_kind_t::serial: return 0;
792 case loop_kind_t::tensorized:
793 return std::numeric_limits<int>::max();
794 default: ir_error_not_expected();
795 }
796 return -1;
797 };
798 int prev_key = -1;
799 for (int i = 0; i < nloops(); i++) {
800 int key = get_loop_key(i);
801 if (key < prev_key) return false;
802 prev_key = key;
803 }
804 return true;
805 }
806
807 // Returns total extent of all loops at a given tile level.
808 dim_t dim(tile_level_t tile_level) const {
809 dim_t ret = 1;
810 int t_level = nesting_level(tile_level);
811 for (int i = 0; i < nloops(); i++) {
812 int i_level = nesting_level(loop_kinds_[i]);
813 if (i_level < t_level) continue;
814 ret *= to_cpp<dim_t>(loops_[i]->bound());
815 }
816 return ret;
817 }
818
819 // Returns initial offset expressed in the outer variables at a given
820 // tile level.
821 expr_t start(const expr_t &var_expanded, tile_level_t tile_level,
822 bool with_outer = true) const {
823 auto ret = var_expanded;
824 int t_level = nesting_level(tile_level);
825 for (int i = 0; i < nloops(); i++) {
826 int i_level = nesting_level(loop_kinds_[i]);
827 if (with_outer) {
828 if (i_level < t_level) continue;
829 } else {
830 if (i_level + 1 == t_level) continue;
831 }
832 ret = substitute(ret, loops_[i]->var(), expr_t(0));
833 }
834 return simplify(ret);
835 }
836
837 private:
838 std::vector<const loop_t *> loops_;
839 std::vector<loop_kind_t> loop_kinds_;
840 std::vector<int> loop_levels_;
841 };
842
843 bmnk_kind_t bmnk_kind(const std::vector<expr_t> &vars) const {
844 if (vars.empty()) return bmnk_kind_t::undef;
845 if (vars.size() == 1) return bmnk_mapper_.bmnk_kind(vars[0]);
846 bmnk_kind_t ret = bmnk_kind(vars[0]);
847 for (size_t i = 1; i < vars.size(); i++) {
848 if (bmnk_kind(vars[i]) != ret) return bmnk_kind_t::undef;
849 }
850 return ret;
851 }
852
853 void set_bmnk_kind(const expr_t &var, bmnk_kind_t kind) {
854 bmnk_mapper_.set_bmnk_kind(var, kind);
855 }
856
857 void set_abc_view(const view_t &view, view_t &abc_view) {
858 abc_view = view;
859 set_view(view);
860 }
861
862 loop_kind_t bound_var_to_loop_kind(const expr_t &v) const {
863 for (int i = 0; i < kernel_grid_.ndims(); i++) {
864 if (kernel_grid_.idx(i).is_same(v)) return loop_kind_t::kernel_grid;
865 }
866 for (int i = 0; i < tg_grid_.ndims(); i++) {
867 if (tg_grid_.idx(i).is_same(v)) return loop_kind_t::tg_grid;
868 }
869 ir_error_not_expected() << "Unknown external variable: " << v;
870 return loop_kind_t::undef;
871 }
872
873 int bound_var_to_dim(const expr_t &v) const {
874 for (int i = 0; i < kernel_grid_.ndims(); i++) {
875 if (kernel_grid_.idx(i).is_same(v)) return kernel_grid_.dim(i);
876 }
877 for (int i = 0; i < tg_grid_.ndims(); i++) {
878 if (tg_grid_.idx(i).is_same(v)) return tg_grid_.dim(i);
879 }
880 ir_error_not_expected() << "Unknown external variable: " << v;
881 return -1;
882 }
883
884 bool has_loop(const expr_t &var) const {
885 auto it = loops_.find(var);
886 return it != loops_.end();
887 }
888
889 const loop_t &find_loop(const expr_t &var) const {
890 ir_assert(has_loop(var)) << "Var not found: " << var;
891 return loops_.at(var);
892 }
893
894 loop_t &find_loop(const expr_t &var) {
895 ir_assert(has_loop(var)) << "Var not found: " << var;
896 return loops_[var];
897 }
898
899 int loop_level(const expr_t &var) const {
900 for (int i = 0; i < int(vars_.size()); i++) {
901 if (vars_[i].is_same(var)) return i;
902 }
903 return -1;
904 }
905
906 loop_t &create_loop(
907 const expr_t &var, const expr_t &bound, bool is_root = false) {
908 loop_t loop(var, bound, is_root);
909 auto ret = loops_.insert({var, loop});
910 ir_assert(ret.second) << "Variable already exists: " << var;
911 vars_.push_back(var);
912 return ret.first->second;
913 }
914
915 static std::string strip_suffix(
916 const std::string &s, const std::string &suffix) {
917 auto pos = s.find(suffix);
918 if (pos == std::string::npos) return s;
919 if (pos + suffix.length() != s.length()) return s;
920 return s.substr(0, pos);
921 }
922
923 static expr_t create_var(
924 const std::vector<expr_t> &vars, const std::string &suffix) {
925 std::string var_name;
926 for (auto &v : vars) {
927 auto name = strip_suffix(v.as<var_t>().name, "_idx");
928 var_name += name + "_";
929 }
930 var_name += suffix;
931 return var_t::make(type_t::s32(), var_name);
932 }
933
934 void init_problem_tiles() {
935 object_map_t<expr_t, split_info_t> split_infos;
936 for (auto *view : {&a_view_, &b_view_, &c_view_}) {
937 for (auto &v : view->vvars()) {
938 if (split_infos.count(v) > 0) continue;
939 split_infos.insert({v, get_split_info(v)});
940 }
941 }
942 a_tg_tile_ = compute_problem_tile(
943 a_view_.vvars(), split_infos, tile_level_t::thread_group);
944 b_tg_tile_ = compute_problem_tile(
945 b_view_.vvars(), split_infos, tile_level_t::thread_group);
946 c_tg_tile_ = compute_problem_tile(
947 c_view_.vvars(), split_infos, tile_level_t::thread_group);
948 a_thr_tile_ = compute_problem_tile(
949 a_view_.vvars(), split_infos, tile_level_t::iter);
950 b_thr_tile_ = compute_problem_tile(
951 b_view_.vvars(), split_infos, tile_level_t::iter);
952 c_thr_tile_ = compute_problem_tile(
953 c_view_.vvars(), split_infos, tile_level_t::iter);
954 }
955
956 void init_constraint_set() {
957 for (auto &v : vars_) {
958 auto &loop = find_loop(v);
959 if (loop.is_fused_parent()) {
960 cset_->add_constraint(v >= 0);
961 cset_->add_constraint(v < loop.bound());
962 continue;
963 }
964 if (!loop.is_leaf()) continue;
965
966 // Fused variables are used only to initialize fused parents.
967 if (loop.is_fused_child()) continue;
968
969 if (loop.is_bound()) {
970 cset_->add_constraint(v == loop.bound_var());
971 continue;
972 }
973
974 cset_->add_constraint(v >= 0);
975 cset_->add_constraint(v < loop.bound());
976 }
977 }
978
979 tensor_t view_tile(const view_t &view, tile_level_t level) const {
980 object_map_t<expr_t, split_info_t> split_infos;
981 for (auto &v : view.vvars()) {
982 if (split_infos.count(v) > 0) continue;
983 split_infos.insert({v, get_split_info(v)});
984 }
985 return compute_problem_tile(view.vvars(), split_infos, level);
986 }
987
988 split_info_t get_split_info(const expr_t &root_var) const {
989 split_info_t ret;
990 std::function<void(const expr_t &)> walk_down;
991 walk_down = [&](const expr_t &v) {
992 auto &loop = find_loop(v);
993 if (loop.is_leaf() || loop.is_fused_parent()) {
994 // Treat a fused var as leaf as it can't be split into other
995 // vars.
996 loop_kind_t kind = loop.kind();
997 int level;
998 if (loop.is_fused_parent()) {
999 auto &child_var = loop.child_vars()[0];
1000 ir_assert(find_loop(child_var).is_leaf());
1001 kind = find_loop(child_var).kind();
1002 level = loop_level(child_var);
1003 } else {
1004 level = loop_level(v);
1005 }
1006 ret.add_sub_loop(&loop, kind, level);
1007 } else if (loop.is_split_parent()) {
1008 walk_down(loop.child_vars()[0]);
1009 walk_down(loop.child_vars()[1]);
1010 } else {
1011 ir_error_not_expected();
1012 }
1013 };
1014 walk_down(root_var);
1015 ir_assert(ret.is_valid()) << "Invalid loop nest.";
1016 return ret;
1017 }
1018
1019 tensor_t compute_problem_tile(const std::vector<expr_t> &vars,
1020 const object_map_t<expr_t, split_info_t> &split_infos,
1021 tile_level_t tile_level) const {
1022 std::vector<dim_t> tile_dims;
1023 std::vector<expr_t> tile_start;
1024 bool with_outer = (tile_level == tile_level_t::thread_group);
1025 for (auto &v : vars) {
1026 auto &split_info = split_infos.at(v);
1027 tile_dims.push_back(split_info.dim(tile_level));
1028
1029 auto v_expanded = expand(v);
1030 tile_start.push_back(
1031 split_info.start(v_expanded, tile_level, with_outer));
1032 }
1033 return tensor_t(tile_dims, tile_start);
1034 }
1035
1036 stmt_t maybe_inject_let_for_fused_vars(
1037 const stmt_t &_body, const loop_t &loop) const {
1038 auto body = _body;
1039 if (!loop.is_leaf() || !loop.is_fused_child()) return body;
1040 auto &pvars = loop.parent_vars();
1041 for (auto it = pvars.rbegin(); it != pvars.rend(); it++) {
1042 auto &ploop = find_loop(*it);
1043 body = let_t::make(*it, ploop.expand_var(loops_), body);
1044 }
1045 return body;
1046 }
1047
1048 bool is_finalized_ = false;
1049
1050 constraint_set_t *cset_;
1051 grid_info_t kernel_grid_;
1052 grid_info_t tg_grid_;
1053
1054 // Loop indices, ordered from outermost to innermost.
1055 std::vector<expr_t> vars_;
1056
1057 object_map_t<expr_t, loop_t> loops_;
1058
1059 object_map_t<expr_t, expr_t> skip_conditions_;
1060
1061 bmnk_mapper_t bmnk_mapper_;
1062
1063 // Full views for A, B, C.
1064 view_t a_view_;
1065 view_t b_view_;
1066 view_t c_view_;
1067
1068 // Thread group tiles for A, B, C.
1069 tensor_t a_tg_tile_;
1070 tensor_t b_tg_tile_;
1071 tensor_t c_tg_tile_;
1072
1073 // Thread tiles for A, B, C (relative to thread group tiles).
1074 tensor_t a_thr_tile_;
1075 tensor_t b_thr_tile_;
1076 tensor_t c_thr_tile_;
1077};
1078
1079} // namespace jit
1080} // namespace gpu
1081} // namespace impl
1082} // namespace dnnl
1083
1084#endif
1085