1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/slm.hpp"
18
19#include "gpu/jit/ir/message.hpp"
20#include "gpu/jit/ir/reorder.hpp"
21#include "gpu/jit/ir/tensor.hpp"
22#include "gpu/jit/utils/trace.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace jit {
28
29class slm_buffer_merger_t : public ir_mutator_t {
30public:
31 slm_buffer_merger_t() {
32 slm_base_ = make_buffer("slm");
33 slm_off_.push_back(0);
34 }
35
36 const expr_t &slm_base() const { return slm_base_; }
37
38 int slm_size() const { return slm_size_; }
39
40 object_t _mutate(const alloc_t &obj) override {
41 if (obj.kind != alloc_kind_t::slm) return ir_mutator_t::_mutate(obj);
42
43 auto new_buf = push(obj);
44 auto new_obj = ir_mutator_t::_mutate(obj);
45 pop();
46
47 auto &alloc = new_obj.as<alloc_t>();
48 new_obj = substitute(alloc.body, alloc.buf, new_buf);
49
50 return new_obj;
51 }
52
53private:
54 expr_t push(const alloc_t &obj) {
55 int cur_off = slm_off_.back();
56 expr_t new_buf = slm_base_ + cur_off;
57 slm_off_.push_back(cur_off + obj.size);
58 slm_size_ = std::max(slm_size_, cur_off + obj.size);
59 return new_buf;
60 }
61
62 void pop() { slm_off_.pop_back(); }
63
64 expr_t slm_base_;
65 std::vector<int> slm_off_;
66 int slm_size_ = 0;
67};
68
69stmt_t merge_slm_buffers(const stmt_t &_stmt, ir_context_t &ir_ctx) {
70 trace_start();
71 stmt_t stmt = _stmt;
72 slm_buffer_merger_t merger;
73 stmt = merger.mutate(stmt);
74 stmt = alloc_t::make(
75 merger.slm_base(), merger.slm_size(), alloc_kind_t::slm, stmt);
76 trace_pass("merge_slm_buffers", stmt, ir_ctx);
77 return stmt;
78}
79
80class slm_reorder_injector_t : public ir_mutator_t {
81public:
82 slm_reorder_injector_t(
83 const stmt_t &root, ngen::HW hw, const grid_info_t &tg_grid)
84 : hw_(hw), tg_grid_(tg_grid) {
85 alloc_manager_t alloc_mgr(root);
86 auto slm_buffers = alloc_mgr.find_buffers(alloc_kind_t::slm);
87 ir_assert(slm_buffers.size() == 1);
88 slm_base_ = slm_buffers[0];
89 slm_size_ = alloc_mgr.total_size(alloc_kind_t::slm);
90 }
91
92 const expr_t &slm_base() const { return slm_base_; }
93
94 int slm_size() const { return slm_size_; }
95
96 object_t _mutate(const func_call_t &obj) override {
97 if (!is_func_call<reorder_t>(obj)) return obj;
98
99 auto &call = obj.as<func_call_t>();
100
101 auto stmt = create_slm_reorder(call.func.as<reorder_t>(),
102 reorder_t::arg_src_buf(call), reorder_t::arg_dst_buf(call));
103 if (stmt.is_empty()) return obj;
104 return std::move(stmt);
105 }
106
107private:
108 stmt_t create_slm_reorder(const reorder_t &reorder, const expr_t &src_buf,
109 const expr_t &dst_buf) {
110 auto src = reorder.src_layout;
111 auto dst = reorder.dst_layout;
112 if (!src.is_dense() || !dst.is_dense()) return stmt_t();
113
114 layout_t::try_reinterpret_to_wider_type(src, dst);
115 if (src.type() != dst.type()) return stmt_t();
116 if (src.type().size() != 4) return stmt_t();
117
118 layout_iterator_t src_it(src);
119 layout_iterator_t dst_it(dst);
120
121 tensor_t max_tile;
122 for (;;) {
123 auto src_tile = src_it.tile();
124 auto dst_tile = dst_it.tile();
125 if (src_tile.is_equal(dst_tile)) {
126 auto s = src.map(src_it.tile());
127 auto d = dst.map(dst_it.tile());
128 if (s.is_dense() && d.is_dense()
129 && src_it.outer_layout() == dst_it.outer_layout()) {
130 if (is_slm_reorder_ok(s, d)) { max_tile = src_tile; }
131 }
132 if (!src_it.has_next() || !dst_it.has_next()) break;
133 ++src_it;
134 ++dst_it;
135 } else {
136 if (src_tile.elems() <= dst_tile.elems()) {
137 if (!src_it.has_next()) break;
138 ++src_it;
139 } else {
140 if (!dst_it.has_next()) break;
141 ++dst_it;
142 }
143 }
144 }
145
146 if (max_tile.is_empty()) return stmt_t();
147
148 return create_slm_reorder(max_tile, src, dst, src_buf, dst_buf);
149 }
150
151 stmt_t create_slm_reorder(const tensor_t &tile, const layout_t &src,
152 const layout_t &dst, const expr_t &src_buf, const expr_t &dst_buf) {
153 auto src_tile = src.map(tile);
154 auto &src_tile_blocks = src_tile.blocks();
155 int simd = src_tile_blocks[0].block;
156 int vect_size = src_tile_blocks[1].block;
157 int tile_size = simd * vect_size * src.type().size();
158 int slm_thr_size = (int)src.size();
159 int dword_size = type_t::dword().size();
160 int hword_size = type_t::hword().size();
161 int hwords = tile_size / hword_size;
162
163 ir_assert(tile_size % hword_size == 0);
164
165 slm_size_ = std::max(slm_size_, slm_thr_size * tg_grid_.elems());
166
167 auto store_send = send_t::make(hw_, send_op_t::store,
168 send_address_t::slm, type_t::dword(vect_size), simd);
169 auto load_send = send_t::make(hw_, send_op_t::load, send_address_t::slm,
170 type_t::hword(hwords), 1);
171
172 std::vector<expr_t> vec(simd);
173 for (int i = 0; i < simd; i++)
174 vec[i] = expr_t(i * vect_size * dword_size);
175 auto vec_off = shuffle_t::make(vec);
176 auto tid = tg_grid_.idx(1) * tg_grid_.dim(0) + tg_grid_.idx(0);
177 expr_t off0 = tid * slm_thr_size;
178
179 stmt_t store_stmt;
180 stmt_t load_stmt;
181 src.for_each_tile(tile, [&](const std::vector<dim_t> &start) {
182 expr_t off = (int)src.offset_in_bytes(start);
183 auto store = store_send.call({slm_base_,
184 shuffle_t::make_broadcast(off0 + off, simd) + vec_off,
185 src_buf + off, expr_t()});
186 auto load = load_send.call(
187 {slm_base_, off0 + off, dst_buf + off, expr_t()});
188 store_stmt = store_stmt.append(store);
189 load_stmt = load_stmt.append(load);
190 });
191
192 auto ret = store_stmt.append(load_stmt);
193 return ret;
194 }
195
196 bool is_slm_reorder_ok(const layout_t &src, const layout_t &dst) const {
197 auto &src_blocks = src.blocks();
198 auto &dst_blocks = dst.blocks();
199 if (src_blocks.size() != 2 || dst_blocks.size() != 2) return false;
200 auto &s0 = src_blocks[0];
201 auto &s1 = src_blocks[1];
202 auto &d0 = dst_blocks[0];
203 auto &d1 = dst_blocks[1];
204
205 if (s0.dim_idx != d1.dim_idx || s1.dim_idx != d0.dim_idx) return false;
206 ir_assert(s0.block == d1.block);
207 ir_assert(s1.block == d0.block);
208
209 int simd = s0.block;
210 int vec_size = s1.block;
211 if (!utils::one_of(simd, 16)) return false;
212 if (!utils::one_of(vec_size, 8)) return false;
213
214 return true;
215 }
216
217 ngen::HW hw_;
218 grid_info_t tg_grid_;
219
220 expr_t slm_base_;
221 int slm_size_ = 0;
222};
223
224stmt_t inject_slm_reorder(const stmt_t &s, ir_context_t &ir_ctx,
225 const grid_info_t &tg_grid, bool has_slm_usage) {
226 trace_start();
227 if (has_slm_usage) return s;
228 if (ir_ctx.hw() < ngen::HW::XeHPC) return s;
229 slm_reorder_injector_t injector(s, ir_ctx.hw(), tg_grid);
230 stmt_t ret = injector.mutate(s);
231
232 auto &slm_buf = injector.slm_base();
233 int slm_size = injector.slm_size();
234 alloc_updater_t alloc_updater;
235 alloc_updater.resize(slm_buf, slm_size);
236 ret = alloc_updater.update(ret);
237
238 trace_pass("inject_slm_reorder", ret, ir_ctx);
239 return ret;
240}
241
242} // namespace jit
243} // namespace gpu
244} // namespace impl
245} // namespace dnnl
246