1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/reorder/ir_builder.hpp"
18
19#include <algorithm>
20#include <array>
21#include <iostream>
22#include <limits>
23#include <memory>
24#include <numeric>
25#include <utility>
26#include <vector>
27#include <unordered_map>
28
29#include "gpu/jit/ir/gemm_schedule.hpp"
30#include "gpu/jit/ir/ir.hpp"
31#include "gpu/jit/ir/message.hpp"
32#include "gpu/jit/ir/reorder.hpp"
33#include "gpu/jit/ir/tensor.hpp"
34#include "gpu/jit/pass/pass.hpp"
35#include "gpu/jit/utils/trace.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace gpu {
40namespace jit {
41
42class tile_helper_t {
43public:
44 tile_helper_t(const layout_t &l)
45 : l_(l)
46 , running_blocks_(l.blocks().size(), 1)
47 , blocks_(l.blocks().size(), 1) {
48 const auto &l_blocks = l.blocks();
49 const auto size = l_blocks.size();
50 while (block_idx_ < size && l_blocks[block_idx_].block == 1)
51 block_idx_++;
52 }
53
54 bool has_next() const { return block_idx_ < running_blocks_.size(); }
55
56 dim_t size() const {
57 dim_t ret = l_.size();
58 for (auto b : blocks_)
59 ret *= b;
60 return ret;
61 }
62
63 bool is_dense() const {
64 bool is_end = false;
65 for (size_t i = 0; i < blocks_.size(); i++) {
66 if (blocks_[i] == l_.blocks()[i].block) continue;
67 if (blocks_[i] != 1 && is_end) return false;
68 is_end = true;
69 }
70 return true;
71 }
72
73 tensor_t tile() const {
74 std::vector<dim_t> dims(l_.ndims(), 1);
75 for (size_t i = 0; i < blocks_.size(); i++) {
76 int dim_idx = l_.blocks()[i].dim_idx;
77 dims[dim_idx] *= blocks_[i];
78 }
79 return tensor_t(dims);
80 }
81
82 tensor_t next() {
83 dim_t l_block = l_.blocks()[block_idx_].block;
84 for (dim_t b = running_blocks_[block_idx_] + 1; b <= l_block; b++) {
85 if (l_block % b == 0) {
86 running_blocks_[block_idx_] = b;
87 return running_tile();
88 }
89 }
90 block_idx_++;
91 if (has_next()) return next();
92 return tensor_t();
93 }
94
95 void accept() { blocks_[block_idx_] = running_blocks_[block_idx_]; }
96
97 static bool can_be_mapped(const layout_t &l, const tensor_t &t) {
98 std::vector<dim_t> rem_dims = t.dims();
99 for (auto &b : l.blocks()) {
100 auto &rem_dim = rem_dims[b.dim_idx];
101 if (rem_dim >= b.block) {
102 if (rem_dim % b.block != 0) return false;
103 rem_dim /= b.block;
104 continue;
105 }
106 if (b.block % rem_dim != 0) return false;
107 rem_dim = 1;
108 }
109 for (auto d : rem_dims)
110 ir_assert(d == 1);
111 return true;
112 }
113
114 static tensor_t merge(const tensor_t &a, const tensor_t &b) {
115 std::vector<dim_t> dims(a.ndims());
116 for (int i = 0; i < a.ndims(); i++) {
117 dims[i] = std::max(a(i), b(i));
118 }
119 return tensor_t(dims);
120 }
121
122private:
123 tensor_t running_tile() const {
124 std::vector<dim_t> dims(l_.ndims(), 1);
125 for (size_t i = 0; i < block_idx_; i++) {
126 int dim_idx = l_.blocks()[i].dim_idx;
127 dims[dim_idx] *= blocks_[i];
128 }
129 int dim_idx = l_.blocks()[block_idx_].dim_idx;
130 dims[dim_idx] *= running_blocks_[block_idx_];
131 return tensor_t(dims);
132 }
133
134 const layout_t &l_;
135 std::vector<dim_t> running_blocks_;
136 std::vector<dim_t> blocks_;
137 size_t block_idx_ = 0;
138};
139
140void reorder_ir_builder_t::compute_blocks(const exec_config_t &exec_cfg,
141 const layout_t &src, const layout_t &dst, std::vector<int> &iter_blocks,
142 std::vector<int> &loop_blocks, std::vector<int> &tg_blocks,
143 int max_iter_tile_bytes, int max_thr_tile_bytes) {
144 if (max_iter_tile_bytes <= 0)
145 max_iter_tile_bytes = max_tile_size(exec_cfg.hw_cfg(), dst, src);
146 if (max_thr_tile_bytes <= 0)
147 max_thr_tile_bytes = max_tile_size(exec_cfg.hw_cfg(), dst, src);
148
149 ir_assert(src.ndims() == dst.ndims());
150 int ndims = src.ndims();
151 std::vector<dim_t> dims(ndims);
152 for (int i = 0; i < ndims; i++) {
153 dims[i] = std::max(src.dim(i), dst.dim(i));
154 }
155
156 // Pad src/dst layouts to match each other.
157 auto pad_layout = [&](const layout_t &l) {
158 std::vector<block_t> padded_blocks;
159 for (auto &eb : l.enumerated_blocks()) {
160 auto b = eb.second;
161 if (l.is_outermost(eb)) {
162 dim_t inner = l.dim(b.dim_idx) / b.block;
163 b.block = ir_utils::safe_divide(dims[b.dim_idx], inner);
164 }
165 padded_blocks.push_back(b);
166 }
167 return layout_t(
168 l.type(), ndims, 0, padded_blocks, /*do_normalize=*/false);
169 };
170 layout_t padded_src = pad_layout(src);
171 layout_t padded_dst = pad_layout(dst);
172 ir_assert(ir_utils::is_equal(padded_src.dims(), padded_dst.dims()));
173
174 int elems = padded_src.elems();
175 int max_type_size = std::max(src.type().size(), dst.type().size());
176 dim_t max_iter_tile_elems
177 = std::min(max_iter_tile_bytes / max_type_size, elems);
178 dim_t max_thr_tile_elems
179 = std::min(max_thr_tile_bytes / max_type_size, elems);
180
181 tile_helper_t src_th(padded_src);
182 tile_helper_t dst_th(padded_dst);
183
184 // Incrementally increase subtiles in src and dst. The goal is to find the
185 // maximum src/dst tiles so that the final combined tile covers dense
186 // regions as big as possible in src/dst layouts.
187 std::vector<tensor_t> candidate_tiles;
188 // To ensure there is at least one candidate.
189 candidate_tiles.emplace_back(std::vector<dim_t>(ndims, 1));
190 for (;;) {
191 if (!src_th.has_next() || !dst_th.has_next()) break;
192 tile_helper_t *th = &src_th;
193 bool src_dense = src_th.is_dense();
194 bool dst_dense = dst_th.is_dense();
195 // When both sublayouts are dense, try to increase the smallest tile.
196 // Otherwise, if there is a dense sublayout try to increase it.
197 if (src_dense && dst_dense && dst_th.size() < src_th.size()) {
198 th = &dst_th;
199 } else if (dst_dense && !src_dense) {
200 th = &dst_th;
201 }
202
203 auto tile = th->next();
204 auto &other_th = (th == &src_th ? dst_th : src_th);
205 tile = tile_helper_t::merge(tile, other_th.tile());
206 if (tile_helper_t::can_be_mapped(padded_src, tile)
207 && tile_helper_t::can_be_mapped(padded_dst, tile)) {
208 th->accept();
209 candidate_tiles.push_back(tile);
210 }
211 if (tile.elems() >= max_thr_tile_elems) break;
212 }
213
214 std::sort(candidate_tiles.begin(), candidate_tiles.end(),
215 [](const tensor_t &a, const tensor_t &b) {
216 return a.elems() > b.elems();
217 });
218
219 const tensor_t *thr_tile = nullptr;
220 const tensor_t *iter_tile = nullptr;
221 for (size_t i = 0; i < candidate_tiles.size(); i++) {
222 auto &t = candidate_tiles[i];
223 if (!thr_tile && t.elems() <= max_thr_tile_elems) thr_tile = &t;
224 if (thr_tile && !iter_tile && t.elems() <= max_iter_tile_elems
225 && thr_tile->is_divisible(t)) {
226 iter_tile = &t;
227 }
228 if (thr_tile && iter_tile) break;
229 }
230
231 ir_assert(thr_tile);
232 ir_assert(iter_tile);
233 std::vector<int> thr_blocks(
234 thr_tile->dims().begin(), thr_tile->dims().end());
235 iter_blocks.assign(iter_tile->dims().begin(), iter_tile->dims().end());
236
237 ir_assert(utils::array_product(iter_blocks) <= max_iter_tile_elems);
238 ir_assert(utils::array_product(thr_blocks) <= max_thr_tile_elems);
239
240 // Initialize loop blocks.
241 loop_blocks.resize(ndims, 1);
242 for (int i = 0; i < ndims; i++) {
243 loop_blocks[i] = ir_utils::safe_divide(thr_blocks[i], iter_blocks[i]);
244 }
245
246 // Initialize thread group blocks.
247 // Heuristic: try to split outer dimension and assign its
248 // inner part to the thread group. This may give better
249 // bandwidth utilization on XeHP/XeHPG.
250 tg_blocks.resize(ndims, 1);
251 const int tg_factor = 2;
252 for (int i = 0; i < ndims; i++) {
253 int outer = utils::div_up(dims[i], thr_blocks[i]);
254 if (outer % tg_factor == 0) {
255 tg_blocks[i] = tg_factor;
256 break;
257 }
258 }
259}
260
261void reorder_ir_builder_t::compute_blocks(const exec_config_t &exec_cfg,
262 const layout_t &src, const layout_t &dst, std::vector<int> &tile_blocks,
263 std::vector<int> &tg_blocks) {
264 std::vector<int> iter_blocks;
265 std::vector<int> loop_blocks;
266 compute_blocks(exec_cfg, src, dst, iter_blocks, loop_blocks, tg_blocks);
267 size_t n = iter_blocks.size();
268 tile_blocks.resize(n);
269 for (size_t i = 0; i < n; i++) {
270 tile_blocks[i] = iter_blocks[i] * loop_blocks[i];
271 }
272}
273
274void reorder_ir_builder_t::compute_grid(const layout_t &src,
275 const layout_t &dst, const std::vector<int> &iter_blocks,
276 const std::vector<int> &loop_blocks, const std::vector<int> &tg_blocks,
277 grid_info_t &kernel_grid, grid_info_t &tg_grid,
278 std::vector<int> *dim2grid) {
279 int ndims = src.ndims();
280 std::vector<dim_t> dims(ndims);
281 for (int i = 0; i < ndims; i++) {
282 dims[i] = std::max(src.dim(i), dst.dim(i));
283 }
284
285 if (dim2grid) dim2grid->resize(ndims, -1);
286
287 const int grid_ndims = 3;
288 std::vector<int> kernel_grid_dims(grid_ndims, 1);
289 std::vector<int> tg_grid_dims(grid_ndims, 1);
290 int grid_idx = 0;
291 int max_grid_idx = grid_ndims - 1;
292 for (int i = 0; i < ndims; i++) {
293 if (dim2grid) (*dim2grid)[i] = grid_idx;
294 int outer = utils::div_up(
295 dims[i], iter_blocks[i] * loop_blocks[i] * tg_blocks[i]);
296 tg_grid_dims[grid_idx] *= tg_blocks[i];
297 kernel_grid_dims[grid_idx] *= outer;
298 if (outer != 1 && grid_idx != max_grid_idx) grid_idx++;
299 }
300 kernel_grid = grid_info_t(kernel_grid_dims, "grid_idx");
301 tg_grid = grid_info_t(tg_grid_dims, "grid_idx");
302}
303
304compute::nd_range_t reorder_ir_builder_t::nd_range(
305 const exec_config_t &exec_cfg, const layout_t &src,
306 const layout_t &dst) {
307 const int simd = exec_cfg.simd();
308 std::vector<int> iter_blocks;
309 std::vector<int> loop_blocks;
310 std::vector<int> tg_blocks;
311 compute_blocks(exec_cfg, src, dst, iter_blocks, loop_blocks, tg_blocks);
312 grid_info_t kernel_grid;
313 grid_info_t tg_grid;
314 compute_grid(src, dst, iter_blocks, loop_blocks, tg_blocks, kernel_grid,
315 tg_grid);
316 std::array<size_t, 3> global;
317 std::array<size_t, 3> local;
318 for (int i = 0; i < kernel_grid.ndims(); i++) {
319 global[i] = kernel_grid[i] * tg_grid[i];
320 local[i] = tg_grid[i];
321 if (i == 0) {
322 global[i] *= simd;
323 local[i] *= simd;
324 }
325 }
326 return compute::nd_range_t(global.data(), local.data());
327}
328
329void reorder_ir_builder_t::build() {
330 std::vector<int> iter_blocks;
331 std::vector<int> loop_blocks;
332 std::vector<int> tg_blocks;
333 compute_blocks(exec_cfg_, src_layout_, dst_layout_, iter_blocks,
334 loop_blocks, tg_blocks);
335
336 int max_iters = 10;
337 int cur_iter_bytes
338 = max_tile_size(exec_cfg_.hw_cfg(), dst_layout_, src_layout_);
339 for (int i = 0; i < max_iters; i++) {
340 if (try_build(iter_blocks, loop_blocks, tg_blocks)) {
341 ir_info() << "Reorder configuration:" << std::endl;
342 ir_info() << " Source layout: " << src_layout_
343 << std::endl;
344 ir_info() << " Destination layout: " << dst_layout_
345 << std::endl;
346 ir_info() << " Iteration blocks: "
347 << ir_utils::make_seq_print_helper(iter_blocks, " x ")
348 << std::endl;
349 ir_info() << " Loop blocks: "
350 << ir_utils::make_seq_print_helper(loop_blocks, " x ")
351 << std::endl;
352 ir_info() << " Thread group blocks: "
353 << ir_utils::make_seq_print_helper(tg_blocks, " x ")
354 << std::endl;
355 return;
356 }
357
358 cur_iter_bytes /= 2;
359 while (cur_iter_bytes >= 1) {
360 std::vector<int> new_iter_blocks;
361 compute_blocks(exec_cfg_, src_layout_, dst_layout_, new_iter_blocks,
362 loop_blocks, tg_blocks, cur_iter_bytes);
363 if (!ir_utils::is_equal(new_iter_blocks, iter_blocks)) {
364 iter_blocks = new_iter_blocks;
365 break;
366 }
367 cur_iter_bytes /= 2;
368 }
369 }
370 ir_error_not_expected();
371}
372
373bool reorder_ir_builder_t::try_build(const std::vector<int> &iter_blocks,
374 const std::vector<int> &loop_blocks,
375 const std::vector<int> &tg_blocks) {
376 constraint_set_t init_cset;
377
378 int ndims = src_layout_.ndims();
379 std::vector<expr_t> vars;
380 for (int i = 0; i < ndims; i++) {
381 char letter = 'a' + i;
382 vars.push_back(var_t::make(type_t::s32(), std::string(1, letter)));
383 }
384
385 std::vector<int> dim2grid;
386 compute_grid(src_layout_, dst_layout_, iter_blocks, loop_blocks, tg_blocks,
387 kernel_grid_, tg_grid_, &dim2grid);
388
389 std::vector<stmt_t> init_stmts;
390 init_kernel_grid(
391 kernel_grid_, tg_grid_, exec_cfg_.simd(), init_cset, init_stmts);
392
393 auto &x = view_t::placeholder_var();
394
395 std::vector<dim_t> vdims(ndims);
396 for (int i = 0; i < ndims; i++) {
397 vdims[i] = std::max(src_layout_.dim(i), dst_layout_.dim(i));
398 }
399
400 view_t src_view(vars, ndims);
401 for (int i = 0; i < ndims; i++) {
402 int dim = src_layout_.dim(i);
403 src_view.set_vdim(vars[i], vdims[i]);
404 expr_t mask(true);
405 if (dim != vdims[i]) mask = x < dim;
406 src_view.set_tdim(i, vars[i], mask);
407 }
408 src_view.set_tlayout(src_layout_);
409
410 view_t dst_view(vars, ndims);
411 for (int i = 0; i < ndims; i++) {
412 int dim = dst_layout_.dim(i);
413 dst_view.set_vdim(vars[i], vdims[i]);
414 expr_t mask(true);
415 if (dim != vdims[i]) mask = x < dim;
416 dst_view.set_tdim(i, vars[i], mask);
417 }
418 dst_view.set_tlayout(dst_layout_);
419
420 gemm_schedule_t schedule(init_cset, kernel_grid_, tg_grid_);
421
422 schedule.set_view(src_view);
423 schedule.set_view(dst_view);
424
425 std::array<std::vector<expr_t>, 3> fused_idxs;
426 for (int i = 0; i < ndims; i++) {
427 std::vector<expr_t> ordered;
428 auto v = vars[i];
429 if (iter_blocks[i] != 1) {
430 expr_t outer, inner;
431 schedule.split(v, iter_blocks[i], outer, inner);
432 schedule.tensorize(inner);
433 v = outer;
434 ordered.insert(ordered.begin(), outer);
435 }
436 if (loop_blocks[i] != 1) {
437 if (!ordered.empty()) ordered.erase(ordered.begin());
438 expr_t outer, inner;
439 schedule.split(v, loop_blocks[i], outer, inner);
440 v = outer;
441 ordered.insert(ordered.begin(), inner);
442 ordered.insert(ordered.begin(), outer);
443 }
444 if (tg_blocks[i] != 1) {
445 if (!ordered.empty()) ordered.erase(ordered.begin());
446 expr_t outer, inner;
447 schedule.split(v, tg_blocks[i], outer, inner);
448 schedule.bind(inner, tg_grid_.idx(dim2grid[i]));
449 v = outer;
450 ordered.insert(ordered.begin(), inner);
451 ordered.insert(ordered.begin(), outer);
452 }
453 fused_idxs[dim2grid[i]].push_back(v);
454 schedule.reorder(ordered);
455 }
456
457 for (int i = 0; i < (int)fused_idxs.size(); i++) {
458 auto &vec = fused_idxs[i];
459 if (vec.empty()) continue;
460 auto var = (vec.size() == 1 ? vec[0] : schedule.fuse(vec));
461 schedule.bind(var, kernel_grid_.idx(i));
462 }
463
464 schedule.finalize();
465
466 auto thr_tile = schedule.thr_view_tile(src_view, /*is_relative=*/false);
467
468 auto src_thr_view = src_view.create_sub_view(thr_tile);
469 auto dst_thr_view = dst_view.create_sub_view(thr_tile);
470
471 auto src_buf = kernel_info_.arg_var(0);
472 auto dst_buf = kernel_info_.arg_var(1);
473
474 ir_context_t ir_ctx(exec_cfg_, init_cset);
475 auto reg_buf = ir_ctx.create_tmp_var(type_t::byte_ptr(), "reg");
476
477 std::vector<stmt_t> allocs;
478 for (int i = 0; i < kernel_info_.nargs(); i++) {
479 auto &var = kernel_info_.arg_var(i);
480 if (!var.type().is_ptr()) continue;
481 allocs.push_back(alloc_t::make(var, 0, alloc_kind_t::global));
482 }
483
484 auto read = make_access_builder(ir_ctx, src_thr_view, src_buf, reg_buf,
485 send_op_t::load, send_address_t::a64);
486 auto read_stmt = read.stmt();
487
488 auto write = make_access_builder(ir_ctx, dst_thr_view, dst_buf, reg_buf,
489 send_op_t::store, send_address_t::a64);
490 auto write_stmt = write.stmt();
491
492 auto read_layout = read.reg_layout();
493 auto write_layout = write.reg_layout();
494 allocs.push_back(
495 alloc_t::make(reg_buf, read_layout.size(), alloc_kind_t::grf));
496
497 if (read_layout != write_layout) {
498 auto tmp_buf = ir_ctx.create_tmp_var(type_t::byte_ptr(), "tmp");
499 allocs.push_back(
500 alloc_t::make(tmp_buf, write_layout.size(), alloc_kind_t::grf));
501
502 auto reorder_stmt = create_reorder_stmt(
503 read_layout, write_layout, reg_buf, tmp_buf);
504 write_stmt = substitute(write_stmt, reg_buf, tmp_buf);
505 write_stmt = reorder_stmt.append(write_stmt);
506 }
507
508 stmt_ = stmt_t();
509 stmt_ = stmt_.append(read_stmt);
510 stmt_ = stmt_.append(write_stmt);
511
512 stmt_ = schedule.create_loop_nest(stmt_);
513 stmt_ = schedule.create_bind_stmt(stmt_);
514 stmt_ = inject_let_stmts(stmt_, init_stmts);
515 stmt_ = inject_alloc_stmts(stmt_, allocs);
516 stmt_ = inject_external_var_let(stmt_, ir_ctx);
517
518 stmt_ = simplify(stmt_, ir_ctx);
519 stmt_ = lift_buffer_offsets_in_send(stmt_, ir_ctx);
520 stmt_ = inject_send(stmt_, ir_ctx);
521 stmt_ = split_wide_stores(stmt_, ir_ctx);
522 stmt_ = fix_int32_overflow(stmt_, ir_ctx);
523 stmt_ = eliminate_common_subexprs(
524 stmt_, ir_ctx, exec_cfg_.regs() * exec_cfg_.grf_size());
525 stmt_ = simplify(stmt_, ir_ctx);
526 stmt_ = optimize_alloc_let(stmt_, ir_ctx);
527 stmt_ = stmt_group_t::make(stmt_label_t::kernel(), stmt_);
528
529 int ir_usage = get_peak_grf_usage(stmt_, exec_cfg_.grf_size());
530 int reserved_usage = 16;
531 int grf_usage = ir_usage + reserved_usage;
532 if (grf_usage > exec_cfg_.regs()) {
533 ir_warning()
534 << "Estimated GRF usage is " << grf_usage
535 << " which exceeds available space, retry with a smaller tile."
536 << std::endl;
537
538 return false;
539 }
540
541 ir_trace() << "Reorder kernel body:\n" << stmt_ << std::endl;
542 return true;
543}
544
545} // namespace jit
546} // namespace gpu
547} // namespace impl
548} // namespace dnnl
549