1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/slm.hpp" |
18 | |
19 | #include "gpu/jit/ir/message.hpp" |
20 | #include "gpu/jit/ir/reorder.hpp" |
21 | #include "gpu/jit/ir/tensor.hpp" |
22 | #include "gpu/jit/utils/trace.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace jit { |
28 | |
29 | class slm_buffer_merger_t : public ir_mutator_t { |
30 | public: |
31 | slm_buffer_merger_t() { |
32 | slm_base_ = make_buffer("slm" ); |
33 | slm_off_.push_back(0); |
34 | } |
35 | |
36 | const expr_t &slm_base() const { return slm_base_; } |
37 | |
38 | int slm_size() const { return slm_size_; } |
39 | |
40 | object_t _mutate(const alloc_t &obj) override { |
41 | if (obj.kind != alloc_kind_t::slm) return ir_mutator_t::_mutate(obj); |
42 | |
43 | auto new_buf = push(obj); |
44 | auto new_obj = ir_mutator_t::_mutate(obj); |
45 | pop(); |
46 | |
47 | auto &alloc = new_obj.as<alloc_t>(); |
48 | new_obj = substitute(alloc.body, alloc.buf, new_buf); |
49 | |
50 | return new_obj; |
51 | } |
52 | |
53 | private: |
54 | expr_t push(const alloc_t &obj) { |
55 | int cur_off = slm_off_.back(); |
56 | expr_t new_buf = slm_base_ + cur_off; |
57 | slm_off_.push_back(cur_off + obj.size); |
58 | slm_size_ = std::max(slm_size_, cur_off + obj.size); |
59 | return new_buf; |
60 | } |
61 | |
62 | void pop() { slm_off_.pop_back(); } |
63 | |
64 | expr_t slm_base_; |
65 | std::vector<int> slm_off_; |
66 | int slm_size_ = 0; |
67 | }; |
68 | |
69 | stmt_t merge_slm_buffers(const stmt_t &_stmt, ir_context_t &ir_ctx) { |
70 | trace_start(); |
71 | stmt_t stmt = _stmt; |
72 | slm_buffer_merger_t merger; |
73 | stmt = merger.mutate(stmt); |
74 | stmt = alloc_t::make( |
75 | merger.slm_base(), merger.slm_size(), alloc_kind_t::slm, stmt); |
76 | trace_pass("merge_slm_buffers" , stmt, ir_ctx); |
77 | return stmt; |
78 | } |
79 | |
80 | class slm_reorder_injector_t : public ir_mutator_t { |
81 | public: |
82 | slm_reorder_injector_t( |
83 | const stmt_t &root, ngen::HW hw, const grid_info_t &tg_grid) |
84 | : hw_(hw), tg_grid_(tg_grid) { |
85 | alloc_manager_t alloc_mgr(root); |
86 | auto slm_buffers = alloc_mgr.find_buffers(alloc_kind_t::slm); |
87 | ir_assert(slm_buffers.size() == 1); |
88 | slm_base_ = slm_buffers[0]; |
89 | slm_size_ = alloc_mgr.total_size(alloc_kind_t::slm); |
90 | } |
91 | |
92 | const expr_t &slm_base() const { return slm_base_; } |
93 | |
94 | int slm_size() const { return slm_size_; } |
95 | |
96 | object_t _mutate(const func_call_t &obj) override { |
97 | if (!is_func_call<reorder_t>(obj)) return obj; |
98 | |
99 | auto &call = obj.as<func_call_t>(); |
100 | |
101 | auto stmt = create_slm_reorder(call.func.as<reorder_t>(), |
102 | reorder_t::arg_src_buf(call), reorder_t::arg_dst_buf(call)); |
103 | if (stmt.is_empty()) return obj; |
104 | return std::move(stmt); |
105 | } |
106 | |
107 | private: |
108 | stmt_t create_slm_reorder(const reorder_t &reorder, const expr_t &src_buf, |
109 | const expr_t &dst_buf) { |
110 | auto src = reorder.src_layout; |
111 | auto dst = reorder.dst_layout; |
112 | if (!src.is_dense() || !dst.is_dense()) return stmt_t(); |
113 | |
114 | layout_t::try_reinterpret_to_wider_type(src, dst); |
115 | if (src.type() != dst.type()) return stmt_t(); |
116 | if (src.type().size() != 4) return stmt_t(); |
117 | |
118 | layout_iterator_t src_it(src); |
119 | layout_iterator_t dst_it(dst); |
120 | |
121 | tensor_t max_tile; |
122 | for (;;) { |
123 | auto src_tile = src_it.tile(); |
124 | auto dst_tile = dst_it.tile(); |
125 | if (src_tile.is_equal(dst_tile)) { |
126 | auto s = src.map(src_it.tile()); |
127 | auto d = dst.map(dst_it.tile()); |
128 | if (s.is_dense() && d.is_dense() |
129 | && src_it.outer_layout() == dst_it.outer_layout()) { |
130 | if (is_slm_reorder_ok(s, d)) { max_tile = src_tile; } |
131 | } |
132 | if (!src_it.has_next() || !dst_it.has_next()) break; |
133 | ++src_it; |
134 | ++dst_it; |
135 | } else { |
136 | if (src_tile.elems() <= dst_tile.elems()) { |
137 | if (!src_it.has_next()) break; |
138 | ++src_it; |
139 | } else { |
140 | if (!dst_it.has_next()) break; |
141 | ++dst_it; |
142 | } |
143 | } |
144 | } |
145 | |
146 | if (max_tile.is_empty()) return stmt_t(); |
147 | |
148 | return create_slm_reorder(max_tile, src, dst, src_buf, dst_buf); |
149 | } |
150 | |
151 | stmt_t create_slm_reorder(const tensor_t &tile, const layout_t &src, |
152 | const layout_t &dst, const expr_t &src_buf, const expr_t &dst_buf) { |
153 | auto src_tile = src.map(tile); |
154 | auto &src_tile_blocks = src_tile.blocks(); |
155 | int simd = src_tile_blocks[0].block; |
156 | int vect_size = src_tile_blocks[1].block; |
157 | int tile_size = simd * vect_size * src.type().size(); |
158 | int slm_thr_size = (int)src.size(); |
159 | int dword_size = type_t::dword().size(); |
160 | int hword_size = type_t::hword().size(); |
161 | int hwords = tile_size / hword_size; |
162 | |
163 | ir_assert(tile_size % hword_size == 0); |
164 | |
165 | slm_size_ = std::max(slm_size_, slm_thr_size * tg_grid_.elems()); |
166 | |
167 | auto store_send = send_t::make(hw_, send_op_t::store, |
168 | send_address_t::slm, type_t::dword(vect_size), simd); |
169 | auto load_send = send_t::make(hw_, send_op_t::load, send_address_t::slm, |
170 | type_t::hword(hwords), 1); |
171 | |
172 | std::vector<expr_t> vec(simd); |
173 | for (int i = 0; i < simd; i++) |
174 | vec[i] = expr_t(i * vect_size * dword_size); |
175 | auto vec_off = shuffle_t::make(vec); |
176 | auto tid = tg_grid_.idx(1) * tg_grid_.dim(0) + tg_grid_.idx(0); |
177 | expr_t off0 = tid * slm_thr_size; |
178 | |
179 | stmt_t store_stmt; |
180 | stmt_t load_stmt; |
181 | src.for_each_tile(tile, [&](const std::vector<dim_t> &start) { |
182 | expr_t off = (int)src.offset_in_bytes(start); |
183 | auto store = store_send.call({slm_base_, |
184 | shuffle_t::make_broadcast(off0 + off, simd) + vec_off, |
185 | src_buf + off, expr_t()}); |
186 | auto load = load_send.call( |
187 | {slm_base_, off0 + off, dst_buf + off, expr_t()}); |
188 | store_stmt = store_stmt.append(store); |
189 | load_stmt = load_stmt.append(load); |
190 | }); |
191 | |
192 | auto ret = store_stmt.append(load_stmt); |
193 | return ret; |
194 | } |
195 | |
196 | bool is_slm_reorder_ok(const layout_t &src, const layout_t &dst) const { |
197 | auto &src_blocks = src.blocks(); |
198 | auto &dst_blocks = dst.blocks(); |
199 | if (src_blocks.size() != 2 || dst_blocks.size() != 2) return false; |
200 | auto &s0 = src_blocks[0]; |
201 | auto &s1 = src_blocks[1]; |
202 | auto &d0 = dst_blocks[0]; |
203 | auto &d1 = dst_blocks[1]; |
204 | |
205 | if (s0.dim_idx != d1.dim_idx || s1.dim_idx != d0.dim_idx) return false; |
206 | ir_assert(s0.block == d1.block); |
207 | ir_assert(s1.block == d0.block); |
208 | |
209 | int simd = s0.block; |
210 | int vec_size = s1.block; |
211 | if (!utils::one_of(simd, 16)) return false; |
212 | if (!utils::one_of(vec_size, 8)) return false; |
213 | |
214 | return true; |
215 | } |
216 | |
217 | ngen::HW hw_; |
218 | grid_info_t tg_grid_; |
219 | |
220 | expr_t slm_base_; |
221 | int slm_size_ = 0; |
222 | }; |
223 | |
224 | stmt_t inject_slm_reorder(const stmt_t &s, ir_context_t &ir_ctx, |
225 | const grid_info_t &tg_grid, bool has_slm_usage) { |
226 | trace_start(); |
227 | if (has_slm_usage) return s; |
228 | if (ir_ctx.hw() < ngen::HW::XeHPC) return s; |
229 | slm_reorder_injector_t injector(s, ir_ctx.hw(), tg_grid); |
230 | stmt_t ret = injector.mutate(s); |
231 | |
232 | auto &slm_buf = injector.slm_base(); |
233 | int slm_size = injector.slm_size(); |
234 | alloc_updater_t alloc_updater; |
235 | alloc_updater.resize(slm_buf, slm_size); |
236 | ret = alloc_updater.update(ret); |
237 | |
238 | trace_pass("inject_slm_reorder" , ret, ir_ctx); |
239 | return ret; |
240 | } |
241 | |
242 | } // namespace jit |
243 | } // namespace gpu |
244 | } // namespace impl |
245 | } // namespace dnnl |
246 | |