1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_IR_GEMM_SCHEDULE_HPP |
18 | #define GPU_JIT_IR_GEMM_SCHEDULE_HPP |
19 | |
20 | #include <functional> |
21 | #include <limits> |
22 | #include <sstream> |
23 | #include <string> |
24 | #include <utility> |
25 | #include <vector> |
26 | #include <initializer_list> |
27 | |
28 | #include "gpu/jit/ir/ir.hpp" |
29 | #include "gpu/jit/ir/tensor.hpp" |
30 | #include "gpu/jit/utils/utils.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace gpu { |
35 | namespace jit { |
36 | |
37 | // Used to describe semantics of a dimension in the GEMM context. |
38 | // GEMM operation is defined as C = A x B |
39 | // GEMM dimension kinds: |
40 | // - B: shared by all tensors A, B, C (batch dimension) |
41 | // - M: shared only by A and C |
42 | // - N: shared only by B and C |
43 | // - K: shared only by A and B (reduction dimension) |
44 | enum class bmnk_kind_t { undef = -1, b = 0, m = 1, n = 2, k = 3 }; |
45 | |
46 | enum class abc_kind_t { undef, a, b, c }; |
47 | |
48 | class bmnk_mapper_t { |
49 | public: |
50 | bmnk_mapper_t() = default; |
51 | |
52 | bmnk_mapper_t(const object_map_t<expr_t, bmnk_kind_t> &bmnk_kinds) |
53 | : bmnk_kinds_(bmnk_kinds) {} |
54 | |
55 | bmnk_kind_t bmnk_kind(const expr_t &var) const { |
56 | auto it = bmnk_kinds_.find(var); |
57 | if (it == bmnk_kinds_.end()) return bmnk_kind_t::undef; |
58 | return it->second; |
59 | } |
60 | |
61 | bmnk_kind_t bmnk_kind(abc_kind_t abc_kind, int dim_idx) const { |
62 | return bmnk_kind(var(abc_kind, dim_idx)); |
63 | } |
64 | |
65 | int ndims(abc_kind_t abc_kind) const { |
66 | return int(get_vars(abc_kind).size()); |
67 | } |
68 | |
69 | void set_a_vars(const std::vector<expr_t> &vars) { a_vars_ = vars; } |
70 | void set_b_vars(const std::vector<expr_t> &vars) { b_vars_ = vars; } |
71 | void set_c_vars(const std::vector<expr_t> &vars) { c_vars_ = vars; } |
72 | |
73 | void set_bmnk_kind(const expr_t &var, bmnk_kind_t bmnk_kind) { |
74 | auto ret = bmnk_kinds_.insert({var, bmnk_kind}); |
75 | ir_assert(ret.second) << "Can't set variable twice: " << var; |
76 | } |
77 | |
78 | const expr_t &var(abc_kind_t abc_kind, int dim_idx) const { |
79 | return get_vars(abc_kind)[dim_idx]; |
80 | } |
81 | |
82 | int dim_idx(abc_kind_t abc_kind, const expr_t &var) const { |
83 | auto &vars = get_vars(abc_kind); |
84 | for (int i = 0; i < int(vars.size()); i++) { |
85 | if (vars[i].is_same(var)) return i; |
86 | } |
87 | return -1; |
88 | } |
89 | |
90 | layout_t map_to_bmnk(abc_kind_t abc_kind, |
91 | const std::vector<bmnk_kind_t> &bmnk_kinds, |
92 | const view_t &view) const; |
93 | |
94 | layout_t map_to_bmnk(abc_kind_t abc_kind, |
95 | const std::vector<bmnk_kind_t> &bmnk_kinds, |
96 | const layout_t &layout) const; |
97 | |
98 | private: |
99 | const std::vector<expr_t> &get_vars(abc_kind_t abc_kind) const { |
100 | switch (abc_kind) { |
101 | case abc_kind_t::a: return a_vars_; |
102 | case abc_kind_t::b: return b_vars_; |
103 | case abc_kind_t::c: return c_vars_; |
104 | default: ir_error_not_expected() << "Unknown ABC kind." ; |
105 | } |
106 | return a_vars_; |
107 | } |
108 | |
109 | std::vector<expr_t> &get_vars(abc_kind_t abc_kind) { |
110 | auto &vars |
111 | = const_cast<const bmnk_mapper_t *>(this)->get_vars(abc_kind); |
112 | return const_cast<std::vector<expr_t> &>(vars); |
113 | } |
114 | |
115 | std::vector<expr_t> a_vars_; |
116 | std::vector<expr_t> b_vars_; |
117 | std::vector<expr_t> c_vars_; |
118 | object_map_t<expr_t, bmnk_kind_t> bmnk_kinds_; |
119 | }; |
120 | |
121 | class bmnk_block_mapper_t { |
122 | public: |
123 | bmnk_block_mapper_t(const bmnk_mapper_t &bmnk_mapper) |
124 | : bmnk_mapper_(bmnk_mapper) {} |
125 | |
126 | void push_blocks(abc_kind_t abc_kind, const std::vector<block_t> &blocks) { |
127 | for (auto &b : blocks) |
128 | push_block(abc_kind, b); |
129 | } |
130 | |
131 | void push_block(abc_kind_t abc_kind, const block_t &b); |
132 | |
133 | layout_t map_from_bmnk(abc_kind_t abc_kind, |
134 | const std::vector<bmnk_kind_t> &bmnk_kinds, |
135 | const layout_t &bmnk_layout) const; |
136 | |
137 | private: |
138 | static void pop_size_1_blocks(std::vector<block_t> &blocks) { |
139 | while (!blocks.empty() && blocks.front().block == 1) { |
140 | blocks.erase(blocks.begin()); |
141 | } |
142 | } |
143 | |
144 | std::vector<block_t> create_prb_blocks(abc_kind_t abc_kind, |
145 | const std::vector<std::pair<abc_kind_t, block_t>> &mn_blocks) |
146 | const { |
147 | std::vector<block_t> ret; |
148 | ret.reserve(mn_blocks.size()); |
149 | for (auto &p : mn_blocks) { |
150 | auto b = p.second; |
151 | const auto &var = bmnk_mapper_.var(p.first, b.dim_idx); |
152 | b.dim_idx = bmnk_mapper_.dim_idx(abc_kind, var); |
153 | ret.push_back(b); |
154 | } |
155 | return ret; |
156 | } |
157 | |
158 | bool pop_block(std::vector<block_t> &bmnk_blocks, |
159 | std::vector<block_t> &prb_blocks, const block_t &bmnk_block) const; |
160 | |
161 | bmnk_mapper_t bmnk_mapper_; |
162 | |
163 | // Ordered from innermost to outermost. |
164 | std::vector<std::pair<abc_kind_t, block_t>> m_blocks_; |
165 | std::vector<std::pair<abc_kind_t, block_t>> n_blocks_; |
166 | std::vector<std::pair<abc_kind_t, block_t>> k_blocks_; |
167 | }; |
168 | |
169 | enum class loop_kind_t : int { |
170 | undef, |
171 | kernel_grid, // Loop is bound to the kernel grid. |
172 | serial, // Loop is inside a thread (may be unrolled or just a regular loop). |
173 | tg_grid, // Loop is bound to the thread group grid. |
174 | tensorized, // Such loops are fully unrolled/vectorized and converted to blocked multiplication. |
175 | }; |
176 | |
177 | static std::string to_string(loop_kind_t kind) { |
178 | switch (kind) { |
179 | case loop_kind_t::undef: return "undef" ; |
180 | case loop_kind_t::kernel_grid: return "kernel_grid" ; |
181 | case loop_kind_t::serial: return "serial" ; |
182 | case loop_kind_t::tg_grid: return "tg_grid" ; |
183 | case loop_kind_t::tensorized: return "tensorized" ; |
184 | default: ir_error_not_expected(); |
185 | } |
186 | return "unknown" ; |
187 | } |
188 | |
189 | inline std::ostream &operator<<(std::ostream &out, loop_kind_t kind) { |
190 | out << to_string(kind); |
191 | return out; |
192 | } |
193 | |
194 | class loop_t { |
195 | public: |
196 | loop_t() : kind_(loop_kind_t::undef) {} |
197 | |
198 | loop_t(const expr_t &var, const expr_t &bound, bool is_root) |
199 | : var_(var) |
200 | , kind_(loop_kind_t::serial) |
201 | , bound_(bound) |
202 | , is_root_(is_root) {} |
203 | |
204 | const expr_t &var() const { return var_; } |
205 | |
206 | loop_kind_t kind() const { return kind_; } |
207 | |
208 | void set_kind(loop_kind_t kind) { kind_ = kind; } |
209 | |
210 | int unroll_factor() const { return unroll_factor_; } |
211 | |
212 | void set_unroll_factor(int factor) { unroll_factor_ = factor; } |
213 | |
214 | bool is_kernel_grid() const { return kind() == loop_kind_t::kernel_grid; } |
215 | |
216 | bool is_serial() const { return kind() == loop_kind_t::serial; } |
217 | |
218 | bool is_tg_grid() const { return kind() == loop_kind_t::tg_grid; } |
219 | |
220 | bool is_tensorized() const { return kind() == loop_kind_t::tensorized; } |
221 | |
222 | const expr_t &bound() const { return bound_; } |
223 | |
224 | void set_bound(const expr_t &bound) { bound_ = bound; } |
225 | |
226 | bool is_bound() const { return !bound_var().is_empty(); } |
227 | |
228 | const expr_t &bound_var() const { return bound_var_; } |
229 | |
230 | void set_bound_var(const expr_t &v) { bound_var_ = v; } |
231 | |
232 | bool is_root() const { return is_root_; } |
233 | |
234 | // Returns true for loops that were neither split, nor fused with other loops. |
235 | bool is_leaf() const { return is_leaf_; } |
236 | |
237 | // Returns true if this loop was split into outer/inner loops. |
238 | bool is_split_parent() const { return is_split_parent_; } |
239 | |
240 | // Returns true if this loop was the result of a split. |
241 | bool is_split_child() const { return is_split_child_; } |
242 | |
243 | // Returns true if this loop was fused with other loops. |
244 | bool is_fused_parent() const { return is_fused_parent_; } |
245 | |
246 | // Returns true if this loop was the result of a fusion. |
247 | bool is_fused_child() const { return is_fused_child_; } |
248 | |
249 | const std::vector<expr_t> &parent_vars() const { return parent_vars_; } |
250 | const std::vector<expr_t> &child_vars() const { return child_vars_; } |
251 | |
252 | void set_split(loop_t &outer_loop, loop_t &inner_loop) { |
253 | outer_loop.parent_vars_.push_back(var()); |
254 | child_vars_.push_back(outer_loop.var()); |
255 | outer_loop.is_split_child_ = true; |
256 | |
257 | inner_loop.parent_vars_.push_back(var()); |
258 | child_vars_.push_back(inner_loop.var()); |
259 | inner_loop.is_split_child_ = true; |
260 | |
261 | is_split_parent_ = true; |
262 | is_leaf_ = false; |
263 | } |
264 | |
265 | void set_fuse(std::vector<std::reference_wrapper<loop_t>> &loops) { |
266 | for (auto &l_ref : loops) { |
267 | auto &l = l_ref.get(); |
268 | parent_vars_.push_back(l.var()); |
269 | l.child_vars_.push_back(var()); |
270 | l.is_fused_parent_ = true; |
271 | l.is_leaf_ = false; |
272 | } |
273 | is_fused_child_ = true; |
274 | } |
275 | |
276 | // Returns a loop variable expressed in the variables of the leaf loops. |
277 | expr_t expand_var(const object_map_t<expr_t, loop_t> &all_loops, |
278 | bool skip_fused = false) const { |
279 | if (is_leaf()) return var(); |
280 | if (is_split_parent()) { |
281 | ir_assert(child_vars_.size() == 2); |
282 | auto &outer_loop = all_loops.at(child_vars_[0]); |
283 | auto &inner_loop = all_loops.at(child_vars_[1]); |
284 | auto outer_var = outer_loop.expand_var(all_loops, skip_fused); |
285 | auto inner_var = inner_loop.expand_var(all_loops, skip_fused); |
286 | return outer_var * inner_loop.bound() + inner_var; |
287 | } |
288 | if (is_fused_parent()) { |
289 | if (skip_fused) return var(); |
290 | // Example of "unpacking": |
291 | // fused_var = (a * b * c * d) |
292 | // b = (fused_var / (D * C)) % B |
293 | ir_assert(child_vars_.size() == 1); |
294 | auto &fused_loop = all_loops.at(child_vars_[0]); |
295 | int nvars = int(fused_loop.parent_vars_.size()); |
296 | expr_t denom = 1; |
297 | for (int i = nvars - 1; i >= 0; i--) { |
298 | auto &v = fused_loop.parent_vars_[i]; |
299 | auto &child_loop = all_loops.at(v); |
300 | auto &bound = child_loop.bound(); |
301 | if (v.is_same(var())) { |
302 | auto e = fused_loop.expand_var(all_loops, skip_fused) |
303 | / denom; |
304 | return (i == 0 ? e : e % bound); |
305 | } |
306 | denom *= bound; |
307 | } |
308 | } |
309 | |
310 | ir_error_not_expected(); |
311 | return expr_t(); |
312 | } |
313 | |
314 | std::string str() const { |
315 | using namespace ir_utils; |
316 | |
317 | std::ostringstream oss; |
318 | oss << "var: " << var_; |
319 | oss << " bound: " << bound_; |
320 | oss << " kind: " << kind_; |
321 | if (unroll_factor_ != 1) oss << " unroll: " << unroll_factor_; |
322 | std::vector<std::string> props; |
323 | if (is_root()) props.push_back("root" ); |
324 | if (is_fused_child()) props.push_back("fused" ); |
325 | if (is_split_parent()) props.push_back("split" ); |
326 | oss << "(" << make_seq_print_helper(props, ", " ) << ")" ; |
327 | return oss.str(); |
328 | } |
329 | |
330 | IR_DEFINE_DUMP() |
331 | |
332 | private: |
333 | expr_t var_; // Loop index variable. |
334 | loop_kind_t kind_; // Loop kind. |
335 | expr_t bound_; // Loop bound (exclusive). |
336 | |
337 | expr_t bound_var_; // External variable this loop bound to. |
338 | |
339 | int unroll_factor_ = 1; |
340 | |
341 | bool is_root_ = false; |
342 | bool is_leaf_ = true; |
343 | |
344 | bool is_split_parent_ = false; |
345 | bool is_split_child_ = false; |
346 | |
347 | bool is_fused_parent_ = false; |
348 | bool is_fused_child_ = false; |
349 | |
350 | // For variables there were split or fused. |
351 | // Fusion: i x j -> k |
352 | // i.child_vars _= [k] |
353 | // j.child_vars _= [k] |
354 | // k.parent_vars_ = [i, j] |
355 | // Split: i -> j x k |
356 | // i.child_vars_ = [j, k] |
357 | // j.parent_vars_ = [i] |
358 | // k.parent_vars_ = [i] |
359 | std::vector<expr_t> parent_vars_; |
360 | std::vector<expr_t> child_vars_; |
361 | }; |
362 | |
363 | // Defines GEMM computation including: |
364 | // - Blocking scheme (order of loops, tiles per thread group/iteration) |
365 | // - Mapping of problem dimensions to GEMM dimensions (BMNK) |
366 | class gemm_schedule_t { |
367 | public: |
368 | gemm_schedule_t() = default; |
369 | |
370 | gemm_schedule_t(constraint_set_t &cset, const grid_info_t &kernel_grid, |
371 | const grid_info_t &tg_grid) |
372 | : cset_(&cset), kernel_grid_(kernel_grid), tg_grid_(tg_grid) {} |
373 | |
374 | const grid_info_t &kernel_grid() const { return kernel_grid_; } |
375 | const grid_info_t &tg_grid() const { return tg_grid_; } |
376 | |
377 | bmnk_kind_t bmnk_kind(const expr_t &var) const { |
378 | return bmnk_kind(std::vector<expr_t>({var})); |
379 | } |
380 | |
381 | const bmnk_mapper_t &bmnk_mapper() const { return bmnk_mapper_; } |
382 | |
383 | void set_b_vars(const std::vector<expr_t> &vars) { |
384 | for (auto &v : vars) |
385 | set_bmnk_kind(v, bmnk_kind_t::b); |
386 | } |
387 | |
388 | void set_m_vars(const std::vector<expr_t> &vars) { |
389 | for (auto &v : vars) |
390 | set_bmnk_kind(v, bmnk_kind_t::m); |
391 | } |
392 | |
393 | void set_n_vars(const std::vector<expr_t> &vars) { |
394 | for (auto &v : vars) |
395 | set_bmnk_kind(v, bmnk_kind_t::n); |
396 | } |
397 | |
398 | void set_k_vars(const std::vector<expr_t> &vars) { |
399 | for (auto &v : vars) |
400 | set_bmnk_kind(v, bmnk_kind_t::k); |
401 | } |
402 | |
403 | // A/B/C views in the problem notation. |
404 | const view_t &a_view() const { return a_view_; } |
405 | const view_t &b_view() const { return b_view_; } |
406 | const view_t &c_view() const { return c_view_; } |
407 | |
408 | void set_a_view(const view_t &v) { |
409 | set_abc_view(v, a_view_); |
410 | bmnk_mapper_.set_a_vars(a_view_.vvars()); |
411 | } |
412 | |
413 | void set_b_view(const view_t &v) { |
414 | set_abc_view(v, b_view_); |
415 | bmnk_mapper_.set_b_vars(b_view_.vvars()); |
416 | } |
417 | |
418 | void set_c_view(const view_t &v) { |
419 | set_abc_view(v, c_view_); |
420 | bmnk_mapper_.set_c_vars(c_view_.vvars()); |
421 | } |
422 | |
423 | void set_view(const view_t &view) { |
424 | // Create missing loops. |
425 | for (int i = 0; i < view.nvdims(); i++) { |
426 | auto &v = view.vvars()[i]; |
427 | dim_t bound = view.vdims()[i]; |
428 | if (has_loop(v)) { |
429 | auto &loop = find_loop(v); |
430 | ir_assert(bound == to_cpp<dim_t>(loop.bound())) |
431 | << "Inconsistent sizes." ; |
432 | continue; |
433 | } |
434 | create_loop(v, bound, /*is_root=*/true); |
435 | } |
436 | } |
437 | |
438 | tensor_t tg_view_tile(const view_t &view) const { |
439 | return view_tile(view, tile_level_t::thread_group); |
440 | } |
441 | |
442 | tensor_t thr_view_tile(const view_t &view, bool is_relative = true) const { |
443 | auto thr_tile = view_tile(view, tile_level_t::iter); |
444 | if (is_relative) return thr_tile; |
445 | return tg_view_tile(view).create_sub_tensor(thr_tile); |
446 | } |
447 | |
448 | view_t a_tg_view() const { |
449 | ir_assert(is_finalized_); |
450 | return a_view_.create_sub_view(a_tg_tile_); |
451 | } |
452 | |
453 | view_t b_tg_view() const { |
454 | ir_assert(is_finalized_); |
455 | return b_view_.create_sub_view(b_tg_tile_); |
456 | } |
457 | |
458 | view_t c_tg_view() const { |
459 | ir_assert(is_finalized_); |
460 | return c_view_.create_sub_view(c_tg_tile_); |
461 | } |
462 | |
463 | // Thread group tiles for A, B, C. |
464 | const tensor_t &a_tg_tile() const { return a_tg_tile_; } |
465 | const tensor_t &b_tg_tile() const { return b_tg_tile_; } |
466 | const tensor_t &c_tg_tile() const { return c_tg_tile_; } |
467 | |
468 | // Thread tiles for A, B, C. |
469 | tensor_t a_thr_tile(bool is_relative = true) const { |
470 | if (is_relative) return a_thr_tile_; |
471 | return a_tg_tile_.create_sub_tensor(a_thr_tile_); |
472 | } |
473 | |
474 | tensor_t b_thr_tile(bool is_relative = true) const { |
475 | if (is_relative) return b_thr_tile_; |
476 | return b_tg_tile_.create_sub_tensor(b_thr_tile_); |
477 | } |
478 | |
479 | tensor_t c_thr_tile(bool is_relative = true) const { |
480 | if (is_relative) return c_thr_tile_; |
481 | return c_tg_tile_.create_sub_tensor(c_thr_tile_); |
482 | } |
483 | |
484 | int var_bound(const expr_t &var) const { |
485 | return to_cpp<int>(find_loop(var).bound()); |
486 | } |
487 | |
488 | void set_var_bound(const expr_t &var, int bound) { |
489 | return find_loop(var).set_bound(bound); |
490 | } |
491 | |
492 | // Splits loop defined by `var` into two new loops based on `factor`. |
493 | // Before: |
494 | // for (int var = 0; var < I; var++) { ... } |
495 | // After: |
496 | // for (int outer_var = 0; outer_var < I / factor; outer_var++) { |
497 | // for (int inner_var = 0; inner_var < factor; inner_var++) { |
498 | // ... |
499 | // } |
500 | // } |
501 | void split(const expr_t &var, int factor, expr_t &outer_var, |
502 | expr_t &inner_var, const std::string &outer_name = {}, |
503 | const std::string &inner_name = {}) { |
504 | auto &loop = find_loop(var); |
505 | ir_assert(loop.is_leaf()) << "Can't split, non-leaf loop." ; |
506 | |
507 | int bound = to_cpp<int>(loop.bound()); |
508 | if (loop.is_root() && (bound % factor != 0)) { |
509 | // Auto round-up bounds for the root loops. |
510 | bound = utils::rnd_up(bound, factor); |
511 | loop.set_bound(bound); |
512 | } |
513 | |
514 | ir_assert(bound % factor == 0) << "Can't split." ; |
515 | |
516 | if (outer_name.empty()) { |
517 | outer_var = create_var({var}, "outer" ); |
518 | } else { |
519 | outer_var = var_t::make(type_t::s32(), outer_name); |
520 | } |
521 | if (inner_name.empty()) { |
522 | inner_var = create_var({var}, "inner" ); |
523 | } else { |
524 | inner_var = var_t::make(type_t::s32(), inner_name); |
525 | } |
526 | auto &outer_loop = create_loop(outer_var, bound / factor); |
527 | auto &inner_loop = create_loop(inner_var, factor); |
528 | loop.set_split(outer_loop, inner_loop); |
529 | set_bmnk_kind(outer_var, bmnk_kind(var)); |
530 | set_bmnk_kind(inner_var, bmnk_kind(var)); |
531 | } |
532 | |
533 | // Double split. |
534 | void split(const expr_t &var, int factor0, int factor1, expr_t &outer_var0, |
535 | expr_t &outer_var1, expr_t &inner_var) { |
536 | expr_t dummy_inner_var; |
537 | split(var, factor0, outer_var0, dummy_inner_var); |
538 | split(dummy_inner_var, factor1, outer_var1, inner_var); |
539 | } |
540 | |
541 | // Fuses loops defined by `v0` and `v1` variables, v0 - outer variable, v1 |
542 | // - inner variable. |
543 | // Before: |
544 | // for (int v0 = 0; v0 < V0; v0++) { |
545 | // for (int v1 = 0; v1 < V1; v1++) { ... } |
546 | // } |
547 | // After: |
548 | // for (int v = 0; v < V0 * V1; v++) { |
549 | // int v0 = v / V1; |
550 | // int v1 = v % V1; |
551 | // ... |
552 | // } |
553 | expr_t fuse(const expr_t &v0, const expr_t &v1) { return fuse({v0, v1}); } |
554 | |
555 | // Double fuse, v0 - outermost variable, v2 - innermost variable. |
556 | expr_t fuse(const expr_t &v0, const expr_t &v1, const expr_t &v2) { |
557 | return fuse({v0, v1, v2}); |
558 | } |
559 | |
560 | // Fusion of multiple loops. |
561 | expr_t fuse(const std::vector<expr_t> &vars) { |
562 | auto fused_var = create_var(vars, "fused" ); |
563 | expr_t fused_bound = find_loop(vars[0]).bound(); |
564 | for (int i = 1; i < int(vars.size()); i++) { |
565 | auto &loop = find_loop(vars[i]); |
566 | fused_bound *= loop.bound(); |
567 | } |
568 | auto &fused_loop = create_loop(fused_var, fused_bound); |
569 | std::vector<std::reference_wrapper<loop_t>> loop_refs; |
570 | for (auto &v : vars) { |
571 | loop_refs.push_back(find_loop(v)); |
572 | } |
573 | fused_loop.set_fuse(loop_refs); |
574 | set_bmnk_kind(fused_var, bmnk_kind(vars)); |
575 | return fused_var; |
576 | } |
577 | |
578 | // Sets unrolling factor for the given loop. |
579 | void unroll(const expr_t &v, int factor) { |
580 | auto &loop = find_loop(v); |
581 | loop.set_unroll_factor(factor); |
582 | } |
583 | |
584 | // Marks the loop defined by `v` as tensorized. |
585 | void tensorize(const expr_t &v) { |
586 | auto &loop = find_loop(v); |
587 | loop.set_kind(loop_kind_t::tensorized); |
588 | } |
589 | |
590 | // Binds the loop defined by `v` to an external variable. |
591 | void bind(const expr_t &v, const expr_t &bound_var) { |
592 | auto &loop = find_loop(v); |
593 | ir_assert(loop.is_leaf()) << "Can't bind non-leaf loop: " << v; |
594 | loop.set_bound_var(bound_var); |
595 | loop.set_kind(bound_var_to_loop_kind(bound_var)); |
596 | |
597 | int var_dim = bound_var_to_dim(bound_var); |
598 | ir_assert(to_cpp<int>(loop.bound()) == var_dim) |
599 | << "Dimension size doesn't match." ; |
600 | } |
601 | |
602 | // Reorders loops defined by given variables. |
603 | void reorder(const std::vector<expr_t> &ordered_vars) { |
604 | for (auto &v : ordered_vars) { |
605 | auto &loop = find_loop(v); |
606 | ir_assert(loop.is_leaf()) << "Can't reorder non-leaf loop: " << v; |
607 | } |
608 | std::vector<bool> found(vars_.size()); |
609 | for (size_t i = 0; i < vars_.size(); i++) { |
610 | for (size_t j = 0; j < ordered_vars.size(); j++) { |
611 | if (ordered_vars[j].is_same(vars_[i])) { |
612 | found[i] = true; |
613 | break; |
614 | } |
615 | } |
616 | } |
617 | |
618 | for (size_t i = 0, j = 0; i < vars_.size(); i++) { |
619 | if (!found[i]) continue; |
620 | vars_[i] = ordered_vars[j++]; |
621 | } |
622 | } |
623 | |
624 | // Adds a skip condition to the loop defined by `var`: |
625 | // for (var = 0; var < bound; var++) { |
626 | // if (cond) continue; |
627 | // ... |
628 | // } |
629 | void set_skip_condition(const expr_t &var, const expr_t &cond) { |
630 | ir_assert(find_loop(var).is_leaf()) << "Variable is non-leaf: " << var; |
631 | skip_conditions_[var] = expand(cond); |
632 | } |
633 | |
634 | bool with_thread_group_k_slicing() const { |
635 | ir_assert(is_finalized_); |
636 | dim_t k_thr = 1; |
637 | dim_t k_tg = 1; |
638 | for (int i = 0; i < bmnk_mapper_.ndims(abc_kind_t::a); i++) { |
639 | if (bmnk_mapper_.bmnk_kind(abc_kind_t::a, i) != bmnk_kind_t::k) |
640 | continue; |
641 | k_thr *= a_thr_tile_(i); |
642 | k_tg *= a_tg_tile_(i); |
643 | } |
644 | ir_assert(k_tg % k_thr == 0); |
645 | return k_thr < k_tg; |
646 | } |
647 | |
648 | bool with_kernel_grid_k_slicing() const { |
649 | ir_assert(is_finalized_); |
650 | dim_t k_loop = 1; |
651 | dim_t k = 1; |
652 | for (int i = 0; i < bmnk_mapper_.ndims(abc_kind_t::a); i++) { |
653 | if (bmnk_mapper_.bmnk_kind(abc_kind_t::a, i) != bmnk_kind_t::k) |
654 | continue; |
655 | auto info = get_split_info(a_view_.vvars()[i]); |
656 | k_loop *= info.dim(tile_level_t::loop); |
657 | k *= var_bound(a_view_.vvars()[i]); |
658 | } |
659 | return k_loop < k; |
660 | } |
661 | |
662 | void finalize() { |
663 | init_problem_tiles(); |
664 | init_constraint_set(); |
665 | is_finalized_ = true; |
666 | } |
667 | |
668 | template <typename F> |
669 | void for_each_var(const F &f) const { |
670 | for (auto &kv : loops_) { |
671 | f(kv.first); |
672 | } |
673 | } |
674 | |
675 | expr_t expand(const expr_t &e, bool expand_trivial_vars = true) const { |
676 | auto found_vars = find_unique_objects<var_t>(e); |
677 | auto ret = e; |
678 | for (auto &v : found_vars) { |
679 | if (!has_loop(v)) continue; |
680 | auto &loop = find_loop(v); |
681 | auto v_value = loop.expand_var(loops_, /*skip_fused=*/true); |
682 | ret = substitute(ret, v, v_value); |
683 | } |
684 | if (expand_trivial_vars) { |
685 | for (auto &kv : loops_) { |
686 | int bound = to_cpp<int>(kv.second.bound()); |
687 | if (bound != 1) continue; |
688 | if (!contains_object(ret, kv.first)) continue; |
689 | ret = substitute(ret, kv.first, expr_t(0)); |
690 | } |
691 | } |
692 | return ret; |
693 | } |
694 | |
695 | // Returns a statement describing the loop nest of the schedule. |
696 | stmt_t create_loop_nest(const stmt_t &_body = stmt_t()) const { |
697 | stmt_t body = _body; |
698 | auto found_vars = find_unique_objects<var_t>(body); |
699 | auto skip_conds = skip_conditions_; |
700 | for (auto it = vars_.rbegin(); it != vars_.rend(); it++) { |
701 | auto &var = *it; |
702 | auto &loop = find_loop(var); |
703 | if (!loop.is_leaf() || loop.is_tensorized() || loop.is_bound()) |
704 | continue; |
705 | body = maybe_inject_let_for_fused_vars(body, loop); |
706 | auto cond_it = skip_conds.find(var); |
707 | if (cond_it != skip_conds.end()) { |
708 | auto skip_cond = cond_it->second; |
709 | cond_it->second = expr_t(); |
710 | auto if_stmt = if_t::make(skip_cond, funcs::_continue()); |
711 | body = if_stmt.append(body); |
712 | } else { |
713 | if (found_vars.count(var) == 0 |
714 | && to_cpp<int>(loop.bound()) == 1) |
715 | continue; |
716 | } |
717 | body = for_t::make( |
718 | var, 0, loop.bound(), body, loop.unroll_factor()); |
719 | } |
720 | |
721 | for (auto &kv : skip_conds) { |
722 | auto &c = kv.second; |
723 | ir_assert(c.is_empty()) << "Skip condition is not injected: " << c; |
724 | } |
725 | |
726 | return body; |
727 | } |
728 | |
729 | stmt_t create_bind_stmt(const stmt_t &_body = stmt_t()) const { |
730 | stmt_t body = _body; |
731 | for (auto it = vars_.rbegin(); it != vars_.rend(); it++) { |
732 | auto &var = *it; |
733 | auto &loop = find_loop(var); |
734 | if (!loop.is_leaf() || !loop.is_bound()) continue; |
735 | body = maybe_inject_let_for_fused_vars(body, loop); |
736 | body = let_t::make(var, loop.bound_var(), body); |
737 | } |
738 | return body; |
739 | } |
740 | |
741 | private: |
742 | enum class tile_level_t { kernel_grid, loop, thread_group, iter }; |
743 | |
744 | static int nesting_level(tile_level_t level) { |
745 | switch (level) { |
746 | case tile_level_t::kernel_grid: return 0; |
747 | case tile_level_t::loop: return 1; |
748 | case tile_level_t::thread_group: return 2; |
749 | case tile_level_t::iter: return 3; |
750 | default: ir_error_not_expected(); |
751 | } |
752 | return -1; |
753 | } |
754 | |
755 | static int nesting_level(loop_kind_t kind) { |
756 | switch (kind) { |
757 | case loop_kind_t::kernel_grid: return 0; |
758 | case loop_kind_t::serial: return 1; |
759 | case loop_kind_t::tg_grid: return 2; |
760 | case loop_kind_t::tensorized: return 3; |
761 | default: ir_error_not_expected(); |
762 | } |
763 | return -1; |
764 | } |
765 | |
766 | // Describes split of a root loop into sub-loops. |
767 | class split_info_t { |
768 | public: |
769 | int nloops() const { return int(loops_.size()); } |
770 | |
771 | void add_sub_loop( |
772 | const loop_t *loop, loop_kind_t loop_kind, int loop_level) { |
773 | loops_.push_back(loop); |
774 | loop_kinds_.push_back(loop_kind); |
775 | loop_levels_.push_back(loop_level); |
776 | } |
777 | |
778 | // Verifies that sub-loops are ordered from outermost to innermost |
779 | // according to the schedule conventions. There are three set of loops: |
780 | // 1) Loops bound to kernel grid |
781 | // 2) Loops bound to thread group grid and serial loops |
782 | // 3) Tensorized loops |
783 | // Sets of loops must be ordered from outermost to innermost going from |
784 | // 1 to 3. Inside a set loops can be ordered arbitrarily. |
785 | bool is_valid() const { |
786 | auto get_loop_key = [&](int loop_idx) { |
787 | switch (loop_kinds_[loop_idx]) { |
788 | case loop_kind_t::kernel_grid: return -1; |
789 | case loop_kind_t::tg_grid: |
790 | // FIXME |
791 | case loop_kind_t::serial: return 0; |
792 | case loop_kind_t::tensorized: |
793 | return std::numeric_limits<int>::max(); |
794 | default: ir_error_not_expected(); |
795 | } |
796 | return -1; |
797 | }; |
798 | int prev_key = -1; |
799 | for (int i = 0; i < nloops(); i++) { |
800 | int key = get_loop_key(i); |
801 | if (key < prev_key) return false; |
802 | prev_key = key; |
803 | } |
804 | return true; |
805 | } |
806 | |
807 | // Returns total extent of all loops at a given tile level. |
808 | dim_t dim(tile_level_t tile_level) const { |
809 | dim_t ret = 1; |
810 | int t_level = nesting_level(tile_level); |
811 | for (int i = 0; i < nloops(); i++) { |
812 | int i_level = nesting_level(loop_kinds_[i]); |
813 | if (i_level < t_level) continue; |
814 | ret *= to_cpp<dim_t>(loops_[i]->bound()); |
815 | } |
816 | return ret; |
817 | } |
818 | |
819 | // Returns initial offset expressed in the outer variables at a given |
820 | // tile level. |
821 | expr_t start(const expr_t &var_expanded, tile_level_t tile_level, |
822 | bool with_outer = true) const { |
823 | auto ret = var_expanded; |
824 | int t_level = nesting_level(tile_level); |
825 | for (int i = 0; i < nloops(); i++) { |
826 | int i_level = nesting_level(loop_kinds_[i]); |
827 | if (with_outer) { |
828 | if (i_level < t_level) continue; |
829 | } else { |
830 | if (i_level + 1 == t_level) continue; |
831 | } |
832 | ret = substitute(ret, loops_[i]->var(), expr_t(0)); |
833 | } |
834 | return simplify(ret); |
835 | } |
836 | |
837 | private: |
838 | std::vector<const loop_t *> loops_; |
839 | std::vector<loop_kind_t> loop_kinds_; |
840 | std::vector<int> loop_levels_; |
841 | }; |
842 | |
843 | bmnk_kind_t bmnk_kind(const std::vector<expr_t> &vars) const { |
844 | if (vars.empty()) return bmnk_kind_t::undef; |
845 | if (vars.size() == 1) return bmnk_mapper_.bmnk_kind(vars[0]); |
846 | bmnk_kind_t ret = bmnk_kind(vars[0]); |
847 | for (size_t i = 1; i < vars.size(); i++) { |
848 | if (bmnk_kind(vars[i]) != ret) return bmnk_kind_t::undef; |
849 | } |
850 | return ret; |
851 | } |
852 | |
853 | void set_bmnk_kind(const expr_t &var, bmnk_kind_t kind) { |
854 | bmnk_mapper_.set_bmnk_kind(var, kind); |
855 | } |
856 | |
857 | void set_abc_view(const view_t &view, view_t &abc_view) { |
858 | abc_view = view; |
859 | set_view(view); |
860 | } |
861 | |
862 | loop_kind_t bound_var_to_loop_kind(const expr_t &v) const { |
863 | for (int i = 0; i < kernel_grid_.ndims(); i++) { |
864 | if (kernel_grid_.idx(i).is_same(v)) return loop_kind_t::kernel_grid; |
865 | } |
866 | for (int i = 0; i < tg_grid_.ndims(); i++) { |
867 | if (tg_grid_.idx(i).is_same(v)) return loop_kind_t::tg_grid; |
868 | } |
869 | ir_error_not_expected() << "Unknown external variable: " << v; |
870 | return loop_kind_t::undef; |
871 | } |
872 | |
873 | int bound_var_to_dim(const expr_t &v) const { |
874 | for (int i = 0; i < kernel_grid_.ndims(); i++) { |
875 | if (kernel_grid_.idx(i).is_same(v)) return kernel_grid_.dim(i); |
876 | } |
877 | for (int i = 0; i < tg_grid_.ndims(); i++) { |
878 | if (tg_grid_.idx(i).is_same(v)) return tg_grid_.dim(i); |
879 | } |
880 | ir_error_not_expected() << "Unknown external variable: " << v; |
881 | return -1; |
882 | } |
883 | |
884 | bool has_loop(const expr_t &var) const { |
885 | auto it = loops_.find(var); |
886 | return it != loops_.end(); |
887 | } |
888 | |
889 | const loop_t &find_loop(const expr_t &var) const { |
890 | ir_assert(has_loop(var)) << "Var not found: " << var; |
891 | return loops_.at(var); |
892 | } |
893 | |
894 | loop_t &find_loop(const expr_t &var) { |
895 | ir_assert(has_loop(var)) << "Var not found: " << var; |
896 | return loops_[var]; |
897 | } |
898 | |
899 | int loop_level(const expr_t &var) const { |
900 | for (int i = 0; i < int(vars_.size()); i++) { |
901 | if (vars_[i].is_same(var)) return i; |
902 | } |
903 | return -1; |
904 | } |
905 | |
906 | loop_t &create_loop( |
907 | const expr_t &var, const expr_t &bound, bool is_root = false) { |
908 | loop_t loop(var, bound, is_root); |
909 | auto ret = loops_.insert({var, loop}); |
910 | ir_assert(ret.second) << "Variable already exists: " << var; |
911 | vars_.push_back(var); |
912 | return ret.first->second; |
913 | } |
914 | |
915 | static std::string strip_suffix( |
916 | const std::string &s, const std::string &suffix) { |
917 | auto pos = s.find(suffix); |
918 | if (pos == std::string::npos) return s; |
919 | if (pos + suffix.length() != s.length()) return s; |
920 | return s.substr(0, pos); |
921 | } |
922 | |
923 | static expr_t create_var( |
924 | const std::vector<expr_t> &vars, const std::string &suffix) { |
925 | std::string var_name; |
926 | for (auto &v : vars) { |
927 | auto name = strip_suffix(v.as<var_t>().name, "_idx" ); |
928 | var_name += name + "_" ; |
929 | } |
930 | var_name += suffix; |
931 | return var_t::make(type_t::s32(), var_name); |
932 | } |
933 | |
934 | void init_problem_tiles() { |
935 | object_map_t<expr_t, split_info_t> split_infos; |
936 | for (auto *view : {&a_view_, &b_view_, &c_view_}) { |
937 | for (auto &v : view->vvars()) { |
938 | if (split_infos.count(v) > 0) continue; |
939 | split_infos.insert({v, get_split_info(v)}); |
940 | } |
941 | } |
942 | a_tg_tile_ = compute_problem_tile( |
943 | a_view_.vvars(), split_infos, tile_level_t::thread_group); |
944 | b_tg_tile_ = compute_problem_tile( |
945 | b_view_.vvars(), split_infos, tile_level_t::thread_group); |
946 | c_tg_tile_ = compute_problem_tile( |
947 | c_view_.vvars(), split_infos, tile_level_t::thread_group); |
948 | a_thr_tile_ = compute_problem_tile( |
949 | a_view_.vvars(), split_infos, tile_level_t::iter); |
950 | b_thr_tile_ = compute_problem_tile( |
951 | b_view_.vvars(), split_infos, tile_level_t::iter); |
952 | c_thr_tile_ = compute_problem_tile( |
953 | c_view_.vvars(), split_infos, tile_level_t::iter); |
954 | } |
955 | |
956 | void init_constraint_set() { |
957 | for (auto &v : vars_) { |
958 | auto &loop = find_loop(v); |
959 | if (loop.is_fused_parent()) { |
960 | cset_->add_constraint(v >= 0); |
961 | cset_->add_constraint(v < loop.bound()); |
962 | continue; |
963 | } |
964 | if (!loop.is_leaf()) continue; |
965 | |
966 | // Fused variables are used only to initialize fused parents. |
967 | if (loop.is_fused_child()) continue; |
968 | |
969 | if (loop.is_bound()) { |
970 | cset_->add_constraint(v == loop.bound_var()); |
971 | continue; |
972 | } |
973 | |
974 | cset_->add_constraint(v >= 0); |
975 | cset_->add_constraint(v < loop.bound()); |
976 | } |
977 | } |
978 | |
979 | tensor_t view_tile(const view_t &view, tile_level_t level) const { |
980 | object_map_t<expr_t, split_info_t> split_infos; |
981 | for (auto &v : view.vvars()) { |
982 | if (split_infos.count(v) > 0) continue; |
983 | split_infos.insert({v, get_split_info(v)}); |
984 | } |
985 | return compute_problem_tile(view.vvars(), split_infos, level); |
986 | } |
987 | |
988 | split_info_t get_split_info(const expr_t &root_var) const { |
989 | split_info_t ret; |
990 | std::function<void(const expr_t &)> walk_down; |
991 | walk_down = [&](const expr_t &v) { |
992 | auto &loop = find_loop(v); |
993 | if (loop.is_leaf() || loop.is_fused_parent()) { |
994 | // Treat a fused var as leaf as it can't be split into other |
995 | // vars. |
996 | loop_kind_t kind = loop.kind(); |
997 | int level; |
998 | if (loop.is_fused_parent()) { |
999 | auto &child_var = loop.child_vars()[0]; |
1000 | ir_assert(find_loop(child_var).is_leaf()); |
1001 | kind = find_loop(child_var).kind(); |
1002 | level = loop_level(child_var); |
1003 | } else { |
1004 | level = loop_level(v); |
1005 | } |
1006 | ret.add_sub_loop(&loop, kind, level); |
1007 | } else if (loop.is_split_parent()) { |
1008 | walk_down(loop.child_vars()[0]); |
1009 | walk_down(loop.child_vars()[1]); |
1010 | } else { |
1011 | ir_error_not_expected(); |
1012 | } |
1013 | }; |
1014 | walk_down(root_var); |
1015 | ir_assert(ret.is_valid()) << "Invalid loop nest." ; |
1016 | return ret; |
1017 | } |
1018 | |
1019 | tensor_t compute_problem_tile(const std::vector<expr_t> &vars, |
1020 | const object_map_t<expr_t, split_info_t> &split_infos, |
1021 | tile_level_t tile_level) const { |
1022 | std::vector<dim_t> tile_dims; |
1023 | std::vector<expr_t> tile_start; |
1024 | bool with_outer = (tile_level == tile_level_t::thread_group); |
1025 | for (auto &v : vars) { |
1026 | auto &split_info = split_infos.at(v); |
1027 | tile_dims.push_back(split_info.dim(tile_level)); |
1028 | |
1029 | auto v_expanded = expand(v); |
1030 | tile_start.push_back( |
1031 | split_info.start(v_expanded, tile_level, with_outer)); |
1032 | } |
1033 | return tensor_t(tile_dims, tile_start); |
1034 | } |
1035 | |
1036 | stmt_t maybe_inject_let_for_fused_vars( |
1037 | const stmt_t &_body, const loop_t &loop) const { |
1038 | auto body = _body; |
1039 | if (!loop.is_leaf() || !loop.is_fused_child()) return body; |
1040 | auto &pvars = loop.parent_vars(); |
1041 | for (auto it = pvars.rbegin(); it != pvars.rend(); it++) { |
1042 | auto &ploop = find_loop(*it); |
1043 | body = let_t::make(*it, ploop.expand_var(loops_), body); |
1044 | } |
1045 | return body; |
1046 | } |
1047 | |
1048 | bool is_finalized_ = false; |
1049 | |
1050 | constraint_set_t *cset_; |
1051 | grid_info_t kernel_grid_; |
1052 | grid_info_t tg_grid_; |
1053 | |
1054 | // Loop indices, ordered from outermost to innermost. |
1055 | std::vector<expr_t> vars_; |
1056 | |
1057 | object_map_t<expr_t, loop_t> loops_; |
1058 | |
1059 | object_map_t<expr_t, expr_t> skip_conditions_; |
1060 | |
1061 | bmnk_mapper_t bmnk_mapper_; |
1062 | |
1063 | // Full views for A, B, C. |
1064 | view_t a_view_; |
1065 | view_t b_view_; |
1066 | view_t c_view_; |
1067 | |
1068 | // Thread group tiles for A, B, C. |
1069 | tensor_t a_tg_tile_; |
1070 | tensor_t b_tg_tile_; |
1071 | tensor_t c_tg_tile_; |
1072 | |
1073 | // Thread tiles for A, B, C (relative to thread group tiles). |
1074 | tensor_t a_thr_tile_; |
1075 | tensor_t b_thr_tile_; |
1076 | tensor_t c_thr_tile_; |
1077 | }; |
1078 | |
1079 | } // namespace jit |
1080 | } // namespace gpu |
1081 | } // namespace impl |
1082 | } // namespace dnnl |
1083 | |
1084 | #endif |
1085 | |