1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/conv/ir_builder.hpp" |
18 | |
19 | #include <algorithm> |
20 | #include <array> |
21 | #include <iostream> |
22 | #include <limits> |
23 | #include <memory> |
24 | #include <numeric> |
25 | #include <utility> |
26 | #include <vector> |
27 | #include <unordered_map> |
28 | |
29 | #include "gpu/jit/conv/config.hpp" |
30 | #include "gpu/jit/conv/epilogue.hpp" |
31 | #include "gpu/jit/conv/pipeline.hpp" |
32 | #include "gpu/jit/conv/post_ops.hpp" |
33 | #include "gpu/jit/conv/slm_reduce_builder.hpp" |
34 | #include "gpu/jit/ir/fma.hpp" |
35 | #include "gpu/jit/ir/gemm_schedule.hpp" |
36 | #include "gpu/jit/ir/ir.hpp" |
37 | #include "gpu/jit/ir/message.hpp" |
38 | #include "gpu/jit/ir/mul_add.hpp" |
39 | #include "gpu/jit/ir/reduce.hpp" |
40 | #include "gpu/jit/ir/reorder.hpp" |
41 | #include "gpu/jit/ir/tensor.hpp" |
42 | #include "gpu/jit/pass/pass.hpp" |
43 | #include "gpu/jit/utils/trace.hpp" |
44 | |
45 | namespace dnnl { |
46 | namespace impl { |
47 | namespace gpu { |
48 | namespace jit { |
49 | |
50 | class buffer_access_verifier_t : public ir_visitor_t { |
51 | public: |
52 | void _visit(const alloc_t &obj) override { |
53 | buf_sizes_.emplace(obj.buf, obj.size); |
54 | ir_visitor_t::_visit(obj); |
55 | buf_sizes_.erase(obj.buf); |
56 | } |
57 | |
58 | void _visit(const func_call_t &obj) override { |
59 | auto &func = obj.func; |
60 | if (auto *dpas = func.as_ptr<dpas_t>()) { |
61 | auto &dst = dpas_t::arg_dst(obj); |
62 | auto &src0 = dpas_t::arg_src0(obj); |
63 | auto &src1 = dpas_t::arg_src1(obj); |
64 | auto &src2 = dpas_t::arg_src2(obj); |
65 | check_access(dst, dpas->dst_size(), obj); |
66 | if (!is_zero(src0)) check_access(src0, dpas->src0_size(), obj); |
67 | check_access(src1, dpas->src1_size(), obj); |
68 | check_access(src2, dpas->src2_size(), obj); |
69 | } else if (func.is<eltwise_t>()) { |
70 | auto &elems = eltwise_t::arg_elems(obj); |
71 | auto &data = eltwise_t::arg_data(obj); |
72 | int size = to_cpp<int>(elems) * sizeof(float); |
73 | check_access(data, size, obj); |
74 | } else if (auto *mad = func.as_ptr<mad_t>()) { |
75 | auto &dst = mad_t::arg_dst(obj); |
76 | auto &src0 = mad_t::arg_src0(obj); |
77 | auto &src1 = mad_t::arg_src1(obj); |
78 | auto &src2 = mad_t::arg_src2(obj); |
79 | check_access(dst, mad->dst_size(), obj); |
80 | if (!is_zero(src0)) check_access(src0, mad->src0_size(), obj); |
81 | check_access(src1, mad->src1_size(), obj); |
82 | check_access(src2, mad->src2_size(), obj); |
83 | } else if (auto *reduce = func.as_ptr<reduce_t>()) { |
84 | auto &dst_buf = reduce_t::arg_dst_buf(obj); |
85 | auto &src_buf = reduce_t::arg_src_buf(obj); |
86 | check_access(dst_buf, reduce->dst_layout.size(), obj); |
87 | check_access(src_buf, reduce->src_layout.size(), obj); |
88 | } else if (auto *reorder = func.as_ptr<reorder_t>()) { |
89 | auto &dst_buf = reorder_t::arg_dst_buf(obj); |
90 | auto &src_buf = reorder_t::arg_src_buf(obj); |
91 | check_access(dst_buf, reorder->dst_layout.size(), obj); |
92 | check_access(src_buf, reorder->src_layout.size(), obj); |
93 | return; |
94 | } else if (auto *send = func.as_ptr<send_t>()) { |
95 | if (!send->is_prefetch() && !send->is_prefetch_2d()) { |
96 | auto ®_buf = send_t::arg_reg_buf(obj); |
97 | int size = send->payload_size(); |
98 | check_access(reg_buf, size, obj); |
99 | } |
100 | return; |
101 | } else if (func.is<builtin_t>()) { |
102 | // No buffers to check. |
103 | } else { |
104 | ir_error_not_expected() << "Unhandled function: " << obj; |
105 | } |
106 | |
107 | ir_visitor_t::_visit(obj); |
108 | } |
109 | |
110 | void _visit(const load_t &obj) override { |
111 | auto elem_type = obj.type.scalar(); |
112 | int stride_bytes |
113 | = (obj.has_default_stride() ? elem_type.size() : obj.stride); |
114 | int off = to_cpp<int>(obj.off); |
115 | auto stride = (obj.type.elems() - 1) * stride_bytes; |
116 | check_access(obj.buf + off, stride + elem_type.size(), obj); |
117 | ir_visitor_t::_visit(obj); |
118 | } |
119 | |
120 | void _visit(const store_t &obj) override { |
121 | auto elem_type = obj.value.type().scalar(); |
122 | int stride_bytes |
123 | = (obj.has_default_stride() ? elem_type.size() : obj.stride); |
124 | int off = to_cpp<int>(obj.off); |
125 | auto stride = (obj.value.type().elems() - 1) * stride_bytes; |
126 | check_access(obj.buf + off, stride + elem_type.size(), obj); |
127 | ir_visitor_t::_visit(obj); |
128 | } |
129 | |
130 | private: |
131 | void check_access(const expr_t &buf, int size, const object_t &obj) { |
132 | auto &base = (is_var(buf) ? buf : buf.as<ptr_t>().base); |
133 | int off = (is_var(buf) ? 0 : to_cpp<int>(buf.as<ptr_t>().off)); |
134 | auto it = buf_sizes_.find(base); |
135 | ir_assert(it != buf_sizes_.end()) |
136 | << "Can't find allocation for buffer: " << buf; |
137 | int buf_size = it->second; |
138 | ir_assert(off + size <= buf_size) |
139 | << "Invalid access:\n " << obj << "\n Buffer " << base |
140 | << " has size: " << buf_size; |
141 | } |
142 | |
143 | object_map_t<expr_t, int> buf_sizes_; |
144 | }; |
145 | |
146 | void verify_buffer_access(const stmt_t &s, ir_context_t &ir_ctx) { |
147 | trace_start(); |
148 | buffer_access_verifier_t verifier; |
149 | verifier.visit(s); |
150 | trace_pass("verify_buffer_access" , s, ir_ctx); |
151 | } |
152 | |
153 | class multiply_builder_t { |
154 | public: |
155 | multiply_builder_t() = default; |
156 | |
157 | multiply_builder_t(const conv_config_t &cfg, |
158 | const bmnk_mapper_t &bmnk_mapper, const view_t &a_view, |
159 | const view_t &b_view, const expr_t &a_buf, const expr_t &b_buf, |
160 | const expr_t &c_buf) |
161 | : hw_(cfg.hw()) |
162 | , simd_size_(cfg.simd()) |
163 | , bmnk_mapper_(bmnk_mapper) |
164 | , a_view_(a_view) |
165 | , b_view_(b_view) |
166 | , a_buf_(a_buf) |
167 | , b_buf_(b_buf) |
168 | , c_buf_(c_buf) { |
169 | switch (cfg.fma_kind()) { |
170 | case fma_kind_t::dp4a: |
171 | case fma_kind_t::dpas: |
172 | case fma_kind_t::dpasw: |
173 | if (try_build_dpas()) return; |
174 | break; |
175 | case fma_kind_t::mad: |
176 | if (try_build_mad()) return; |
177 | break; |
178 | default: ir_error_not_expected() << "Unknown FMA kind." ; |
179 | } |
180 | |
181 | ir_error_not_expected() |
182 | << "Can't decompose into multiplication instructions. A view: " |
183 | << a_view << ". B view: " << b_view; |
184 | } |
185 | |
186 | const stmt_t &stmt() const { return stmt_; } |
187 | |
188 | const layout_t &c_layout() const { return c_layout_; } |
189 | |
190 | bool do_transpose() const { return do_transpose_; } |
191 | |
192 | std::string str() const { |
193 | std::ostringstream oss; |
194 | oss << "A view: " << a_view_ << std::endl; |
195 | oss << "B view: " << b_view_ << std::endl; |
196 | oss << "C layout: " << c_layout_ << std::endl; |
197 | oss << "Statement: " << std::endl << stmt_; |
198 | return oss.str(); |
199 | } |
200 | |
201 | private: |
202 | struct loop_info_t { |
203 | loop_info_t() = default; |
204 | |
205 | loop_info_t(const expr_t &var, bmnk_kind_t bmnk_kind, int dim) |
206 | : var(var), bmnk_kind(bmnk_kind), dim(dim) {} |
207 | |
208 | expr_t var; |
209 | bmnk_kind_t bmnk_kind; |
210 | |
211 | int dim; |
212 | int a_idx = -1; |
213 | int b_idx = -1; |
214 | int c_idx = -1; |
215 | int block = 1; |
216 | }; |
217 | |
218 | bool try_build_dpas() { |
219 | ir_assert(a_view_.can_convert_to_vlayout()) |
220 | << "Views are not supported with dpas/dpasw." ; |
221 | ir_assert(b_view_.can_convert_to_vlayout()) |
222 | << "Views are not supported with dpas/dpasw." ; |
223 | |
224 | auto a_layout = a_view_.create_vlayout(); |
225 | auto b_layout = b_view_.create_vlayout(); |
226 | |
227 | check_k_blocks_order(a_layout, b_layout); |
228 | |
229 | bmnk_block_mapper_t from_bmnk_mapper(bmnk_mapper_); |
230 | from_bmnk_mapper.push_blocks(abc_kind_t::a, a_layout.blocks()); |
231 | from_bmnk_mapper.push_blocks(abc_kind_t::b, b_layout.blocks()); |
232 | |
233 | // Convert to MNK layouts. |
234 | a_layout = bmnk_mapper_.map_to_bmnk( |
235 | abc_kind_t::a, {bmnk_kind_t::m, bmnk_kind_t::k}, a_layout); |
236 | b_layout = bmnk_mapper_.map_to_bmnk( |
237 | abc_kind_t::b, {bmnk_kind_t::k, bmnk_kind_t::n}, b_layout); |
238 | |
239 | multiply_desc_t desc(a_layout, b_layout, /*force_c_upconvert=*/true); |
240 | if (!dpas_t::matches_types( |
241 | hw_, desc.a_type(), desc.b_type(), desc.c_type())) |
242 | return false; |
243 | |
244 | int sdepth = 8; |
245 | int rcount = std::min(utils::rnd_up_pow2(desc.n()), 8); |
246 | auto _dpas = dpas_t::make(/*is_dpasw=*/false, simd_size_, sdepth, |
247 | rcount, desc.c_type(), desc.a_type(), desc.b_type()); |
248 | if (_dpas.as<dpas_t>().matches(desc)) { |
249 | build_dpas(from_bmnk_mapper, _dpas.as<dpas_t>(), desc); |
250 | return true; |
251 | } |
252 | |
253 | // Try to transpose and flip: C += A * B -> C^T = B^T * A^T. |
254 | rcount = std::min(desc.m(), 8); |
255 | desc = multiply_desc_t( |
256 | b_layout.transpose(), a_layout.transpose(), true); |
257 | _dpas = dpas_t::make(/*is_dpasw=*/false, /*exec_size=*/simd_size_, |
258 | sdepth, rcount, desc.c_type(), desc.a_type(), desc.b_type()); |
259 | |
260 | if (_dpas.as<dpas_t>().matches(desc)) { |
261 | do_transpose_ = true; |
262 | build_dpas(from_bmnk_mapper, _dpas.as<dpas_t>(), desc); |
263 | return true; |
264 | } |
265 | return false; |
266 | } |
267 | |
268 | void check_k_blocks_order(const layout_t &a, const layout_t &b) const { |
269 | object_map_t<expr_t, int> k_vars; |
270 | auto k_sub_layout = [&](abc_kind_t abc_kind, const layout_t &l) { |
271 | layout_t k_layout = layout_t(type_t::u8(), 0, |
272 | std::vector<dim_t>(layout_t::max_ndims, 1)); |
273 | for (auto &b : l.blocks()) { |
274 | auto bmnk_kind = bmnk_mapper_.bmnk_kind(abc_kind, b.dim_idx); |
275 | if (bmnk_kind != bmnk_kind_t::k) continue; |
276 | auto &var = bmnk_mapper_.var(abc_kind, b.dim_idx); |
277 | auto ret = k_vars.emplace(var, (int)k_vars.size()); |
278 | k_layout = k_layout.add_outer_block(ret.first->second, b.block); |
279 | } |
280 | return k_layout; |
281 | }; |
282 | auto a_k = k_sub_layout(abc_kind_t::a, a); |
283 | auto b_k = k_sub_layout(abc_kind_t::b, b); |
284 | ir_assert(a_k == b_k) |
285 | << "Order of K dimensions doesn't match in A and B. A layout: " |
286 | << a << ", B layout: " << b; |
287 | } |
288 | |
289 | void build_dpas(const bmnk_block_mapper_t &from_bmnk_mapper, |
290 | const dpas_t &dpas, const multiply_desc_t &desc) { |
291 | int m_blk = dpas.exec_size; |
292 | int n_blk = dpas.rcount; |
293 | int k_blk = dpas.sdepth * 4 / dpas.src1_type.size(); |
294 | |
295 | c_layout_ = compute_dpas_c_layout(m_blk, n_blk, dpas.c_layout(), desc); |
296 | |
297 | expr_t a_buf = a_buf_; |
298 | expr_t b_buf = b_buf_; |
299 | if (do_transpose_) std::swap(a_buf, b_buf); |
300 | auto dpas_tail = dpas_t::make(/*is_dpasw=*/false, dpas.exec_size, |
301 | dpas.sdepth, desc.n() > n_blk ? desc.n() % n_blk : n_blk, |
302 | dpas.dst_type, dpas.src1_type, dpas.src2_type); |
303 | |
304 | for (int i_k = 0; i_k < desc.k(); i_k += k_blk) { |
305 | for (int i_m = 0; i_m < desc.m(); i_m += m_blk) { |
306 | for (int i_n = 0; i_n < desc.n(); i_n += n_blk) { |
307 | std::vector<int> a_args = {i_m, i_k}; |
308 | std::vector<int> b_args = {i_k, i_n}; |
309 | std::vector<int> c_args = {i_m, i_n}; |
310 | auto a = a_buf[desc.a_layout()(a_args) |
311 | * desc.a_type().size()]; |
312 | auto b = b_buf[desc.b_layout()(b_args) |
313 | * desc.b_type().size()]; |
314 | auto c = c_buf_[c_layout_(c_args) * desc.c_type().size()]; |
315 | auto &_dpas = (i_n + n_blk > desc.n()) |
316 | ? dpas_tail.as<dpas_t>() |
317 | : dpas; |
318 | stmt_ = stmt_.append(_dpas(c, c, a, b)); |
319 | } |
320 | } |
321 | } |
322 | |
323 | // Transpose C layout back if needed. |
324 | if (do_transpose_) c_layout_ = c_layout_.transpose(); |
325 | |
326 | // Convert C layout back to problem notation. |
327 | c_layout_ = from_bmnk_mapper.map_from_bmnk( |
328 | abc_kind_t::c, {bmnk_kind_t::m, bmnk_kind_t::n}, c_layout_); |
329 | } |
330 | |
331 | static layout_t compute_dpas_c_layout(int m_blk, int n_blk, |
332 | const layout_t &blk_layout, const multiply_desc_t &desc) { |
333 | auto c_layout = blk_layout; |
334 | auto new_blocks = c_layout.blocks(); |
335 | if (new_blocks.size() > 1) new_blocks[1].block = desc.n(); |
336 | c_layout = layout_t(c_layout.type(), c_layout.ndims(), |
337 | c_layout.offset(), new_blocks); |
338 | c_layout = c_layout.add_outer_block(0, desc.m() / m_blk); |
339 | return c_layout; |
340 | } |
341 | |
342 | bool try_build_mad() { |
343 | auto loops = create_loop_nest(); |
344 | |
345 | if (try_build_mad_kmn_block_by_n(loops)) return true; |
346 | if (try_build_mad_kmn_block_by_b(loops)) return true; |
347 | |
348 | return false; |
349 | } |
350 | |
351 | std::vector<loop_info_t> create_loop_nest() const { |
352 | object_map_t<expr_t, loop_info_t> loops; |
353 | for (auto *view : {&a_view_, &b_view_}) { |
354 | abc_kind_t abc_kind |
355 | = (view == &a_view_ ? abc_kind_t::a : abc_kind_t::b); |
356 | for (int i = 0; i < view->nvdims(); i++) { |
357 | auto &var = bmnk_mapper_.var(abc_kind, i); |
358 | int dim = int(view->vdims()[i]); |
359 | if (dim == 1) continue; |
360 | |
361 | if (loops.count(var) > 0) continue; |
362 | loops[var] = loop_info_t(var, bmnk_mapper_.bmnk_kind(var), dim); |
363 | } |
364 | } |
365 | |
366 | std::vector<loop_info_t> ret; |
367 | for (auto &kv : loops) { |
368 | auto &loop = kv.second; |
369 | loop.a_idx = bmnk_mapper_.dim_idx(abc_kind_t::a, loop.var); |
370 | loop.b_idx = bmnk_mapper_.dim_idx(abc_kind_t::b, loop.var); |
371 | loop.c_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, loop.var); |
372 | ret.push_back(kv.second); |
373 | } |
374 | return ret; |
375 | } |
376 | |
377 | // Order of loops: BKMN, block by N. |
378 | bool try_build_mad_kmn_block_by_n(std::vector<loop_info_t> &_loops) { |
379 | return try_build_mad_impl(_loops, |
380 | {bmnk_kind_t::b, bmnk_kind_t::k, bmnk_kind_t::m, |
381 | bmnk_kind_t::n}, |
382 | bmnk_kind_t::n); |
383 | } |
384 | |
385 | // Order of loops: BKMN, block by B. |
386 | bool try_build_mad_kmn_block_by_b(std::vector<loop_info_t> &_loops) { |
387 | return try_build_mad_impl(_loops, |
388 | {bmnk_kind_t::b, bmnk_kind_t::k, bmnk_kind_t::m, |
389 | bmnk_kind_t::n}, |
390 | bmnk_kind_t::b); |
391 | } |
392 | |
393 | bool try_build_mad_impl(std::vector<loop_info_t> &_loops, |
394 | const std::vector<bmnk_kind_t> &loop_order, |
395 | bmnk_kind_t block_bmnk_kind) { |
396 | auto loops = _loops; |
397 | int nloops = int(loops.size()); |
398 | std::sort(loops.begin(), loops.end(), |
399 | [&](const loop_info_t &a, const loop_info_t &b) { |
400 | int a_key = ir_utils::find_index(loop_order, a.bmnk_kind); |
401 | int b_key = ir_utils::find_index(loop_order, b.bmnk_kind); |
402 | ir_assert(a_key != -1); |
403 | ir_assert(b_key != -1); |
404 | return a_key < b_key; |
405 | }); |
406 | |
407 | int block_idx = -1; |
408 | for (int i = 0; i < nloops; i++) { |
409 | auto &l = loops[i]; |
410 | if (l.bmnk_kind == block_bmnk_kind) { |
411 | ir_assert(block_idx == -1) << "Can't block 2+ dimensions." ; |
412 | block_idx = i; |
413 | } |
414 | } |
415 | |
416 | // Couldn't find N dimension, try different blocking scheme. |
417 | if (block_idx == -1) return false; |
418 | |
419 | auto &block_loop = loops[block_idx]; |
420 | |
421 | int block = simd_size_; |
422 | while (block >= 1) { |
423 | if (block_loop.dim % block == 0) break; |
424 | block /= 2; |
425 | } |
426 | |
427 | ir_assert(block >= 1) << "Invalid block size." ; |
428 | block_loop.block = block; |
429 | |
430 | int a_stride = 0; |
431 | int b_stride = 0; |
432 | |
433 | // Ensure that A tile is dense. |
434 | if (block_loop.a_idx != -1) { |
435 | std::vector<dim_t> tile_dims(a_view_.nvdims(), 1); |
436 | tile_dims[block_loop.a_idx] = block; |
437 | auto layout = a_view_.create_pseudo_vlayout(); |
438 | auto tile = layout.map(tensor_t(tile_dims)); |
439 | if (!is_1d_strided(tile)) return false; |
440 | a_stride = tile.blocks()[0].stride; |
441 | } |
442 | |
443 | // Ensure that B tile is dense. |
444 | if (block_loop.b_idx != -1) { |
445 | std::vector<dim_t> tile_dims(b_view_.nvdims(), 1); |
446 | tile_dims[block_loop.b_idx] = block; |
447 | auto layout = b_view_.create_pseudo_vlayout(); |
448 | auto tile = layout.map(tensor_t(tile_dims)); |
449 | if (!is_1d_strided(tile)) return false; |
450 | b_stride = tile.blocks()[0].stride; |
451 | } |
452 | |
453 | build_mad(loops, block_loop, a_stride, b_stride); |
454 | return true; |
455 | } |
456 | |
457 | static bool is_1d_strided(const layout_t &layout) { |
458 | auto &blocks = layout.blocks(); |
459 | if (blocks.size() > 1) return false; |
460 | return true; |
461 | } |
462 | |
463 | void build_mad(const std::vector<loop_info_t> &loops, |
464 | const loop_info_t &block_loop, int a_stride, int b_stride) { |
465 | ir_assert(utils::one_of( |
466 | block_loop.bmnk_kind, bmnk_kind_t::b, bmnk_kind_t::n)) |
467 | << "Unsupported blocking (expected blocking by B or N)." ; |
468 | |
469 | auto &a_type = a_view_.type(); |
470 | auto &b_type = b_view_.type(); |
471 | auto c_type = multiply_desc_t::get_c_type(a_type, b_type, |
472 | /*force_c_upconvert=*/false); |
473 | |
474 | int block = block_loop.block; |
475 | auto _mad = mad_t::make( |
476 | hw_, c_type, block, a_type, a_stride, b_type, b_stride); |
477 | auto &mad = _mad.as<mad_t>(); |
478 | |
479 | c_layout_ = compute_mad_c_layout(c_type, loops, block_loop); |
480 | |
481 | int nloops = int(loops.size()); |
482 | std::vector<int> bounds(loops.size()); |
483 | for (int i = 0; i < nloops; i++) { |
484 | bounds[i] = loops[i].dim / loops[i].block; |
485 | } |
486 | std::vector<int> a_idx(a_view_.nvdims()); |
487 | std::vector<int> b_idx(b_view_.nvdims()); |
488 | std::vector<int> c_idx(c_layout_.ndims()); |
489 | ir_utils::for_each(bounds, [&](const std::vector<int> &idx) { |
490 | for (int i = 0; i < nloops; i++) { |
491 | int full_idx = idx[i] * loops[i].block; |
492 | auto &loop = loops[i]; |
493 | if (loop.a_idx != -1) a_idx[loop.a_idx] = full_idx; |
494 | if (loop.b_idx != -1) b_idx[loop.b_idx] = full_idx; |
495 | if (loop.c_idx != -1) c_idx[loop.c_idx] = full_idx; |
496 | } |
497 | int a_off = a_view_(a_idx) * a_type.size(); |
498 | int b_off = b_view_(b_idx) * b_type.size(); |
499 | int c_off = c_layout_(c_idx) * c_type.size(); |
500 | stmt_ = stmt_.append(mad(c_buf_[c_off], c_buf_[c_off], |
501 | a_buf_[a_off], b_buf_[b_off])); |
502 | }); |
503 | } |
504 | |
505 | layout_t compute_mad_c_layout(const type_t &c_type, |
506 | const std::vector<loop_info_t> &loops, |
507 | const loop_info_t &block_loop) const { |
508 | layout_t c_layout(c_type, bmnk_mapper_.ndims(abc_kind_t::c), 0, {}); |
509 | |
510 | int c_dim_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, block_loop.var); |
511 | c_layout = c_layout.add_outer_block(c_dim_idx, block_loop.block); |
512 | |
513 | for (size_t i = 0; i < loops.size(); i++) { |
514 | if (loops[i].bmnk_kind == bmnk_kind_t::k) continue; |
515 | int dim_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, loops[i].var); |
516 | int bound = loops[i].dim / loops[i].block; |
517 | c_layout = c_layout.add_outer_block(dim_idx, bound); |
518 | } |
519 | return c_layout; |
520 | } |
521 | |
522 | ngen::HW hw_; |
523 | int simd_size_; |
524 | bmnk_mapper_t bmnk_mapper_; |
525 | |
526 | bool do_transpose_ = false; |
527 | |
528 | view_t a_view_; |
529 | view_t b_view_; |
530 | layout_t c_layout_; |
531 | |
532 | expr_t a_buf_; |
533 | expr_t b_buf_; |
534 | expr_t c_buf_; |
535 | |
536 | stmt_t stmt_; |
537 | }; |
538 | |
539 | class fma_helper_t { |
540 | public: |
541 | fma_helper_t(int simd_size, fma_kind_t fma_kind, const type_t &a_type, |
542 | const type_t &b_type, bool allow_a_grf_reorder, |
543 | bool allow_b_grf_reorder, bool is_src1_broadcast) |
544 | : simd_size_(simd_size) |
545 | , fma_kind_(fma_kind) |
546 | , a_type_(a_type) |
547 | , b_type_(b_type) |
548 | , allow_a_grf_reorder_(allow_a_grf_reorder) |
549 | , allow_b_grf_reorder_(allow_b_grf_reorder) |
550 | , is_src1_broadcast_(is_src1_broadcast) {} |
551 | |
552 | fma_kind_t fma_kind() const { return fma_kind_; } |
553 | |
554 | layout_t convert_to_fma_friendly_layout(const layout_t &layout, |
555 | abc_kind_t abc_kind, bool is_slm, const bmnk_mapper_t &bmnk_mapper, |
556 | bool *changed = nullptr) const { |
557 | bool allow_grf_reorder |
558 | = (abc_kind == abc_kind_t::a ? allow_a_grf_reorder_ |
559 | : allow_b_grf_reorder_); |
560 | if (changed) *changed = false; |
561 | if (!allow_grf_reorder) return layout; |
562 | |
563 | // GRF reorder is only supported with dpas/dpasw. |
564 | if (fma_kind_ == fma_kind_t::mad) { |
565 | if (is_slm) return layout; |
566 | // mad may require type conversion, supported for GRF layouts only. |
567 | return convert_to_fma_friendly_type(layout, abc_kind, changed); |
568 | } |
569 | |
570 | std::vector<bmnk_kind_t> bmnk_kinds; |
571 | if (abc_kind == abc_kind_t::a) { |
572 | bmnk_kinds.push_back(bmnk_kind_t::m); |
573 | bmnk_kinds.push_back(bmnk_kind_t::k); |
574 | } else { |
575 | bmnk_kinds.push_back(bmnk_kind_t::k); |
576 | bmnk_kinds.push_back(bmnk_kind_t::n); |
577 | } |
578 | |
579 | auto bmnk_layout |
580 | = bmnk_mapper.map_to_bmnk(abc_kind, bmnk_kinds, layout); |
581 | |
582 | auto dpas_layout = get_dpas_friendly_layout(bmnk_layout, abc_kind); |
583 | if (dpas_layout == bmnk_layout) return layout; |
584 | |
585 | if (changed) *changed = true; |
586 | |
587 | bmnk_block_mapper_t from_bmnk_mapper(bmnk_mapper); |
588 | from_bmnk_mapper.push_blocks(abc_kind, layout.blocks()); |
589 | |
590 | auto fma_layout = from_bmnk_mapper.map_from_bmnk( |
591 | abc_kind, bmnk_kinds, dpas_layout); |
592 | fma_layout = fma_layout.make_dense(); |
593 | return fma_layout; |
594 | } |
595 | |
596 | private: |
597 | layout_t convert_to_fma_friendly_type(const layout_t &layout, |
598 | abc_kind_t abc_kind, bool *changed = nullptr) const { |
599 | if (changed) *changed = false; |
600 | if (fma_kind_ != fma_kind_t::mad) return layout; |
601 | |
602 | // mad with s8/u8 is not supported, promote to strided s16. |
603 | if (a_type_.is_x8() && b_type_.is_x8()) { |
604 | if (changed) *changed = true; |
605 | return layout.retype(type_t::s16()).make_strided(2); |
606 | } |
607 | |
608 | // bf16 mixed mode mad requires src2 to be f32. |
609 | if (abc_kind == abc_kind_t::b && a_type_.is_bf16()) { |
610 | if (changed) *changed = true; |
611 | return layout.retype(type_t::f32()).make_dense(); |
612 | } |
613 | |
614 | // bf16 mixed mode mad requires src1 to be packed, when src1 is |
615 | // broadcasted it needs to be converted to f32. |
616 | if (abc_kind == abc_kind_t::a && a_type_.is_bf16() |
617 | && is_src1_broadcast_) { |
618 | if (changed) *changed = true; |
619 | return layout.retype(type_t::f32()).make_dense(); |
620 | } |
621 | |
622 | // Ensure the layout is dense to align regioning. |
623 | if (!layout.is_dense()) { |
624 | if (changed) *changed = true; |
625 | return layout.make_dense(); |
626 | } |
627 | |
628 | return layout; |
629 | } |
630 | |
631 | layout_t get_dpas_friendly_layout( |
632 | const layout_t &bmnk_layout, abc_kind_t abc_kind) const { |
633 | bool is_a = (abc_kind == abc_kind_t::a); |
634 | int mn_idx = (is_a ? 0 : 1); |
635 | int k_idx = (is_a ? 1 : 0); |
636 | |
637 | layout_t dpas_layout; |
638 | dim_t mn_blk = bmnk_layout.dim(mn_idx); |
639 | dim_t k_blk = bmnk_layout.dim(k_idx); |
640 | |
641 | std::vector<int> try_rcount; |
642 | if (is_a && mn_blk % 8 == 0) try_rcount.push_back(8); |
643 | // Cannot calculate correct r_count when !is_a, but rcount is |
644 | // effectively ignored in that case as rcount mainly affects b_layout. |
645 | // Also note that rcount used here may not be supported in hardware and |
646 | // is used solely to compute layout. |
647 | try_rcount.push_back(is_a ? mn_blk : 8); |
648 | for (int rcount : try_rcount) { |
649 | auto _dpas = dpas_t::make(/*is_dpasw=*/false, simd_size_, |
650 | /*sdepth=*/8, rcount, type_t::undef(), b_type_, a_type_); |
651 | auto &dpas = _dpas.as<dpas_t>(); |
652 | |
653 | dpas_layout = (is_a ? dpas.b_layout() : dpas.a_layout()); |
654 | dpas_layout = dpas_layout.transpose(); |
655 | |
656 | auto default_layout = bmnk_layout.retype(is_a ? a_type_ : b_type_); |
657 | if (dpas_layout <= default_layout) return default_layout; |
658 | } |
659 | |
660 | dim_t dpas_mn_blk = dpas_layout.dim(mn_idx); |
661 | dim_t dpas_k_blk = dpas_layout.dim(k_idx); |
662 | ir_assert(k_blk % dpas_k_blk == 0); |
663 | |
664 | dim_t k_outer = ir_utils::safe_divide(k_blk, dpas_k_blk); |
665 | dim_t mn_outer = ir_utils::safe_divide(mn_blk, dpas_mn_blk); |
666 | dpas_layout = dpas_layout.add_outer_block(k_idx, k_outer); |
667 | dpas_layout = dpas_layout.add_outer_block(mn_idx, mn_outer); |
668 | return dpas_layout; |
669 | } |
670 | |
671 | int simd_size_; |
672 | fma_kind_t fma_kind_; |
673 | type_t a_type_; |
674 | type_t b_type_; |
675 | bool allow_a_grf_reorder_; |
676 | bool allow_b_grf_reorder_; |
677 | bool is_src1_broadcast_; |
678 | }; |
679 | |
680 | class b_reduce_context_t { |
681 | public: |
682 | b_reduce_context_t(ir_context_t &ir_ctx, const conv_config_t &cfg) |
683 | : ir_ctx_(ir_ctx), cfg_(cfg), reduce_condition_(true) { |
684 | if (cfg_.reduce_b()) b_reduced_reg_buf_ = make_buffer("b_reduced" ); |
685 | } |
686 | |
687 | // Setters for B reduced memory buffer/view. |
688 | void set_b_reduced_mem_buf(const expr_t &buf) { b_reduced_mem_buf_ = buf; } |
689 | void set_b_reduced_view(const view_t &v) { b_reduced_view_ = v; } |
690 | |
691 | // Sets the condition to update B reduced output. Reduction is done across |
692 | // K for B (KxN tensor) so M dimension should be checked before the update. |
693 | void set_reduce_condition(const expr_t &cond) { reduce_condition_ = cond; } |
694 | |
695 | // Global memory buffer. |
696 | const expr_t &b_reduced_mem_buf() const { return b_reduced_mem_buf_; } |
697 | |
698 | // Register buffer. |
699 | const expr_t &b_reduced_reg_buf() const { return b_reduced_reg_buf_; } |
700 | int b_reduced_size() const { return b_reduced_size_; } |
701 | |
702 | // Memory view. |
703 | const view_t &b_reduced_thr_view() const { return b_reduced_thr_view_; } |
704 | |
705 | // Register layout. |
706 | const layout_t &b_reduced_reg_layout() const { |
707 | return b_reduced_reg_layout_; |
708 | } |
709 | |
710 | void init_reduced_thr_view( |
711 | const tensor_t &b_thr_tile, const expr_t &cond = expr_t()) { |
712 | ir_assert(b_reduced_thr_view_.is_empty()) << "Can't initialize twice." ; |
713 | |
714 | auto b_reduced_thr_tile = b_to_b_reduced_tile(b_thr_tile); |
715 | b_reduced_thr_view_ |
716 | = b_reduced_view_.create_sub_view(b_reduced_thr_tile); |
717 | b_reduced_reg_layout_ = b_reduced_thr_view_.create_dense_vlayout(); |
718 | b_reduced_size_ = b_reduced_reg_layout_.size(); |
719 | b_reduced_size_ = utils::rnd_up(b_reduced_size_, cfg_.grf_size()); |
720 | |
721 | if (!cond.is_empty()) reduce_condition_ &= cond; |
722 | } |
723 | |
724 | stmt_t create_reduce_stmt(const layout_t &b_layout, const expr_t &b_buf, |
725 | const tensor_t &subtile = tensor_t()) { |
726 | auto reduction_stmt |
727 | = jit::create_reduce_stmt(b_layout, b_reduced_reg_layout_, |
728 | b_buf, b_reduced_reg_buf_, subtile, reduction_mask_); |
729 | return reduction_stmt; |
730 | } |
731 | |
732 | stmt_t create_store_stmt(bool use_atomic) const { |
733 | auto r2g = make_access_builder(ir_ctx_, b_reduced_thr_view_, |
734 | b_reduced_mem_buf_, b_reduced_reg_buf_, |
735 | use_atomic ? send_op_t::atomic_fadd : send_op_t::store, |
736 | send_address_t::a64); |
737 | // TODO: Check that layouts match. |
738 | auto ret = r2g.stmt(); |
739 | if (!reduce_condition_.is_empty()) { |
740 | ret = if_t::make(reduce_condition_, ret); |
741 | } |
742 | return ret; |
743 | } |
744 | |
745 | private: |
746 | tensor_t b_to_b_reduced_tile(const tensor_t &b_tile) const { |
747 | std::vector<dim_t> dims; |
748 | std::vector<expr_t> start; |
749 | for (int i = 0; i < b_tile.ndims(); i++) { |
750 | if ((reduction_mask_ & (1 << i)) != 0) { |
751 | dims.push_back(b_tile(i)); |
752 | start.push_back(b_tile.start(i)); |
753 | } |
754 | } |
755 | return tensor_t(dims, start); |
756 | } |
757 | |
758 | ir_context_t &ir_ctx_; |
759 | const conv_config_t &cfg_; |
760 | |
761 | expr_t reduce_condition_; |
762 | |
763 | expr_t b_reduced_mem_buf_; |
764 | expr_t b_reduced_reg_buf_; |
765 | |
766 | view_t b_reduced_view_; |
767 | view_t b_reduced_thr_view_; |
768 | |
769 | layout_t b_reduced_reg_layout_; |
770 | int b_reduced_size_ = 0; |
771 | |
772 | uint32_t reduction_mask_ = (1 << 1) | (1 << 2); |
773 | }; |
774 | |
775 | class subtile_info_t { |
776 | public: |
777 | using post_load_func_t = std::function<stmt_t( |
778 | const layout_t &, const expr_t &, const tensor_t &)>; |
779 | |
780 | subtile_info_t(ir_context_t &ir_ctx, const gemm_schedule_t &gemm_schedule, |
781 | const fma_helper_t &fma_helper, abc_kind_t abc_kind, bool use_slm, |
782 | bool load_buffered, bool allow_2d_load, int idx, |
783 | const view_t &mem_view, const tensor_t &subtile, |
784 | const expr_t &mem_buf, const expr_t &slm_buf, const expr_t ®_buf, |
785 | const expr_t &tmp_buf) |
786 | : ir_ctx_(ir_ctx) |
787 | , gemm_schedule_(gemm_schedule) |
788 | , fma_helper_(fma_helper) |
789 | , abc_kind_(abc_kind) |
790 | , use_slm_(use_slm) |
791 | , load_buffered_(load_buffered) |
792 | , allow_2d_load_(allow_2d_load) |
793 | , idx_(idx) |
794 | , mem_view_(mem_view) |
795 | , subtile_(subtile) |
796 | , mem_buf_(mem_buf) |
797 | , slm_buf_(slm_buf) |
798 | , reg_buf_(reg_buf) |
799 | , tmp_buf_(tmp_buf) {} |
800 | |
801 | bool is_loaded() const { return is_loaded_; } |
802 | |
803 | void set_loaded() { is_loaded_ = true; } |
804 | |
805 | const view_t ®_view() const { return reg_view_; } |
806 | |
807 | int reg_buf_size() const { |
808 | return utils::rnd_up(reg_layout_.size(), ir_ctx_.grf_size()); |
809 | } |
810 | |
811 | int tmp_buf_size() const { return tmp_buf_size_; } |
812 | |
813 | const stmt_t &s2r_load() const { return s2r_load_; } |
814 | |
815 | const stmt_t &g2r_load() const { return g2r_load_; } |
816 | |
817 | const send_hint_t &send_hint() const { return send_hint_; } |
818 | |
819 | void load(const post_load_func_t &post_load = post_load_func_t()) { |
820 | auto &bmnk_mapper = gemm_schedule_.bmnk_mapper(); |
821 | |
822 | layout_t load_layout; |
823 | stmt_t &stmt = (use_slm_ ? s2r_load_ : g2r_load_); |
824 | load_impl(ir_ctx_, load_layout, reg_view_, send_hint_, stmt); |
825 | |
826 | if (post_load) { |
827 | stmt = stmt.append(post_load(load_layout, reg_buf_, subtile_)); |
828 | } |
829 | |
830 | reg_layout_ = load_layout; |
831 | |
832 | bool changed; |
833 | auto fma_layout = fma_helper_.convert_to_fma_friendly_layout( |
834 | reg_layout_, abc_kind_, |
835 | /*is_slm=*/false, bmnk_mapper, &changed); |
836 | |
837 | if (changed) { |
838 | bool is_reorder_nop |
839 | = fma_layout.retype(reg_layout_.type()) == reg_layout_ |
840 | && reg_layout_.type().is_bitwise_compatible( |
841 | fma_layout.type()); |
842 | |
843 | if (fma_layout.type() != reg_layout_.type()) { |
844 | reg_view_ = reg_view_.retype(fma_layout.type()); |
845 | } |
846 | reg_layout_ = fma_layout; |
847 | reg_view_.set_tlayout(reg_layout_); |
848 | if (!is_reorder_nop) { |
849 | stmt = substitute(stmt, reg_buf_, tmp_buf_); |
850 | stmt = stmt.append(create_reorder_stmt( |
851 | load_layout, reg_layout_, tmp_buf_, reg_buf_)); |
852 | int load_reg_size = int(load_layout.size()); |
853 | load_reg_size |
854 | = utils::rnd_up(load_reg_size, ir_ctx_.grf_size()); |
855 | tmp_buf_size_ = std::max(tmp_buf_size_, load_reg_size); |
856 | } |
857 | } |
858 | } |
859 | |
860 | private: |
861 | void load_impl(ir_context_t &ir_ctx, layout_t &load_layout, |
862 | view_t &load_view, send_hint_t &send_hint, stmt_t &stmt) const { |
863 | view_t mem_view = mem_view_; |
864 | if (load_buffered_) |
865 | mem_view_.try_create_buffer_view(mem_view, load_view); |
866 | |
867 | send_op_t send_op = send_op_t::load; |
868 | send_hint = get_send_hint(ir_ctx_.exec_cfg(), send_op_t::load, |
869 | fma_helper_.fma_kind(), abc_kind_, mem_view, gemm_schedule_, |
870 | allow_2d_load_); |
871 | auto read = make_access_builder(ir_ctx, mem_view, |
872 | use_slm_ ? slm_buf_ : mem_buf_, reg_buf_, send_op, |
873 | use_slm_ ? send_address_t::slm : send_address_t::a64, |
874 | send_hint); |
875 | ir_trace() << (abc_kind_ == abc_kind_t::a ? "A" : "B" ) |
876 | << " GMEM/SLM to GRF load #" << idx_ << ":\n" |
877 | << read.str() << std::endl; |
878 | |
879 | load_layout = read.reg_layout(); |
880 | if (!load_view.is_empty()) { |
881 | load_view.set_tlayout(load_layout); |
882 | } else { |
883 | load_view = view_t(load_layout); |
884 | } |
885 | stmt = read.stmt(); |
886 | } |
887 | |
888 | ir_context_t &ir_ctx_; |
889 | const gemm_schedule_t &gemm_schedule_; |
890 | const fma_helper_t &fma_helper_; |
891 | abc_kind_t abc_kind_; |
892 | bool use_slm_; |
893 | bool load_buffered_; |
894 | bool allow_2d_load_; |
895 | int idx_; |
896 | view_t mem_view_; |
897 | tensor_t subtile_; |
898 | |
899 | expr_t mem_buf_; |
900 | expr_t slm_buf_; |
901 | expr_t reg_buf_; |
902 | expr_t tmp_buf_; |
903 | |
904 | bool is_loaded_ = false; |
905 | view_t reg_view_; |
906 | layout_t reg_layout_; |
907 | int tmp_buf_size_ = 0; |
908 | stmt_t s2r_load_; |
909 | stmt_t g2r_load_; |
910 | send_hint_t send_hint_; |
911 | }; |
912 | |
913 | class load_multiply_builder_t { |
914 | public: |
915 | load_multiply_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx, |
916 | const gemm_schedule_t &gemm_schedule, |
917 | const fma_helper_t &fma_helper, b_reduce_context_t &b_reduce_ctx, |
918 | const expr_t &ap_buf, const expr_t &a_slm_buf, const expr_t &bp_buf, |
919 | const expr_t &b_slm_buf, const view_t &ap_x_view, |
920 | const view_t &bp_x_view, const kernel_info_t &kernel_info) |
921 | : prb_(cfg.prb()) |
922 | , cfg_(cfg) |
923 | , ir_ctx_(ir_ctx) |
924 | , gemm_schedule_(gemm_schedule) |
925 | , fma_helper_(fma_helper) |
926 | , b_reduce_ctx_(b_reduce_ctx) |
927 | , ap_buf_(ap_buf) |
928 | , a_slm_buf_(a_slm_buf) |
929 | , bp_buf_(bp_buf) |
930 | , b_slm_buf_(b_slm_buf) |
931 | , kernel_info_(kernel_info) { |
932 | ir_assert(cfg_.subtiles().a() == 1 || cfg_.subtiles().b() == 1) |
933 | << "At most one tensor can be tiled." ; |
934 | |
935 | ab_tmp_buf_ = make_buffer("ab_tmp" ); |
936 | a_buf_ = make_buffer("a" ); |
937 | b_buf_ = make_buffer("b" ); |
938 | c_buf_ = make_buffer("c" ); |
939 | |
940 | // Views to multiply by a thread. |
941 | a_thr_view_ = ap_x_view.create_sub_view(gemm_schedule_.a_thr_tile()); |
942 | b_thr_view_ = bp_x_view.create_sub_view(gemm_schedule_.b_thr_tile()); |
943 | |
944 | // Initialize view for reduced B. |
945 | if (cfg_.reduce_b() && !cfg_.slm().b()) { |
946 | b_reduce_ctx_.init_reduced_thr_view( |
947 | gemm_schedule_.b_thr_tile(/*is_relative=*/false)); |
948 | } |
949 | |
950 | // TODO: Specify loops over subtiles in the schedule, use unrolling. |
951 | // Sub-tile indices. |
952 | a_idx_ = ir_ctx_.create_tmp_var(type_t::s32(), "a_idx" ); |
953 | b_idx_ = ir_ctx_.create_tmp_var(type_t::s32(), "b_idx" ); |
954 | |
955 | // Sub-tile views. |
956 | a_i_view_ = create_subtile_view(abc_kind_t::a, a_thr_view_, |
957 | cfg_.subtiles().a(), a_idx_, bmnk_kind_t::m, &a_i_outer_blocks_, |
958 | a_i_tile_); |
959 | b_j_view_ = create_subtile_view(abc_kind_t::b, b_thr_view_, |
960 | cfg_.subtiles().b(), b_idx_, bmnk_kind_t::n, &b_j_outer_blocks_, |
961 | b_j_tile_); |
962 | |
963 | build(); |
964 | } |
965 | |
966 | const std::vector<stmt_t> &allocs() const { return allocs_; } |
967 | |
968 | const stmt_t &load_mul_stmt() const { return load_mul_stmt_; } |
969 | |
970 | const expr_t &c_buf() const { return c_buf_; } |
971 | |
972 | const layout_t &c_reg_layout() const { return c_reg_layout_; } |
973 | |
974 | private: |
975 | view_t create_subtile_view(abc_kind_t abc_kind, const view_t &thr_view, |
976 | int subtiles, const expr_t &idx, bmnk_kind_t bmnk_kind, |
977 | std::vector<block_t> *outer_blocks, tensor_t &subtile) const { |
978 | auto &bmnk_mapper = gemm_schedule_.bmnk_mapper(); |
979 | auto layout = thr_view.create_pseudo_vlayout(); |
980 | dim_t mn_dim = 1; |
981 | for (auto &b : layout.blocks()) { |
982 | auto b_bmnk_kind = bmnk_mapper.bmnk_kind(abc_kind, b.dim_idx); |
983 | if (b_bmnk_kind == bmnk_kind) mn_dim *= b.block; |
984 | } |
985 | |
986 | std::vector<dim_t> subtile_dims(thr_view.nvdims(), 1); |
987 | dim_t mn_subtile_dim = ir_utils::safe_divide(mn_dim, dim_t(subtiles)); |
988 | for (auto &b : layout.blocks()) { |
989 | auto b_bmnk_kind = bmnk_mapper.bmnk_kind(abc_kind, b.dim_idx); |
990 | if (b_bmnk_kind == bmnk_kind) { |
991 | if (mn_subtile_dim == 1) continue; |
992 | dim_t next_block; |
993 | if (mn_subtile_dim % b.block == 0) { |
994 | next_block = b.block; |
995 | } else { |
996 | ir_assert(b.block % mn_subtile_dim == 0); |
997 | next_block = mn_subtile_dim; |
998 | } |
999 | subtile_dims[b.dim_idx] *= next_block; |
1000 | mn_subtile_dim /= next_block; |
1001 | } else { |
1002 | subtile_dims[b.dim_idx] *= b.block; |
1003 | } |
1004 | } |
1005 | grid_info_t grid({subtiles}, {idx}); |
1006 | subtile = layout.split(tensor_t(subtile_dims), grid, outer_blocks); |
1007 | return thr_view.create_sub_view(subtile); |
1008 | } |
1009 | |
1010 | void build() { |
1011 | int max_iters = 2; |
1012 | bool load_ok = false; |
1013 | for (int iter = 0; iter < max_iters; iter++) { |
1014 | if (try_load_subtiles(/*allow_2d_load=*/iter == 0)) { |
1015 | load_ok = true; |
1016 | break; |
1017 | } |
1018 | } |
1019 | ir_assert(load_ok) << "Can't generate load statements for subtiles." ; |
1020 | |
1021 | auto a_subtiles = cfg_.subtiles().a(); |
1022 | auto b_subtiles = cfg_.subtiles().b(); |
1023 | |
1024 | for (int i = 0; i < a_subtiles; i++) { |
1025 | for (int j = 0; j < b_subtiles; j++) { |
1026 | build_subtile(i, j); |
1027 | } |
1028 | } |
1029 | |
1030 | if (zp_buf_size_ > 0) |
1031 | register_buffer(zp_buf_, zp_buf_size_, alloc_kind_t::grf); |
1032 | if (zp_mask_size_ > 0) |
1033 | register_buffer(zp_mask_, zp_mask_size_, alloc_kind_t::grf); |
1034 | |
1035 | // Handle temporary buffer in case of GRF reorders. |
1036 | int tmp_buf_size = 0; |
1037 | for (int i = 0; i < a_subtiles; i++) |
1038 | tmp_buf_size |
1039 | = std::max(tmp_buf_size, a_subtiles_[i].tmp_buf_size()); |
1040 | for (int j = 0; j < b_subtiles; j++) |
1041 | tmp_buf_size |
1042 | = std::max(tmp_buf_size, b_subtiles_[j].tmp_buf_size()); |
1043 | if (tmp_buf_size > 0) |
1044 | register_buffer(ab_tmp_buf_, tmp_buf_size, alloc_kind_t::grf); |
1045 | |
1046 | // C layout in problem notation. |
1047 | auto c_layout = c_subtile_layout_; |
1048 | |
1049 | // Add outer blocks coming from A/B subtiles. |
1050 | auto &bmnk_mapper = gemm_schedule_.bmnk_mapper(); |
1051 | for (auto &b : a_i_outer_blocks_) { |
1052 | auto &var = bmnk_mapper.var(abc_kind_t::a, b.dim_idx); |
1053 | int c_dim_idx = bmnk_mapper.dim_idx(abc_kind_t::c, var); |
1054 | c_layout = c_layout.add_outer_block(c_dim_idx, b.block); |
1055 | } |
1056 | for (auto &b : b_j_outer_blocks_) { |
1057 | auto &var = bmnk_mapper.var(abc_kind_t::b, b.dim_idx); |
1058 | int c_dim_idx = bmnk_mapper.dim_idx(abc_kind_t::c, var); |
1059 | c_layout = c_layout.add_outer_block(c_dim_idx, b.block); |
1060 | } |
1061 | |
1062 | c_reg_layout_ = c_layout; |
1063 | } |
1064 | |
1065 | bool can_use_2d_load(const abc_kind_t &abc_kind, const view_t &view) const { |
1066 | bool is_blocked = view.tlayout().innermost_block_layout().elems() > 1; |
1067 | if (!is_blocked) return true; |
1068 | |
1069 | // In general we want to skip expensive logic to check requirements for |
1070 | // 2D block messages with block layouts as performance with 1D messages |
1071 | // is good enough. However there are a few cases (backward by weights |
1072 | // with dpas) when 2D block messages give boost even for block layouts |
1073 | // due to VNNI/transpose features. |
1074 | if (prb_.is_bwd_w && cfg_.is_dp_fma()) { |
1075 | auto &bmnk_mapper = gemm_schedule_.bmnk_mapper(); |
1076 | auto &blocks = view.tlayout().blocks(); |
1077 | if (blocks.size() < 2) return false; |
1078 | int b1_dim_idx = blocks[1].dim_idx; |
1079 | return bmnk_mapper.bmnk_kind(abc_kind, b1_dim_idx) |
1080 | == bmnk_kind_t::k; |
1081 | } |
1082 | return false; |
1083 | } |
1084 | |
1085 | bool try_load_subtiles(bool allow_2d_load) { |
1086 | int a_subtiles = cfg_.subtiles().a(); |
1087 | int b_subtiles = cfg_.subtiles().b(); |
1088 | bool use_a_slm = cfg_.slm().a(); |
1089 | bool use_b_slm = cfg_.slm().b(); |
1090 | a_subtiles_.clear(); |
1091 | b_subtiles_.clear(); |
1092 | for (int i = 0; i < a_subtiles; i++) { |
1093 | auto view = a_i_view_.substitute(a_idx_, i); |
1094 | auto tile = a_i_tile_.substitute(a_idx_, i); |
1095 | // Using buffered view is enabled only when: |
1096 | // - Loading directly from global memory |
1097 | // - FMA kind is mad (dpas implementation is more strict and requires |
1098 | // layouts, not views) |
1099 | // - Loading A tensor (A - activations for FWD/BWD_D where we may have |
1100 | // overlapping when applying KW blocking ) |
1101 | bool load_buffered = cfg_.ow_kw_grf_cache() && !use_a_slm |
1102 | && cfg_.fma_kind() == fma_kind_t::mad; |
1103 | a_subtiles_.emplace_back(ir_ctx_, gemm_schedule_, fma_helper_, |
1104 | abc_kind_t::a, use_a_slm, load_buffered, |
1105 | allow_2d_load && can_use_2d_load(abc_kind_t::a, a_i_view_), |
1106 | i, view, tile, ap_buf_, a_slm_buf_, a_buf_, ab_tmp_buf_); |
1107 | a_subtiles_.back().load(); |
1108 | } |
1109 | subtile_info_t::post_load_func_t b_post_load; |
1110 | if (!use_b_slm && cfg_.reduce_b()) { |
1111 | b_post_load = [&](const layout_t ®_layout, const expr_t ®_buf, |
1112 | const tensor_t &tile) { |
1113 | return b_reduce_ctx_.create_reduce_stmt( |
1114 | reg_layout, reg_buf, tile); |
1115 | }; |
1116 | } |
1117 | for (int j = 0; j < b_subtiles; j++) { |
1118 | auto view = b_j_view_.substitute(b_idx_, j); |
1119 | auto tile = b_j_tile_.substitute(b_idx_, j); |
1120 | b_subtiles_.emplace_back(ir_ctx_, gemm_schedule_, fma_helper_, |
1121 | abc_kind_t::b, use_b_slm, |
1122 | /*load_buffered=*/false, |
1123 | allow_2d_load && can_use_2d_load(abc_kind_t::b, b_j_view_), |
1124 | j, view, tile, bp_buf_, b_slm_buf_, b_buf_, ab_tmp_buf_); |
1125 | |
1126 | b_subtiles_.back().load(b_post_load); |
1127 | } |
1128 | |
1129 | // Validate subtile loads, when VNNI permutation is applied, both A/B |
1130 | // have to use the same pattern. |
1131 | int vnni_permute_factor |
1132 | = a_subtiles_[0].send_hint().hint_2d.vnni_permute_factor; |
1133 | for (int i = 1; i < a_subtiles; i++) { |
1134 | int f = a_subtiles_[i].send_hint().hint_2d.vnni_permute_factor; |
1135 | if (f != vnni_permute_factor) return false; |
1136 | } |
1137 | for (int j = 0; j < b_subtiles; j++) { |
1138 | int f = b_subtiles_[j].send_hint().hint_2d.vnni_permute_factor; |
1139 | if (f != vnni_permute_factor) return false; |
1140 | } |
1141 | return true; |
1142 | } |
1143 | |
1144 | class src_zp_mask_info_t { |
1145 | public: |
1146 | src_zp_mask_info_t() = delete; |
1147 | src_zp_mask_info_t(load_multiply_builder_t &lmb, int m_blk, int k_blk, |
1148 | int desc_m, int desc_n, int channels, int a_stride, bool is_mad, |
1149 | const view_t &a_view, expr_t &zp_mask, int &zp_mask_size) |
1150 | : lmb_(lmb) |
1151 | , is_const_(true) |
1152 | , is_simd_(true) |
1153 | , is_scalar_(false) |
1154 | , is_wide_(m_blk < 16) |
1155 | , zp_mask_(zp_mask) { |
1156 | const auto tile |
1157 | = lmb_.gemm_schedule_.a_thr_tile(/*is_relative=*/false); |
1158 | const auto a_thr_view |
1159 | = lmb_.gemm_schedule_.a_view().create_sub_view(tile); |
1160 | const auto ic_dim = (!is_mad) ? 2 : 1; |
1161 | ic_start_ = a_thr_view.vstart(ic_dim); |
1162 | |
1163 | // 0. Are the masks at all required? |
1164 | const auto &prb = lmb_.prb_; |
1165 | const auto dims = tile.dims()[3] * tile.dims()[4] * tile.dims()[5]; |
1166 | const auto is_scalar = !is_mad && (dims <= 1); |
1167 | |
1168 | const auto has_pad = (prb.pd + 1) * (prb.ph + 1) * (prb.pw + 1) > 1; |
1169 | const auto has_stride_bd |
1170 | = prb.is_bwd_d && (prb.sd * prb.sh * prb.sw > 1); |
1171 | |
1172 | // 1. Get the raw representation of the buffer`s masks |
1173 | auto mask_tensor |
1174 | = a_thr_view.create_mask_tensor(lmb_.ir_ctx_.cset()); |
1175 | |
1176 | // 2. Collect the masks, transforming the dimensions as needed |
1177 | int channels_blk = std::min(channels, |
1178 | (int)a_thr_view.tlayout().normalize().blocks()[0].block); |
1179 | if (channels_blk > 32) channels_blk = 32; |
1180 | const auto c_blk = std::min(channels_blk, m_blk); |
1181 | auto a_tdims = a_view.tlayout().dims(); |
1182 | auto mask_blk = is_mad ? c_blk : channels_blk; |
1183 | size_ = ((prb.kd * prb.kh * prb.kw > 1) || has_pad || has_stride_bd) |
1184 | * ((!is_scalar) ? accumulate(a_tdims.begin(), a_tdims.end(), |
1185 | 1, std::multiplies<dim_t>()) |
1186 | / mask_blk |
1187 | : 1); |
1188 | if (size_ == 0) return; |
1189 | |
1190 | mask_tensor_t masks( |
1191 | layout_t(type_t::_bool(), 0, std::vector<dim_t> {size_})); |
1192 | std::vector<dim_t> a_dims(a_view.tlayout().ndims(), 1); |
1193 | a_dims[ic_dim] = mask_blk; |
1194 | int i = 0; |
1195 | a_view.tlayout().for_each_tile( |
1196 | tensor_t(a_dims), [&](const std::vector<dim_t> &start) { |
1197 | std::vector<dim_t> a_st(a_thr_view.nvdims(), 0); |
1198 | for (int idx = 0; idx < (int)start.size(); idx++) { |
1199 | auto tdim_to_vdims = [&](int idx) { |
1200 | const auto &tdim = a_view.tdim(idx); |
1201 | auto vidx0 = tdim.vidx(0); |
1202 | auto vidx1 = -1; |
1203 | int vdim1 = 0; |
1204 | ir_assert(tdim.nvargs() <= 2); |
1205 | for (int vdim0 = 0; |
1206 | vdim0 < a_thr_view.vdims()[vidx0]; |
1207 | vdim0++) { |
1208 | auto tdim_expr = substitute(tdim.expr(), |
1209 | a_view.vvars()[vidx0], |
1210 | to_expr(vdim0)); |
1211 | if (tdim.nvargs() == 2) { |
1212 | vidx1 = tdim.vidx(1); |
1213 | for (vdim1 = 0; vdim1 |
1214 | < a_thr_view.vdims()[vidx1]; |
1215 | vdim1++) { |
1216 | auto tdim_expr2 = substitute( |
1217 | tdim_expr, |
1218 | a_view.vvars()[vidx1], |
1219 | to_expr(vdim1)); |
1220 | if (to_cpp<dim_t>( |
1221 | simplify(tdim_expr2)) |
1222 | == start[idx]) { |
1223 | a_st[vidx1] = vdim1; |
1224 | a_st[vidx0] = vdim0; |
1225 | return; |
1226 | } |
1227 | } |
1228 | tdim_expr = substitute(tdim_expr, |
1229 | a_view.vvars()[vidx1], |
1230 | to_expr(0)); |
1231 | } |
1232 | if (to_cpp<dim_t>(simplify(tdim_expr)) |
1233 | == start[idx]) { |
1234 | a_st[vidx0] = vdim0; |
1235 | break; |
1236 | } |
1237 | } |
1238 | }; |
1239 | tdim_to_vdims(idx); |
1240 | } |
1241 | auto off = a_thr_view.create_pseudo_vlayout() |
1242 | .make_dense() |
1243 | .offset<dim_t>(a_st); |
1244 | if (i >= size_) return; |
1245 | masks.set_mask(i, mask_tensor.mask(off)); |
1246 | i++; |
1247 | }); |
1248 | |
1249 | // 3. Compute some basic properties of the masks just collected |
1250 | for (int n = 0; n < size_; n++) { |
1251 | auto *sh = masks.mask(n).as_ptr<shuffle_t>(); |
1252 | is_simd_ &= !sh || sh->is_broadcast(); |
1253 | is_const_ &= !!sh; |
1254 | for (int v = (sh) ? 0 : c_blk; v < c_blk; v++) |
1255 | is_const_ &= sh->vec[sh->idx[v]].is<bool_imm_t>(); |
1256 | } |
1257 | |
1258 | // 4. Scalarize if the masks permit, transform to shorts otherwise |
1259 | for (int n = 0; n < size_; n++) |
1260 | if (is_simd_) { |
1261 | object_map_t<expr_t, std::vector<expr_t>> vars; |
1262 | expr_scalarizer_t sc(c_blk, 0, vars); |
1263 | masks.set_mask(n, sc.mutate(masks.mask(n))); |
1264 | } else if (is_const_) { |
1265 | uint16_t mask = 0; |
1266 | auto &sh = masks.mask(n).as<shuffle_t>(); |
1267 | for (int v = c_blk; v; v--) |
1268 | mask = mask * 2 |
1269 | + sh.vec[sh.idx[v - 1]].as<bool_imm_t>().value; |
1270 | masks.set_mask(n, mask); |
1271 | } else { |
1272 | ir_error_not_expected() << "Non-SIMD non-constant masks!" ; |
1273 | } |
1274 | |
1275 | // 5. Assume lack of masks if they all are true |
1276 | bool all_true = true; |
1277 | for (int n = 0; all_true && (n < size_); n++) |
1278 | all_true &= masks.mask(n).is_equal(expr_t(true)); |
1279 | if (all_true) { |
1280 | is_const_ = true; |
1281 | is_simd_ = true; |
1282 | size_ = 0; |
1283 | return; |
1284 | } |
1285 | is_scalar_ = is_scalar; |
1286 | |
1287 | // 6. zp_mask_ gets created by the caller; allocate var_mask_ |
1288 | var_mask_ = lmb_.ir_ctx_.create_tmp_var( |
1289 | (!is_bool()) ? type_t::s16() : type_t::_bool(16)); |
1290 | |
1291 | // 7. Vectorize everything for easier computation and emit the IR |
1292 | if (!is_scalar) { |
1293 | std::vector<expr_t> exprs; |
1294 | object_eq_map_t<expr_t, expr_t> vars; |
1295 | |
1296 | // Here we assume two important things: |
1297 | // - C has exactly one N block like 4c16f8c (where f is ow) |
1298 | // - The innermost block is by M and it matches the SIMD size |
1299 | |
1300 | std::vector<expr_t> proto(size_); |
1301 | if (is_wide_) { |
1302 | for (int n = 0; n < int(proto.size()); n++) { |
1303 | const auto r = (n / 2 / k_blk) * 2 * k_blk |
1304 | + (n / 2) % k_blk + (n & 1) * k_blk; |
1305 | proto[n] = masks.mask(r % size_); |
1306 | } |
1307 | } else { |
1308 | for (int n = 0; n < int(proto.size()); n++) |
1309 | proto[n] = masks.mask(n % size_); |
1310 | } |
1311 | for (; size_ >= m_blk * 2; size_ /= 2) { |
1312 | auto c = [](expr_t &a, expr_t &b) { return a.is_equal(b); }; |
1313 | auto half = proto.begin() + size_ / 2; |
1314 | if (!std::equal(proto.begin(), half, half, c)) break; |
1315 | } |
1316 | |
1317 | const auto blk |
1318 | = (size_ > m_blk) ? std::min(m_blk * 2, 16) : m_blk; |
1319 | for (int n = 0; n < size_; n += blk) { |
1320 | std::vector<expr_t> e(blk); |
1321 | for (int m = 0; m < blk; m++) { |
1322 | e[m] = proto[n + m % (size_ - n)]; |
1323 | } |
1324 | int ntrue = 0, nfalse = 0; |
1325 | for (int m = 0; m < blk; m++) { |
1326 | if (e[m].is<bool_imm_t>()) |
1327 | ((e[m].as<bool_imm_t>().value) ? ntrue : nfalse)++; |
1328 | } |
1329 | ir_assert((ntrue == 0) || (ntrue + nfalse == blk)); |
1330 | if ((ntrue == 0) && (nfalse > 0) && (nfalse < blk)) { |
1331 | auto nb = *std::find_if(e.begin(), e.end(), |
1332 | [](expr_t &x) { return !x.is<bool_imm_t>(); }); |
1333 | for (int m = 0; m < blk; m++) { |
1334 | e[m] = (e[m].is<bool_imm_t>()) |
1335 | ? (nb & expr_t(false)) |
1336 | : (e[m] & expr_t(true)); |
1337 | } |
1338 | } |
1339 | exprs.emplace_back(vector2expr(e, vars)); |
1340 | } |
1341 | |
1342 | const auto real_size = std::max( |
1343 | size_ * ((is_wide_) ? 8 : 1), int(exprs.size()) * blk); |
1344 | zp_mask_size = std::max(zp_mask_size, real_size * w_stride()); |
1345 | for (int i = 0; i < int(exprs.size()); i++) { |
1346 | auto expr = cast_t::make(w_type(blk), exprs[i]); |
1347 | stmt_ = stmt_.append( |
1348 | store_t::make(zp_mask_, i * blk * w_stride(), |
1349 | (is_simd_) ? -expr : expr, w_stride())); |
1350 | } |
1351 | if (is_wide_) { |
1352 | auto wide_scalar = [&](const expr_t &a, const expr_t &b) { |
1353 | std::vector<int> idx(16, 1); |
1354 | for (int i = 0; i < 8; i++) |
1355 | idx[i] = 0; |
1356 | return shuffle_t::make(std::vector<expr_t> {a, b}, idx); |
1357 | }; |
1358 | for (int i = size_ - 2; i > 0; i -= 2) { |
1359 | auto load_l = load_t::make( |
1360 | type_t::s16(), zp_mask_, i * w_stride()); |
1361 | auto load_h = load_t::make( |
1362 | type_t::s16(), zp_mask_, (i + 1) * w_stride()); |
1363 | auto load = wide_scalar(load_l, load_h); |
1364 | load = cast_t::make(type_t::s16(16), load); |
1365 | stmt_ = stmt_.append(store_t::make(zp_mask_, |
1366 | i * 8 * w_stride(), load, w_stride())); |
1367 | } |
1368 | if (size_ % 2 == 0) { |
1369 | auto l0h = load_t::make( |
1370 | type_t::s16(), zp_mask_, w_stride()); |
1371 | stmt_ = stmt_.append(store_t::make(zp_mask_, |
1372 | 8 * w_stride(), |
1373 | shuffle_t::make_broadcast(l0h, 8), w_stride())); |
1374 | } |
1375 | auto l0l = load_t::make(type_t::s16(), zp_mask_, 0); |
1376 | stmt_ = stmt_.append(store_t::make(zp_mask_, 0, |
1377 | shuffle_t::make_broadcast(l0l, 8), w_stride())); |
1378 | } |
1379 | |
1380 | for (auto &v : vars) |
1381 | stmt_ = let_t::make(v.second, v.first, stmt_); |
1382 | } else { // is_scalar == true |
1383 | zp_mask_size = std::max(zp_mask_size, type_t::s16().size()); |
1384 | auto expr = cast_t::make(type_t::s16(), masks.mask(0)); |
1385 | if (is_simd_) expr = cast(-expr, type_t::s16()); |
1386 | stmt_ = stmt_.append(store_t::make(zp_mask_, 0, expr)); |
1387 | } |
1388 | } |
1389 | |
1390 | const stmt_t &stmt() const { return stmt_; } |
1391 | expr_t ic_start() const { return ic_start_; } |
1392 | bool is_scalar() const { return is_scalar_; } |
1393 | bool is_simd() const { return is_simd_; } |
1394 | bool is_const() const { return is_const_; } |
1395 | bool is_bool() const { return !size_ || !is_simd() || is_scalar(); } |
1396 | |
1397 | expr_t gen_mask(int base) const { |
1398 | auto null_mask = (is_bool()) ? expr_t() : expr_t(-1); |
1399 | if (!size_ || is_scalar_) return (size_) ? var_mask_ : null_mask; |
1400 | return word2bool(base % size_); |
1401 | } |
1402 | |
1403 | expr_t maybe_gen_mask_let(const stmt_t &loop) const { |
1404 | return (size_ && is_scalar_) |
1405 | ? let_t::make(var_mask_, word2bool(0), loop) |
1406 | : loop; |
1407 | } |
1408 | |
1409 | private: |
1410 | type_t w_type(int width = 1) const { |
1411 | return (is_bool()) ? type_t::u16(width) : type_t::s32(width); |
1412 | } |
1413 | int w_stride() const { return w_type().size(); } |
1414 | |
1415 | expr_t word2bool(int off) const { |
1416 | auto type = (is_bool()) ? type_t::u16() : type_t::s16(16); |
1417 | expr_t load; |
1418 | if (is_wide_ && !is_bool()) { |
1419 | load = load_t::make( |
1420 | type, zp_mask_, off * 8 * w_stride(), w_stride()); |
1421 | } else { |
1422 | load = load_t::make( |
1423 | type.scalar(), zp_mask_, off * w_stride(), w_stride()); |
1424 | if (!is_bool()) load = shuffle_t::make_broadcast(load, 16); |
1425 | } |
1426 | return (is_bool()) ? cast_t::make(type_t::_bool(16), load) : load; |
1427 | } |
1428 | |
1429 | expr_t vector2expr(const std::vector<expr_t> &expr, |
1430 | object_eq_map_t<expr_t, expr_t> &vars) const { |
1431 | constexpr size_t mask = 0x8000; |
1432 | auto hash = [&](const binary_op_t &b) -> size_t { |
1433 | return size_t(b.op_kind) | ((b.b.is<int_imm_t>()) ? mask : 0UL); |
1434 | }; |
1435 | auto fetch_var = [this, &vars](expr_t e) { |
1436 | if (vars.find(e) == vars.end()) { |
1437 | auto var = lmb_.ir_ctx_.create_tmp_var( |
1438 | type_t::s32(e.type().elems()), "zp_mask" ); |
1439 | vars.emplace(e, var); |
1440 | } |
1441 | return vars[e]; |
1442 | }; |
1443 | if (expr.empty()) return expr_t(); |
1444 | // Can only vectorize if the element count is a power of 2 |
1445 | ir_assert(math::is_pow2(expr.size())) << "Cannot vectorize." ; |
1446 | |
1447 | std::unordered_map<size_t, size_t> kind; |
1448 | for (const expr_t &e : expr) |
1449 | if (const auto *bin = e.as_ptr<binary_op_t>()) { |
1450 | kind[hash(*bin)]++; |
1451 | auto bin_a = bin; |
1452 | while ((bin_a = bin_a->a.as_ptr<binary_op_t>())) { |
1453 | if (kind[hash(*bin_a)] > 0) { |
1454 | kind[hash(*bin)] += kind[hash(*bin_a)]; |
1455 | break; |
1456 | } |
1457 | } |
1458 | } |
1459 | if (!kind.empty()) { |
1460 | using k_type = decltype(kind)::value_type; |
1461 | auto k = std::max_element( |
1462 | kind.begin(), kind.end(), [](k_type &a, k_type &b) { |
1463 | return a.second < b.second; |
1464 | }); |
1465 | const auto k_raw = op_kind_t(k->first & (mask - 1)); |
1466 | std::vector<expr_t> a, b; |
1467 | for (const expr_t &e : expr) { |
1468 | const auto *bin = e.as_ptr<binary_op_t>(); |
1469 | if (bin && (hash(*bin) == k->first)) { |
1470 | a.emplace_back(bin->a); |
1471 | b.emplace_back(bin->b); |
1472 | } else { |
1473 | const int is_mul = (k_raw == op_kind_t::_mul); |
1474 | ir_assert(is_mul || (k_raw == op_kind_t::_add)); |
1475 | a.emplace_back(e); |
1476 | b.emplace_back(is_mul); |
1477 | } |
1478 | } |
1479 | auto a_new = vector2expr(a, vars); |
1480 | auto b_new = vector2expr(b, vars); |
1481 | if (auto *a_bin = a_new.as_ptr<binary_op_t>()) |
1482 | if ((a_bin->op_kind == op_kind_t::_add) && is_var(a_bin->b) |
1483 | && is_cmp_op(k_raw) && is_shuffle_const(b_new)) |
1484 | for (auto &v : vars) |
1485 | if (v.second.is_equal(a_bin->b)) { |
1486 | auto fold = const_fold_non_recursive( |
1487 | b_new - v.first); |
1488 | return binary_op_t::make(negate_cmp_op(k_raw), |
1489 | fetch_var(fold), a_bin->a); |
1490 | } |
1491 | return binary_op_t::make(k_raw, a_new, b_new); |
1492 | } |
1493 | |
1494 | size_t num_ints = 0; |
1495 | for (const expr_t &e : expr) |
1496 | num_ints += e.is<int_imm_t>(); |
1497 | ir_assert((num_ints == 0) || (num_ints == expr.size())); |
1498 | if (num_ints == expr.size()) { |
1499 | auto offs = shuffle_t::make(expr); |
1500 | if (offs.as<shuffle_t>().is_broadcast()) return offs; |
1501 | return fetch_var(offs); |
1502 | } |
1503 | |
1504 | size_t num_bools = 0; |
1505 | for (const expr_t &e : expr) |
1506 | num_bools += e.is<bool_imm_t>(); |
1507 | ir_assert((num_bools == 0) || (num_bools == expr.size())); |
1508 | if (num_bools == expr.size()) return shuffle_t::make(expr); |
1509 | |
1510 | ir_assert(expr.front().is<var_t>()); |
1511 | for (const expr_t &e : expr) |
1512 | ir_assert(e.is_same(expr.front())); |
1513 | return shuffle_t::make_broadcast(expr.front(), int(expr.size())); |
1514 | } |
1515 | |
1516 | load_multiply_builder_t &lmb_; |
1517 | bool is_const_; |
1518 | bool is_simd_; |
1519 | bool is_scalar_; |
1520 | bool is_wide_; |
1521 | int size_; |
1522 | expr_t ic_start_; |
1523 | expr_t var_mask_; |
1524 | const expr_t &zp_mask_; |
1525 | stmt_t stmt_; |
1526 | }; |
1527 | |
1528 | stmt_t maybe_add_src_zps(const view_t &a_view, const view_t &b_view, |
1529 | const multiply_builder_t &mul_builder, int i_buf, int j_buf) { |
1530 | if (!prb_.zp_cfg.do_src_compensation) return mul_builder.stmt(); |
1531 | const bool is_runtime = prb_.zp_cfg.is_runtime_src_zero_points; |
1532 | const bool is_scalar = prb_.zp_cfg.is_common_src_zero_point; |
1533 | const bool is_mad = (cfg_.fma_kind() == fma_kind_t::mad); |
1534 | const int channels = utils::rnd_up_pow2( |
1535 | (!is_mad) ? (prb_.is_fwd) ? prb_.ic : prb_.oc : prb_.g); |
1536 | const int c_blk = (channels < 32) ? channels : 32; |
1537 | const int k_blk = ((channels > 4)) ? 32 / c_blk : 1; |
1538 | const int m_blk = cfg_.simd(); |
1539 | |
1540 | const type_t s_type = a_view.type(); |
1541 | const type_t i_type = type_t::s32(); // x32 type that is always signed |
1542 | auto has_sign = [&]() { |
1543 | if (is_runtime) return s_type.is_signed(); |
1544 | ir_assert(is_scalar); |
1545 | return prb_.zp_cfg.common_src_zero_point < 0; |
1546 | }; |
1547 | const type_t d_type = (has_sign()) ? type_t::s32() : type_t::u32(); |
1548 | ir_assert((is_mad) ? s_type.is_x16() : s_type.is_x8()); |
1549 | |
1550 | const int a_stride |
1551 | = s_type.size() * int(a_view.tlayout().blocks()[0].stride); |
1552 | int desc_m = 0, desc_n = 0; |
1553 | |
1554 | if (!is_mad) { |
1555 | auto &mapper = gemm_schedule_.bmnk_mapper(); |
1556 | auto a_layout = mapper.map_to_bmnk(abc_kind_t::a, |
1557 | {bmnk_kind_t::m, bmnk_kind_t::k}, a_view.create_vlayout()); |
1558 | auto b_layout = mapper.map_to_bmnk(abc_kind_t::b, |
1559 | {bmnk_kind_t::k, bmnk_kind_t::n}, b_view.create_vlayout()); |
1560 | if (mul_builder.do_transpose()) { |
1561 | a_layout = a_layout.transpose(); |
1562 | b_layout = b_layout.transpose(); |
1563 | std::swap(a_layout, b_layout); |
1564 | } |
1565 | multiply_desc_t desc(a_layout, b_layout, true); |
1566 | desc_m = desc.m(); |
1567 | desc_n = desc.n(); |
1568 | } else { |
1569 | desc_n = a_view.tlayout().size() / m_blk / a_stride; |
1570 | desc_m = m_blk; |
1571 | } |
1572 | if (zp_mask_.is_empty()) |
1573 | zp_mask_ = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "zp_mask" ); |
1574 | src_zp_mask_info_t masks(*this, m_blk, k_blk, desc_m, desc_n, channels, |
1575 | a_stride, is_mad, a_view, zp_mask_, zp_mask_size_); |
1576 | stmt_t data = masks.stmt(); |
1577 | |
1578 | const int simd_per_ic = utils::div_up( |
1579 | std::min((!is_scalar) ? channels : 1, 32), m_blk); |
1580 | const std::vector<dim_t> dims |
1581 | = {m_blk * std::min((is_mad) ? 1 : 2, simd_per_ic)}; |
1582 | const bool sc_ic = is_scalar || (channels <= 32); |
1583 | expr_t offs = (!sc_ic) ? masks.ic_start() * d_type.size() : 0; |
1584 | |
1585 | if (is_runtime && !sc_ic && !cfg_.pipeline().do_unroll() |
1586 | && (cfg_.slm().bufs() > 1)) { |
1587 | auto buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "zp_mask" ); |
1588 | register_buffer(buf, type_t::u32().size(), alloc_kind_t::grf); |
1589 | data = data.append(store_t::make(buf, 0, offs)); |
1590 | offs = load_t::make(type_t::u32(), buf, 0); |
1591 | } |
1592 | |
1593 | auto get_src_zp_size = [](bool scalar, bool runtime, bool mad, int b) { |
1594 | if (scalar) return (!mad) ? b * 2 : ((runtime) ? b : 0); |
1595 | return (!mad) ? std::max(b * 2, 32) : 32; |
1596 | }; |
1597 | const int m_blk_x2 = std::min(m_blk * 2, 16); |
1598 | const int src_zp_size = get_src_zp_size( |
1599 | is_scalar, is_runtime, is_mad, m_blk_x2 * k_blk); |
1600 | if (zp_buf_.is_empty()) |
1601 | zp_buf_ = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "zp_buf" ); |
1602 | zp_buf_size_ = std::max(zp_buf_size_, src_zp_size * d_type.size()); |
1603 | |
1604 | for (int i = (is_runtime) ? 0 : std::numeric_limits<int>::max(); |
1605 | i < m_blk * simd_per_ic; i += dims[0]) { |
1606 | const int b = i * d_type.size(); |
1607 | view_t zpv(layout_t(d_type, 0, dims)); |
1608 | auto read = make_access_builder(ir_ctx_, zpv, |
1609 | kernel_info_.find_arg("src_zero_points" )[offs + b], |
1610 | zp_buf_[b], send_op_t::load, send_address_t::a64); |
1611 | data = data.append(read.stmt()); |
1612 | } |
1613 | |
1614 | if (is_mad) { |
1615 | // TODO: for now, only b-blocking (per G) of the MAD loop is ready; |
1616 | // please implement n-blocking (per OC) as well! |
1617 | ir_assert(a_view.tlayout().size() % a_stride == 0); |
1618 | ir_assert(masks.is_simd()); |
1619 | |
1620 | std::vector<stmt_t> loop(std::max(1, 32 / m_blk)); |
1621 | for (int a_off = 0; a_off < a_view.tlayout().size(); |
1622 | a_off += m_blk * a_stride) { |
1623 | int iter = (a_off / m_blk / a_stride) % loop.size(); |
1624 | type_t sv_type(s_type.kind(), m_blk); |
1625 | type_t b_type(s_type.kind(), (!is_scalar) ? m_blk : 1); |
1626 | auto a = load_t::make(sv_type, a_buf_, a_off, a_stride); |
1627 | auto b_off |
1628 | = (!is_scalar && (channels > m_blk)) ? iter * m_blk : 0; |
1629 | auto b = (is_runtime) // '4'-s mean '(|i32| / |i16|) * |i16|' |
1630 | ? load_t::make(b_type, zp_buf_, b_off * 4, 4) |
1631 | : prb_.zp_cfg.common_src_zero_point; |
1632 | auto mask = masks.gen_mask( |
1633 | (utils::div_up(k_blk, 2)) * a_off / m_blk / a_stride); |
1634 | auto mad = (masks.is_bool()) |
1635 | ? binary_op_t::make(op_kind_t::_sub, a, b, sv_type) |
1636 | : ternary_op_t::make( |
1637 | op_kind_t::_mad, a, mask, b, sv_type); |
1638 | loop[iter] = loop[iter].append(store_t::make(a_buf_, a_off, mad, |
1639 | a_stride, (masks.is_bool()) ? mask : expr_t())); |
1640 | } |
1641 | for (size_t i = 1; i < loop.size(); i++) |
1642 | loop[0] = loop[0].append(loop[i]); |
1643 | return data.append(masks.maybe_gen_mask_let( |
1644 | loop[0].append(mul_builder.stmt()))); |
1645 | } |
1646 | |
1647 | if (is_scalar) { |
1648 | expr_t expr = (!is_runtime) |
1649 | ? (prb_.zp_cfg.common_src_zero_point & 0xFF) * 0x01010101 |
1650 | : cast_t::make(type_t::s8(4), |
1651 | shuffle_t::make_broadcast( |
1652 | load_t::make(s_type, zp_buf_, 0), 4)); |
1653 | data = data.append(store_t::make(zp_buf_, 0, expr)); |
1654 | } else { |
1655 | data = data.append(store_t::make(zp_buf_, 0, |
1656 | load_t::make(type_t::u8(m_blk_x2), zp_buf_, 0, 4))); |
1657 | if (channels > 16) |
1658 | data = data.append(store_t::make(zp_buf_, 16, |
1659 | load_t::make(type_t::u8(m_blk_x2), zp_buf_, 64, 4))); |
1660 | if (m_blk_x2 != m_blk) |
1661 | data = data.append(store_t::make(zp_buf_, 32, |
1662 | load_t::make(type_t::u32(4), zp_buf_, 4, 8), 8)); |
1663 | } |
1664 | std::vector<stmt_t> parts; |
1665 | |
1666 | auto wide_scalar = [m_blk](const expr_t &a, const expr_t &b, int blk) { |
1667 | if (blk == m_blk) return shuffle_t::make_broadcast(a, m_blk); |
1668 | std::vector<int> index(blk, 1); |
1669 | for (int i = 0; i < m_blk; i++) |
1670 | index[i] = 0; |
1671 | return shuffle_t::make(std::vector<expr_t> {a, b}, index); |
1672 | }; |
1673 | auto wide_vector = [m_blk, i_type](const expr_t &a, int blk) { |
1674 | if (blk == m_blk) |
1675 | return load_t::make(type_t(i_type.kind(), m_blk), a, 0); |
1676 | std::vector<expr_t> vec(m_blk); |
1677 | std::vector<int> index(blk); |
1678 | for (int i = 0; i < m_blk; i++) { |
1679 | vec[i] = load_t::make(i_type, a, i * i_type.size()); |
1680 | index[i + m_blk] = index[i] = i; |
1681 | } |
1682 | return shuffle_t::make(vec, index); |
1683 | }; |
1684 | std::vector<expr_t> acc; |
1685 | for (int i = 1; i <= 2 * k_blk; i++) |
1686 | acc.emplace_back( |
1687 | zp_buf_[(src_zp_size |
1688 | - utils::div_up( |
1689 | i, m_blk != m_blk_x2 ? 1 : 2) |
1690 | * m_blk) |
1691 | * d_type.size()]); |
1692 | for (int i_m = 0; i_m < desc_m; i_m += m_blk) { |
1693 | const int blk |
1694 | = (masks.is_simd() || masks.is_scalar()) ? m_blk_x2 : m_blk; |
1695 | for (int i = 0; i < k_blk; i++) { |
1696 | for (int i_k = i * (c_blk / 4); i_k |
1697 | < ((channels > 4) ? (c_blk + i * c_blk) / 4 : prb_.kw); |
1698 | i_k += m_blk_x2 / m_blk) { |
1699 | type_t vi(i_type.kind(), m_blk_x2); |
1700 | const int szp_off = (is_scalar) ? 0 : (i_k * d_type.size()); |
1701 | auto b0 = load_t::make(d_type, zp_buf_, szp_off); |
1702 | auto b1 = load_t::make( |
1703 | d_type, zp_buf_, szp_off + m_blk * d_type.size()); |
1704 | auto b = (is_scalar) ? b0 : wide_scalar(b0, b1, m_blk_x2); |
1705 | auto c = load_t::make(vi, b_buf_, |
1706 | (i_m * (32 / 4) + i_k * m_blk) * d_type.size()); |
1707 | if (is_scalar) std::swap(b, c); |
1708 | auto a = (i_k != i * (c_blk / 4)) |
1709 | ? load_t::make(vi, acc[i * 2 + 1], 0) |
1710 | : expr_t(0); |
1711 | parts.emplace_back(store_t::make(acc[i * 2 + 1], 0, |
1712 | ternary_op_t::make(op_kind_t::_dp4a, a, b, c, vi))); |
1713 | } |
1714 | |
1715 | if (m_blk_x2 != m_blk) { |
1716 | type_t vi(i_type.kind(), m_blk); |
1717 | auto a = load_t::make(vi, acc[i * 2 + 1], 0); |
1718 | auto b = load_t::make(vi, acc[i * 2 + 0], 0); |
1719 | parts.emplace_back(store_t::make(acc[i * 2 + 1], 0, a + b)); |
1720 | if (!masks.is_bool() && (blk != m_blk)) |
1721 | parts.emplace_back(store_t::make(acc[i * 2 + 0], 0, a)); |
1722 | } |
1723 | } |
1724 | for (int i_n = 0; i_n < desc_n; i_n += blk / m_blk) { |
1725 | int off_n = i_m / m_blk * desc_n + i_n; |
1726 | const int ij_buf = i_buf * cfg_.subtiles().b() + j_buf; |
1727 | auto dst = c_buf_ + off_n * m_blk * d_type.size() |
1728 | + ij_buf * mul_builder.c_layout().size(); |
1729 | type_t vi(i_type.kind(), blk); |
1730 | auto a = load_t::make(vi, dst, 0); |
1731 | for (int i = 0; i < k_blk; i++) { |
1732 | auto mask = masks.gen_mask(off_n * k_blk + i * blk / m_blk); |
1733 | if (!masks.is_bool()) { |
1734 | auto b = load_t::make(vi, acc[i * 2 + 1], 0); |
1735 | auto mad = ternary_op_t::make( |
1736 | op_kind_t::_mad, a, b, mask); |
1737 | parts.emplace_back(store_t::make(dst, 0, mad)); |
1738 | } else { |
1739 | auto b = wide_vector(acc[i * 2 + 1], blk); |
1740 | auto sub = binary_op_t::make(op_kind_t::_sub, a, b); |
1741 | parts.emplace_back(store_t::make( |
1742 | dst, 0, sub, store_t::default_stride, mask)); |
1743 | } |
1744 | } |
1745 | } |
1746 | } |
1747 | // Stick the compensations between DPASes for better GPU utilization |
1748 | auto raw_dpas = flatten_statements(mul_builder.stmt()); |
1749 | std::vector<stmt_t> dpas; |
1750 | stmt_t full; |
1751 | expr_t src1; |
1752 | for (auto &r : raw_dpas) { |
1753 | ir_assert(is_func_call<dpas_t>(r)); |
1754 | auto &this_src1 = dpas_t::arg_src1(r); |
1755 | if (this_src1.is_equal(src1)) { |
1756 | dpas.back() = dpas.back().append(r); |
1757 | } else { |
1758 | src1 = this_src1; |
1759 | dpas.emplace_back(r); |
1760 | } |
1761 | } |
1762 | ir_assert(parts.size() % dpas.size() == 0); |
1763 | const int loop_size = int(parts.size()) / int(dpas.size()); |
1764 | for (int i = 0; i < int(dpas.size()); i++) { |
1765 | full = full.append(dpas[i]); |
1766 | const auto k = (i + int(dpas.size()) / 2) % int(dpas.size()); |
1767 | for (int j = k * loop_size; j < (k + 1) * loop_size; j++) |
1768 | full = full.append(parts[j]); |
1769 | } |
1770 | return data.append(masks.maybe_gen_mask_let(full)); |
1771 | } |
1772 | |
1773 | void build_subtile(int i, int j) { |
1774 | bool is_first = (i == 0 && j == 0); |
1775 | |
1776 | stmt_t ab_s2r_load; |
1777 | stmt_t ab_g2r_load; |
1778 | if (!a_subtiles_[i].is_loaded()) { |
1779 | ab_s2r_load = ab_s2r_load.append(a_subtiles_[i].s2r_load()); |
1780 | ab_g2r_load = ab_g2r_load.append(a_subtiles_[i].g2r_load()); |
1781 | a_subtiles_[i].set_loaded(); |
1782 | } |
1783 | if (!b_subtiles_[j].is_loaded()) { |
1784 | ab_s2r_load = ab_s2r_load.append(b_subtiles_[j].s2r_load()); |
1785 | ab_g2r_load = ab_g2r_load.append(b_subtiles_[j].g2r_load()); |
1786 | b_subtiles_[j].set_loaded(); |
1787 | } |
1788 | load_mul_stmt_ = load_mul_stmt_.append( |
1789 | stmt_group_t::make(stmt_label_t::g2r_load(i + j), ab_g2r_load)); |
1790 | load_mul_stmt_ = load_mul_stmt_.append( |
1791 | stmt_group_t::make(stmt_label_t::s2r_load(i + j), ab_s2r_load)); |
1792 | |
1793 | auto &a_i_view = a_subtiles_[i].reg_view(); |
1794 | auto &b_j_view = b_subtiles_[j].reg_view(); |
1795 | |
1796 | // Multiply C_i_j += A_i x B_j in GEMM notation. |
1797 | multiply_builder_t mul_builder(cfg_, gemm_schedule_.bmnk_mapper(), |
1798 | a_i_view, b_j_view, a_buf_, b_buf_, c_buf_[c_buf_off_]); |
1799 | c_subtile_layout_ = mul_builder.c_layout(); |
1800 | |
1801 | auto mul_total |
1802 | = maybe_add_src_zps(a_i_view, b_j_view, mul_builder, i, j); |
1803 | |
1804 | c_buf_off_ += c_subtile_layout_.size(); |
1805 | ir_trace() << "Multiply (" << i << ", " << j << "):\n" |
1806 | << mul_total.str() << std::endl; |
1807 | |
1808 | load_mul_stmt_ = load_mul_stmt_.append( |
1809 | stmt_group_t::make(stmt_label_t::mul(i + j), mul_total)); |
1810 | |
1811 | if (!is_first) { |
1812 | ir_assert(mul_builder.c_layout() == c_subtile_layout_) |
1813 | << "Sub-tile layouts must be equal." ; |
1814 | return; |
1815 | } |
1816 | |
1817 | register_buffer( |
1818 | a_buf_, a_subtiles_[i].reg_buf_size(), alloc_kind_t::grf); |
1819 | register_buffer( |
1820 | b_buf_, b_subtiles_[j].reg_buf_size(), alloc_kind_t::grf); |
1821 | } |
1822 | void register_buffer(const stmt_t &alloc) { |
1823 | ir_assert(alloc.is<alloc_t>()); |
1824 | allocs_.push_back(alloc); |
1825 | } |
1826 | |
1827 | void register_buffer(const expr_t &buf, int size, alloc_kind_t kind, |
1828 | const alloc_attr_t &attr = {}) { |
1829 | register_buffer(alloc_t::make(buf, size, kind, attr)); |
1830 | } |
1831 | |
1832 | const conv_problem_t &prb_; |
1833 | const conv_config_t &cfg_; |
1834 | ir_context_t &ir_ctx_; |
1835 | const gemm_schedule_t &gemm_schedule_; |
1836 | const fma_helper_t &fma_helper_; |
1837 | b_reduce_context_t &b_reduce_ctx_; |
1838 | |
1839 | expr_t ap_buf_; |
1840 | expr_t a_slm_buf_; |
1841 | |
1842 | expr_t bp_buf_; |
1843 | expr_t b_slm_buf_; |
1844 | |
1845 | expr_t zp_buf_; |
1846 | int zp_buf_size_ = 0; |
1847 | |
1848 | expr_t zp_mask_; |
1849 | int zp_mask_size_ = 0; |
1850 | |
1851 | layout_t c_reg_layout_; |
1852 | |
1853 | expr_t ab_tmp_buf_; |
1854 | expr_t a_buf_; |
1855 | expr_t b_buf_; |
1856 | expr_t c_buf_; |
1857 | |
1858 | // Per-thread views to multiply. |
1859 | view_t a_thr_view_; |
1860 | view_t b_thr_view_; |
1861 | |
1862 | // Sub-tile indices. |
1863 | expr_t a_idx_; |
1864 | expr_t b_idx_; |
1865 | |
1866 | // Sub-tile views. |
1867 | view_t a_i_view_; |
1868 | view_t b_j_view_; |
1869 | |
1870 | tensor_t a_i_tile_; |
1871 | tensor_t b_j_tile_; |
1872 | |
1873 | std::vector<subtile_info_t> a_subtiles_; |
1874 | std::vector<subtile_info_t> b_subtiles_; |
1875 | |
1876 | std::vector<block_t> a_i_outer_blocks_; |
1877 | std::vector<block_t> b_j_outer_blocks_; |
1878 | |
1879 | std::vector<stmt_t> allocs_; |
1880 | |
1881 | stmt_t load_mul_stmt_; |
1882 | |
1883 | int c_buf_off_ = 0; |
1884 | layout_t c_subtile_layout_; |
1885 | |
1886 | const kernel_info_t &kernel_info_; |
1887 | }; |
1888 | |
1889 | class compute_builder_t { |
1890 | public: |
1891 | compute_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx, |
1892 | const kernel_info_t &kernel_info) |
1893 | : cfg_(cfg) |
1894 | , ir_ctx_(ir_ctx) |
1895 | , b_reduce_ctx_(ir_ctx, cfg) |
1896 | , g2s_ctx_(ir_ctx) |
1897 | , fma_helper_(cfg.simd(), cfg.fma_kind(), cfg.prb().a_data_type, |
1898 | cfg.prb().b_data_type, cfg.allow_a_grf_reorder(), |
1899 | cfg.allow_b_grf_reorder(), !cfg.prb().is_dw) |
1900 | , kernel_info_(kernel_info) {} |
1901 | |
1902 | int ab_slm_size() const { return ab_slm_size_; } |
1903 | |
1904 | const stmt_t &c_zero_out_stmt() const { return c_zero_out_stmt_; } |
1905 | const stmt_t &b_reduced_zero_out_stmt() const { |
1906 | return b_reduced_zero_out_stmt_; |
1907 | } |
1908 | |
1909 | stmt_t zero_out_stmt() const { |
1910 | stmt_t ret; |
1911 | ret = ret.append(c_zero_out_stmt()); |
1912 | ret = ret.append(b_reduced_zero_out_stmt()); |
1913 | return ret; |
1914 | } |
1915 | |
1916 | stmt_t iter_stmt() const { |
1917 | stmt_t stmt; |
1918 | bool use_prefetch = !prefetch_stmt_.is_empty(); |
1919 | bool use_slm = !g2s_load_stmt_.is_empty(); |
1920 | if (use_prefetch) { |
1921 | stmt = stmt.append(stmt_group_t::make( |
1922 | stmt_label_t::prefetch(), prefetch_stmt_)); |
1923 | } else if (use_slm) { |
1924 | stmt = stmt.append(stmt_group_t::make( |
1925 | stmt_label_t::g2s_load(), g2s_load_stmt_)); |
1926 | stmt = stmt.append(funcs::barrier()); |
1927 | stmt = stmt.append(stmt_group_t::make( |
1928 | stmt_label_t::g2s_store(), g2s_store_stmt_)); |
1929 | stmt = stmt.append(funcs::barrier()); |
1930 | } |
1931 | stmt = stmt.append(load_mul_stmt_); |
1932 | return stmt; |
1933 | } |
1934 | |
1935 | const stmt_t &c_store_stmt() const { return c_store_stmt_; } |
1936 | const stmt_t &b_reduced_store_stmt() const { return b_reduced_store_stmt_; } |
1937 | |
1938 | stmt_t inject_compute_alloc_stmts(const stmt_t &stmt) const { |
1939 | return jit::inject_alloc_stmts(stmt, compute_allocs_); |
1940 | } |
1941 | |
1942 | stmt_t inject_out_alloc_stmts(const stmt_t &stmt) const { |
1943 | return jit::inject_alloc_stmts(stmt, out_allocs_); |
1944 | } |
1945 | |
1946 | stmt_t inject_let_stmts(const stmt_t &stmt) const { |
1947 | return jit::inject_let_stmts(stmt, g2s_ctx_.grid_idx_lets); |
1948 | } |
1949 | |
1950 | void set_gemm_schedule(const gemm_schedule_t &gemm_schedule) { |
1951 | gemm_schedule_ = gemm_schedule; |
1952 | } |
1953 | |
1954 | // Setters for original AP/BP/CP buffers (P - problem notation). |
1955 | void set_ap_buf(const expr_t &buf) { ap_buf_ = buf; } |
1956 | void set_bp_buf(const expr_t &buf) { bp_buf_ = buf; } |
1957 | void set_cp_buf(const expr_t &buf) { cp_buf_ = buf; } |
1958 | void set_b_reduced_mem_buf(const expr_t &buf) { |
1959 | b_reduce_ctx_.set_b_reduced_mem_buf(buf); |
1960 | } |
1961 | |
1962 | void set_b_reduced_view(const view_t &v) { |
1963 | b_reduce_ctx_.set_b_reduced_view(v); |
1964 | } |
1965 | |
1966 | void set_post_op_context(const post_op_context_t &post_op_ctx) { |
1967 | post_op_ctx_ = post_op_ctx; |
1968 | } |
1969 | |
1970 | void set_reduce_condition(const expr_t &cond) { |
1971 | b_reduce_ctx_.set_reduce_condition(cond); |
1972 | } |
1973 | |
1974 | void build() { |
1975 | // Initialize SLM buffers. |
1976 | expr_t a_slm_buf = make_buffer("a_slm" ); |
1977 | expr_t b_slm_buf = make_buffer("b_slm" ); |
1978 | |
1979 | view_t ap_gmem_view = gemm_schedule_.a_tg_view(); |
1980 | view_t bp_gmem_view = gemm_schedule_.b_tg_view(); |
1981 | |
1982 | // Views to multiply by a thread group (either GMEM or SLM). |
1983 | view_t ap_x_view; |
1984 | view_t bp_x_view; |
1985 | prepare_gmem_to_slm("A" , cfg_.slm().a(), gemm_schedule_.a_tg_tile(), |
1986 | ap_gmem_view, ap_buf_, a_slm_buf, ap_x_view, g2s_ctx_); |
1987 | prepare_gmem_to_slm("B" , cfg_.slm().b(), gemm_schedule_.b_tg_tile(), |
1988 | bp_gmem_view, bp_buf_, b_slm_buf, bp_x_view, g2s_ctx_); |
1989 | prepare_prefetch("A" , cfg_.prefetch(), ap_gmem_view, ap_buf_); |
1990 | prepare_prefetch("B" , cfg_.prefetch(), bp_gmem_view, bp_buf_); |
1991 | |
1992 | if (ap_x_view.is_empty()) ap_x_view = ap_gmem_view; |
1993 | if (bp_x_view.is_empty()) bp_x_view = bp_gmem_view; |
1994 | |
1995 | for (auto &bi : g2s_ctx_.bufs) { |
1996 | register_compute_buffer(bi.buf, bi.size, alloc_kind_t::grf); |
1997 | } |
1998 | |
1999 | load_multiply_builder_t load_mul_builder(cfg_, ir_ctx_, gemm_schedule_, |
2000 | fma_helper_, b_reduce_ctx_, ap_buf_, a_slm_buf, bp_buf_, |
2001 | b_slm_buf, ap_x_view, bp_x_view, kernel_info_); |
2002 | |
2003 | load_mul_stmt_ = load_mul_builder.load_mul_stmt(); |
2004 | compute_allocs_.insert(compute_allocs_.end(), |
2005 | load_mul_builder.allocs().begin(), |
2006 | load_mul_builder.allocs().end()); |
2007 | |
2008 | auto c_buf = load_mul_builder.c_buf(); |
2009 | int c_size = load_mul_builder.c_reg_layout().size(); |
2010 | int c_size_grf_rounded = utils::rnd_up(c_size, ir_ctx_.grf_size()); |
2011 | register_out_buffer(c_buf, c_size_grf_rounded, alloc_kind_t::grf); |
2012 | |
2013 | auto c_thr_reg_layout = load_mul_builder.c_reg_layout(); |
2014 | auto thr_tile = gemm_schedule_.c_thr_tile(/*is_relative=*/false); |
2015 | |
2016 | auto reduce_cond = expr_t(); |
2017 | if (gemm_schedule_.with_thread_group_k_slicing()) { |
2018 | slm_reduce_builder_t slm_reduce_builder(ir_ctx_, |
2019 | gemm_schedule_.tg_grid(), c_buf, c_thr_reg_layout, |
2020 | thr_tile); |
2021 | c_store_stmt_ = c_store_stmt_.append(slm_reduce_builder.stmt()); |
2022 | c_thr_reg_layout = slm_reduce_builder.reg_layout(); |
2023 | thr_tile = slm_reduce_builder.thr_tile(); |
2024 | reduce_cond = slm_reduce_builder.reduce_cond(); |
2025 | } |
2026 | |
2027 | auto c_thr_mem_view = gemm_schedule_.c_view().create_sub_view(thr_tile); |
2028 | auto c_m2g_stmt = create_epilogue_stmt(cfg_, ir_ctx_, gemm_schedule_, |
2029 | post_op_ctx_, thr_tile, c_thr_mem_view, c_thr_reg_layout, |
2030 | cp_buf_, c_buf); |
2031 | if (!reduce_cond.is_empty()) |
2032 | c_m2g_stmt = if_t::make(reduce_cond, c_m2g_stmt); |
2033 | ir_trace() << "C GRF to GMEM store:\n" << c_m2g_stmt << std::endl; |
2034 | |
2035 | c_zero_out_stmt_ = stmt_group_t::make(stmt_label_t::c_zero_out(), |
2036 | create_zero_out_stmt(ir_ctx_, c_buf, c_size)); |
2037 | c_store_stmt_ = c_store_stmt_.append(c_m2g_stmt); |
2038 | |
2039 | if (cfg_.reduce_b()) { |
2040 | auto &ctx = b_reduce_ctx_; |
2041 | b_reduced_zero_out_stmt_ = create_zero_out_stmt( |
2042 | ir_ctx_, ctx.b_reduced_reg_buf(), ctx.b_reduced_size()); |
2043 | b_reduced_store_stmt_ = ctx.create_store_stmt( |
2044 | gemm_schedule_.with_kernel_grid_k_slicing() |
2045 | || cfg_.slm().b()); |
2046 | register_out_buffer(ctx.b_reduced_reg_buf(), ctx.b_reduced_size(), |
2047 | alloc_kind_t::grf); |
2048 | } |
2049 | |
2050 | // Replace DPAS by DPASW when applicable. |
2051 | if (cfg_.fma_kind() == fma_kind_t::dpasw) { |
2052 | alloc_updater_t alloc_updater; |
2053 | inject_dpasw(ir_ctx_.hw(), load_mul_stmt_, c_buf, c_store_stmt_, |
2054 | alloc_updater, gemm_schedule_.tg_grid().idx(0)); |
2055 | for (auto &a : compute_allocs_) { |
2056 | a = alloc_updater.update(a); |
2057 | } |
2058 | for (auto &a : out_allocs_) { |
2059 | a = alloc_updater.update(a); |
2060 | } |
2061 | } |
2062 | |
2063 | // Assign {Atomic} for DPAS(W) when applicable. |
2064 | load_mul_stmt_ = inject_atomic(load_mul_stmt_); |
2065 | } |
2066 | |
2067 | private: |
2068 | struct buf_info_t { |
2069 | buf_info_t(const std::string &tag, const expr_t &buf) |
2070 | : tag(tag), buf(buf) {} |
2071 | |
2072 | std::string tag; |
2073 | expr_t buf; |
2074 | int size = 0; |
2075 | }; |
2076 | |
2077 | struct g2s_context_t { |
2078 | g2s_context_t(ir_context_t &ir_ctx) : ir_ctx(ir_ctx) {} |
2079 | |
2080 | expr_t create_buf(const char *tag, bool force_reuse = false) { |
2081 | if (reuse_buffers || force_reuse) { |
2082 | for (auto &bi : bufs) { |
2083 | if (bi.tag == tag) return bi.buf; |
2084 | } |
2085 | } |
2086 | auto buf = ir_ctx.create_tmp_var(type_t::byte_ptr(), tag); |
2087 | bufs.emplace_back(tag, buf); |
2088 | return buf; |
2089 | } |
2090 | |
2091 | void set_buf_size(const expr_t &buf, int size) { |
2092 | for (auto &bi : bufs) { |
2093 | if (bi.buf.is_same(buf)) bi.size = std::max(bi.size, size); |
2094 | } |
2095 | } |
2096 | |
2097 | expr_t create_tmp_grid_idx() { |
2098 | auto var = ir_ctx.create_tmp_var(type_t::s32(), "idx" ); |
2099 | tmp_grid_idxs.insert({var, expr_t()}); |
2100 | return var; |
2101 | } |
2102 | |
2103 | void set_grid_idx_value(const expr_t &idx, const expr_t &value) { |
2104 | auto &old = tmp_grid_idxs[idx]; |
2105 | ir_assert(old.is_empty()); |
2106 | old = substitute_grid_idx_value(value); |
2107 | } |
2108 | |
2109 | expr_t substitute_grid_idx_value(const expr_t &_e) { |
2110 | auto e = _e; |
2111 | auto vars = find_unique_objects<var_t>(e); |
2112 | for (auto &v : vars) { |
2113 | auto it = tmp_grid_idxs.find(v); |
2114 | if (it == tmp_grid_idxs.end()) continue; |
2115 | e = substitute(e, v, it->second); |
2116 | } |
2117 | return e; |
2118 | } |
2119 | |
2120 | void register_grid(const grid_info_t &grid) { |
2121 | for (int i = 0; i < grid.ndims(); i++) { |
2122 | auto &idx = grid.idx(i); |
2123 | auto it = tmp_grid_idxs.find(idx); |
2124 | if (it == tmp_grid_idxs.end()) continue; |
2125 | grid_idx_lets.emplace_back(let_t::make(idx, it->second)); |
2126 | } |
2127 | } |
2128 | |
2129 | ir_context_t &ir_ctx; |
2130 | grid_info_t prev_load_grid; |
2131 | bool reuse_buffers = false; |
2132 | std::vector<buf_info_t> bufs; |
2133 | |
2134 | object_map_t<expr_t, expr_t> tmp_grid_idxs; |
2135 | std::vector<stmt_t> grid_idx_lets; |
2136 | }; |
2137 | |
2138 | void register_compute_buffer(const expr_t &buf, int size, alloc_kind_t kind, |
2139 | const alloc_attr_t &attr = {}) { |
2140 | compute_allocs_.push_back(alloc_t::make(buf, size, kind, attr)); |
2141 | } |
2142 | |
2143 | void register_out_buffer(const expr_t &buf, int size, alloc_kind_t kind, |
2144 | const alloc_attr_t &attr = {}) { |
2145 | out_allocs_.push_back(alloc_t::make(buf, size, kind, attr)); |
2146 | } |
2147 | |
2148 | // Handles GMEM to SLM load for A and B. Done in two steps: |
2149 | // 1. Load: GMEM -> GRF (temporary) |
2150 | // 2. Store: GRF (temporary) -> SLM |
2151 | void prepare_gmem_to_slm(const char *tag, bool use_x_slm, |
2152 | const tensor_t &tg_tile, const view_t &x_gmem_view, |
2153 | const expr_t &xp_buf, const expr_t &x_slm_buf, view_t &x_slm_view, |
2154 | g2s_context_t &g2s_ctx) { |
2155 | if (!use_x_slm) return; |
2156 | |
2157 | grid_info_t load_grid = gemm_schedule_.tg_grid(); |
2158 | for (;;) { |
2159 | bool ok = prepare_gmem_to_slm_impl(tag, use_x_slm, tg_tile, |
2160 | x_gmem_view, xp_buf, x_slm_buf, x_slm_view, load_grid, |
2161 | g2s_ctx); |
2162 | if (ok) { |
2163 | g2s_ctx.prev_load_grid = load_grid; |
2164 | g2s_ctx.register_grid(load_grid); |
2165 | return; |
2166 | } |
2167 | |
2168 | // Reduce grid and try again. |
2169 | auto grid_idx = g2s_ctx.create_tmp_grid_idx(); |
2170 | int dim_idx; |
2171 | expr_t grid_idx_value; |
2172 | auto new_load_grid |
2173 | = load_grid.halven(grid_idx, dim_idx, grid_idx_value); |
2174 | if (new_load_grid.is_empty()) break; |
2175 | |
2176 | if (new_load_grid == g2s_ctx.prev_load_grid) { |
2177 | new_load_grid = load_grid.halven( |
2178 | grid_idx, dim_idx, grid_idx_value, /*first=*/false); |
2179 | g2s_ctx.reuse_buffers = true; |
2180 | } |
2181 | g2s_ctx.set_grid_idx_value(grid_idx, grid_idx_value); |
2182 | |
2183 | ir_ctx_.add_constraint(grid_idx >= 0); |
2184 | ir_ctx_.add_constraint(grid_idx < new_load_grid.dim(dim_idx)); |
2185 | |
2186 | load_grid = new_load_grid; |
2187 | } |
2188 | ir_error_not_expected() << "Can't create GMEM -> SLM loads/stores." ; |
2189 | } |
2190 | |
2191 | bool prepare_gmem_to_slm_impl(const char *tag, bool use_x_slm, |
2192 | const tensor_t &tg_tile, const view_t &x_gmem_view, |
2193 | const expr_t &xp_buf, const expr_t &x_slm_buf, view_t &x_slm_view, |
2194 | const grid_info_t &load_grid, g2s_context_t &g2s_ctx) { |
2195 | bool is_a = (tag[0] == 'A'); |
2196 | abc_kind_t ab_kind = (is_a ? abc_kind_t::a : abc_kind_t::b); |
2197 | |
2198 | auto xp_slm_layout = create_slm_layout(x_gmem_view, ab_kind, load_grid); |
2199 | |
2200 | auto grid_cond = load_grid.slice_condition(); |
2201 | |
2202 | // Per-thread tile and view to load from GMEM and store to SLM. |
2203 | tensor_t thr_tile; |
2204 | view_t x_g2s_view; |
2205 | if (cfg_.allow_slm_tg_slicing()) { |
2206 | x_g2s_view = x_gmem_view.split(load_grid, thr_tile); |
2207 | } else { |
2208 | thr_tile = xp_slm_layout.split(load_grid); |
2209 | x_g2s_view = x_gmem_view.create_sub_view(thr_tile); |
2210 | } |
2211 | |
2212 | auto bound_cond = expr_t(); |
2213 | if (is_a && !cfg_.fuse_spatial() |
2214 | && thr_tile.elems() * load_grid.elems() |
2215 | != xp_slm_layout.elems()) { |
2216 | for (int i = 0; i < x_gmem_view.nvdims(); i++) { |
2217 | if (!x_g2s_view.vstart(i).is_equal(x_gmem_view.vstart(i))) { |
2218 | auto dim_expr |
2219 | = x_g2s_view.vstart(i) - x_gmem_view.vstart(i); |
2220 | if (bound_cond.is_empty()) |
2221 | bound_cond = dim_expr < x_gmem_view.vdims()[i]; |
2222 | else |
2223 | bound_cond &= dim_expr < x_gmem_view.vdims()[i]; |
2224 | } |
2225 | } |
2226 | } |
2227 | if (!bound_cond.is_empty()) { |
2228 | if (!grid_cond.is_empty()) |
2229 | grid_cond = grid_cond & bound_cond; |
2230 | else |
2231 | grid_cond = bound_cond; |
2232 | } |
2233 | |
2234 | auto slm_thr_layout = xp_slm_layout.map(thr_tile); |
2235 | |
2236 | // Ensure that each thread writes a dense region to SLM. If the layout |
2237 | // is not dense, return and try with smaller grid. |
2238 | if (!slm_thr_layout.is_dense()) return false; |
2239 | |
2240 | register_compute_buffer( |
2241 | x_slm_buf, xp_slm_layout.size(), alloc_kind_t::slm); |
2242 | ab_slm_size_ += xp_slm_layout.size(); |
2243 | |
2244 | // Temporary GRF buffer. |
2245 | expr_t x_g2s_reg_buf = g2s_ctx.create_buf("g2s" ); |
2246 | |
2247 | // GMEM -> GRF load. |
2248 | auto x_read = make_access_builder(ir_ctx_, x_g2s_view, xp_buf, |
2249 | x_g2s_reg_buf, send_op_t::load, send_address_t::a64); |
2250 | ir_trace() << tag << " GMEM to GRF load:\n" |
2251 | << x_read.str() << std::endl; |
2252 | |
2253 | g2s_ctx.set_buf_size(x_g2s_reg_buf, x_read.reg_buf_size()); |
2254 | |
2255 | auto load_stmt = x_read.stmt(); |
2256 | if (!grid_cond.is_empty()) load_stmt = if_t::make(grid_cond, load_stmt); |
2257 | g2s_load_stmt_ = g2s_load_stmt_.append(load_stmt); |
2258 | |
2259 | // GRF -> SLM store. |
2260 | auto x_write = make_access_builder(ir_ctx_, view_t(slm_thr_layout), |
2261 | x_slm_buf, x_g2s_reg_buf, send_op_t::store, |
2262 | send_address_t::slm); |
2263 | ir_trace() << tag << " GRF to SLM store:\n" |
2264 | << x_write.str() << std::endl; |
2265 | auto store_stmt = x_write.stmt(); |
2266 | |
2267 | auto &read_layout = x_read.reg_layout(); |
2268 | auto &write_layout = x_write.reg_layout(); |
2269 | if (read_layout != write_layout) { |
2270 | if (is_a ? cfg_.allow_a_grf_reorder() |
2271 | : cfg_.allow_b_grf_reorder()) { |
2272 | // Temporary GRF buffer. |
2273 | expr_t tmp_buf |
2274 | = g2s_ctx.create_buf("g2s_tmp" , /*force_reuse=*/true); |
2275 | auto reorder_stmt = create_reorder_stmt( |
2276 | read_layout, write_layout, x_g2s_reg_buf, tmp_buf); |
2277 | g2s_ctx.set_buf_size(tmp_buf, x_write.reg_buf_size()); |
2278 | store_stmt = substitute(store_stmt, x_g2s_reg_buf, tmp_buf); |
2279 | store_stmt = reorder_stmt.append(store_stmt); |
2280 | } else { |
2281 | ir_error_not_expected() << "Requested register layouts for " |
2282 | << tag << " do not match: " |
2283 | << "read: " << read_layout |
2284 | << ", write: " << write_layout; |
2285 | } |
2286 | } |
2287 | // Generate reduction statement for B. |
2288 | if (!is_a && cfg_.reduce_b()) { |
2289 | auto absolute_thr_tile = tg_tile.create_sub_tensor(thr_tile); |
2290 | b_reduce_ctx_.init_reduced_thr_view(absolute_thr_tile, grid_cond); |
2291 | auto reduce_stmt = b_reduce_ctx_.create_reduce_stmt( |
2292 | read_layout, x_g2s_reg_buf); |
2293 | store_stmt = reduce_stmt.append(store_stmt); |
2294 | } |
2295 | if (!grid_cond.is_empty()) |
2296 | store_stmt = if_t::make(grid_cond, store_stmt); |
2297 | g2s_store_stmt_ = g2s_store_stmt_.append(store_stmt); |
2298 | |
2299 | x_slm_view = view_t(xp_slm_layout); |
2300 | |
2301 | return true; |
2302 | } |
2303 | |
2304 | void prepare_prefetch(const char *tag, bool use_prefetch, |
2305 | const view_t &x_gmem_view, const expr_t &xp_buf) { |
2306 | if (!use_prefetch) return; |
2307 | |
2308 | // Per-thread view to prefetch from GMEM. |
2309 | auto thr_view = x_gmem_view.split(gemm_schedule_.tg_grid()); |
2310 | |
2311 | auto send_hint = get_send_hint(ir_ctx_.exec_cfg(), send_op_t::prefetch, |
2312 | (tag[0] == 'A') ? abc_kind_t::a : abc_kind_t::b, thr_view, |
2313 | gemm_schedule_); |
2314 | |
2315 | // GMEM prefetch. |
2316 | auto x_prefetch = make_access_builder(ir_ctx_, thr_view, xp_buf, |
2317 | expr_t(), send_op_t::prefetch, send_address_t::a64, send_hint); |
2318 | |
2319 | // too many prefetches degrades performance |
2320 | if (find_objects<func_call_t>(x_prefetch.stmt()).size() > 16) { |
2321 | ir_warning() << "Dropping excessive prefetches." << std::endl; |
2322 | prefetch_stmt_ = stmt_t(); |
2323 | } else { |
2324 | ir_trace() << tag << " GMEM prefetch:\n" |
2325 | << x_prefetch.str() << std::endl; |
2326 | prefetch_stmt_ = prefetch_stmt_.append(x_prefetch.stmt()); |
2327 | } |
2328 | } |
2329 | |
2330 | layout_t create_slm_layout(const view_t &tg_view, abc_kind_t abc_kind, |
2331 | const grid_info_t &load_grid) const { |
2332 | auto layout = tg_view.create_dense_vlayout(); |
2333 | auto ret = fma_helper_.convert_to_fma_friendly_layout(layout, abc_kind, |
2334 | /*is_slm=*/true, gemm_schedule_.bmnk_mapper()); |
2335 | if (cfg_.pad_slm()) ret = pad_slm_layout(ret, load_grid); |
2336 | return ret.normalize(); |
2337 | } |
2338 | |
2339 | // SLM has 65 dword-granularity banks (Xe_HP): |
2340 | // banks: [bank 0] [bank 1] [bank 2] ... [bank 0] |
2341 | // byte offsets: | 0 | 4 | 8 ... | 4 * 65 |
2342 | // SLM reads don't have conflicts. During SLM writes each fused EU writes |
2343 | // 64 bytes (in total 128 bytes per clock). If there are repeating banks |
2344 | // between 128 bytes the write takes 2 clocks to complete. |
2345 | // Assume that every X-axis thread (across tg_dim[0]) writes the |
2346 | // corresponding outer block of the layout. The goal is to ensure that the |
2347 | // stride between outer blocks allows to avoid duplicated banks. |
2348 | layout_t pad_slm_layout( |
2349 | const layout_t &layout, const grid_info_t &load_grid) const { |
2350 | // EUs are not fused in XeHPC+ so no need to pad SLM. |
2351 | if (ir_ctx_.hw() >= ngen::HW::XeHPC) return layout; |
2352 | auto tg_dim0 = load_grid.dim(0); |
2353 | auto tg_dim1 = load_grid.dim(1); |
2354 | int type_size = layout.type().size(); |
2355 | |
2356 | ir_assert(layout.elems() % tg_dim0 == 0) << layout; |
2357 | dim_t inner_block = layout.elems() / tg_dim0; |
2358 | |
2359 | ir_assert((inner_block * type_size) % tg_dim1 == 0) << layout; |
2360 | dim_t per_thr_bytes = (inner_block * type_size) / tg_dim1; |
2361 | |
2362 | std::vector<dim_t> multi_blocks = {inner_block, tg_dim0}; |
2363 | auto l = layout.split_into_multi_blocks(multi_blocks); |
2364 | |
2365 | if (l.is_empty()) { |
2366 | ir_warning() << "Couldn't split layout for SLM padding." |
2367 | << std::endl; |
2368 | return layout; |
2369 | } |
2370 | auto padded_blocks = l.blocks(); |
2371 | dim_t stride = -1; |
2372 | dim_t remaining_elems = inner_block; |
2373 | bool past_inner_block = remaining_elems == 1; |
2374 | for (auto &b : padded_blocks) { |
2375 | if (past_inner_block) { |
2376 | if (stride == -1) { |
2377 | dim_t stride_bytes = find_min_stride_without_conflicts( |
2378 | per_thr_bytes, dim_t(b.stride) * type_size); |
2379 | ir_assert(stride_bytes % type_size == 0); |
2380 | stride = stride_bytes / type_size; |
2381 | } |
2382 | b.stride = stride; |
2383 | stride = b.stride * b.block; |
2384 | continue; |
2385 | } |
2386 | ir_assert(remaining_elems % b.block == 0); |
2387 | remaining_elems /= b.block; |
2388 | if (remaining_elems == 1) past_inner_block = true; |
2389 | } |
2390 | return layout_t( |
2391 | layout.type(), layout.ndims(), layout.offset(), padded_blocks); |
2392 | } |
2393 | |
2394 | dim_t find_min_stride_without_conflicts( |
2395 | dim_t inner_bytes, dim_t dense_stride_bytes) const { |
2396 | int write_step = 64; |
2397 | int stride_step = 16; |
2398 | dim_t stride_beg = dense_stride_bytes; |
2399 | dim_t stride_end = 2 * dense_stride_bytes; |
2400 | auto arch = convert_ngen_arch_to_dnnl(ir_ctx_.hw()); |
2401 | const int slm_banks |
2402 | = compute::device_info_t::slm_memory_bank_count(arch); |
2403 | const int bank_granularity |
2404 | = compute::device_info_t::slm_memory_bank_granularity(arch); |
2405 | for (dim_t s = stride_beg; s < stride_end; s += stride_step) { |
2406 | bool ok = true; |
2407 | for (dim_t off0 = 0; off0 < inner_bytes; off0 += write_step) { |
2408 | // Check banks for a single SLM write. |
2409 | std::vector<bool> found(slm_banks, false); |
2410 | for (dim_t off = off0; off < off0 + write_step; |
2411 | off += bank_granularity) { |
2412 | int bank0 = (off / bank_granularity) % slm_banks; |
2413 | int bank1 = ((off + s) / bank_granularity) % slm_banks; |
2414 | if (found[bank0]) { |
2415 | ok = false; |
2416 | break; |
2417 | } |
2418 | found[bank0] = true; |
2419 | if (found[bank1]) { |
2420 | ok = false; |
2421 | break; |
2422 | } |
2423 | found[bank1] = true; |
2424 | } |
2425 | if (ok) return s; |
2426 | } |
2427 | } |
2428 | |
2429 | ir_warning() |
2430 | << "Couldn't find stride without conflicts for SLM padding." |
2431 | << std::endl; |
2432 | |
2433 | return dense_stride_bytes; |
2434 | } |
2435 | |
2436 | const conv_config_t &cfg_; |
2437 | ir_context_t &ir_ctx_; |
2438 | post_op_context_t post_op_ctx_; |
2439 | b_reduce_context_t b_reduce_ctx_; |
2440 | |
2441 | g2s_context_t g2s_ctx_; |
2442 | fma_helper_t fma_helper_; |
2443 | |
2444 | gemm_schedule_t gemm_schedule_; |
2445 | |
2446 | expr_t ap_buf_; |
2447 | expr_t bp_buf_; |
2448 | expr_t cp_buf_; |
2449 | |
2450 | std::vector<stmt_t> compute_allocs_; |
2451 | std::vector<stmt_t> out_allocs_; |
2452 | int ab_slm_size_ = 0; |
2453 | |
2454 | stmt_t g2s_load_stmt_; |
2455 | stmt_t g2s_store_stmt_; |
2456 | stmt_t prefetch_stmt_; |
2457 | stmt_t load_mul_stmt_; |
2458 | |
2459 | stmt_t c_zero_out_stmt_; |
2460 | stmt_t c_store_stmt_; |
2461 | |
2462 | stmt_t b_reduced_zero_out_stmt_; |
2463 | stmt_t b_reduced_store_stmt_; |
2464 | |
2465 | const kernel_info_t &kernel_info_; |
2466 | }; |
2467 | |
2468 | class compute_loop_label_injector_t : public ir_mutator_t { |
2469 | public: |
2470 | object_t _mutate(const for_t &obj) override { |
2471 | if (injected_) return obj; |
2472 | |
2473 | bool found_continue = false; |
2474 | auto calls = find_objects<func_call_t>(obj); |
2475 | for (auto &_c : calls) { |
2476 | auto &c = _c.as<func_call_t>(); |
2477 | if (c.func.is_equal(funcs::continue_func())) found_continue = true; |
2478 | } |
2479 | |
2480 | if (!found_continue) { |
2481 | injected_ = true; |
2482 | return stmt_group_t::make(stmt_label_t::compute_loop(), obj); |
2483 | } |
2484 | return ir_mutator_t::_mutate(obj); |
2485 | } |
2486 | |
2487 | private: |
2488 | bool injected_ = false; |
2489 | }; |
2490 | |
2491 | // Injects compute_loop statement label to the outermost loop that can be |
2492 | // pipelined. If a loop contains a "continue" function call it can't be |
2493 | // pipelined because of conditional flow. |
2494 | stmt_t inject_compute_loop_label(const stmt_t &s) { |
2495 | return compute_loop_label_injector_t().mutate(s); |
2496 | } |
2497 | |
2498 | void conv_ir_builder_t::build() { |
2499 | constraint_set_t init_cset; |
2500 | |
2501 | trace_reset(); |
2502 | |
2503 | std::vector<stmt_t> init_stmts; |
2504 | init_kernel_grid(cfg_.kernel_grid(), cfg_.thread_group_grid(), cfg_.simd(), |
2505 | init_cset, init_stmts); |
2506 | |
2507 | gemm_schedule_t gemm_schedule( |
2508 | init_cset, cfg_.kernel_grid(), cfg_.thread_group_grid()); |
2509 | |
2510 | // Initialize memory buffers. |
2511 | std::vector<stmt_t> inner_lets; |
2512 | |
2513 | view_t a_view; |
2514 | view_t b_view; |
2515 | view_t c_view; |
2516 | view_t bp_reduced_view; |
2517 | |
2518 | expr_t ap_buf; |
2519 | expr_t bp_buf; |
2520 | expr_t cp_buf; |
2521 | expr_t b_reduced_mem_buf; |
2522 | expr_t b_reduction_condition; |
2523 | |
2524 | if (cfg_.prb().is_fwd) { |
2525 | init_fwd(gemm_schedule, a_view, b_view, c_view, ap_buf, bp_buf, cp_buf); |
2526 | } else if (cfg_.prb().is_bwd_d) { |
2527 | init_bwd_d( |
2528 | gemm_schedule, a_view, b_view, c_view, ap_buf, bp_buf, cp_buf); |
2529 | } else if (cfg_.prb().is_bwd_w) { |
2530 | init_bwd_w(gemm_schedule, a_view, b_view, c_view, bp_reduced_view, |
2531 | ap_buf, bp_buf, cp_buf, b_reduced_mem_buf, |
2532 | b_reduction_condition); |
2533 | } else { |
2534 | ir_error_not_expected(); |
2535 | } |
2536 | |
2537 | gemm_schedule.finalize(); |
2538 | |
2539 | trace_stamp("GEMM Schedule" ); |
2540 | |
2541 | ir_context_t ir_ctx(cfg_.exec_cfg(), init_cset); |
2542 | post_op_context_t post_op_ctx(cfg_, gemm_schedule, kernel_info_); |
2543 | compute_builder_t cb(cfg_, ir_ctx, kernel_info_); |
2544 | |
2545 | cb.set_gemm_schedule(gemm_schedule); |
2546 | cb.set_ap_buf(ap_buf); |
2547 | cb.set_bp_buf(bp_buf); |
2548 | cb.set_cp_buf(cp_buf); |
2549 | cb.set_b_reduced_mem_buf(b_reduced_mem_buf); |
2550 | cb.set_b_reduced_view(bp_reduced_view); |
2551 | cb.set_post_op_context(post_op_ctx); |
2552 | cb.set_reduce_condition(b_reduction_condition); |
2553 | |
2554 | cb.build(); |
2555 | |
2556 | trace_stamp("Compute Builder" ); |
2557 | |
2558 | std::vector<stmt_t> allocs; |
2559 | for (int i = 0; i < kernel_info_.nargs(); i++) { |
2560 | auto &var = kernel_info_.arg_var(i); |
2561 | if (!var.type().is_ptr()) continue; |
2562 | allocs.push_back(alloc_t::make(var, 0, alloc_kind_t::global)); |
2563 | } |
2564 | |
2565 | // Create IR statements. |
2566 | stmt_t loop_stmt = cb.iter_stmt(); |
2567 | loop_stmt = gemm_schedule.create_loop_nest(loop_stmt); |
2568 | loop_stmt = inject_compute_loop_label(loop_stmt); |
2569 | loop_stmt = cb.inject_compute_alloc_stmts(loop_stmt); |
2570 | |
2571 | stmt_t c_store_stmt; |
2572 | c_store_stmt = c_store_stmt.append(cb.b_reduced_store_stmt()); |
2573 | c_store_stmt = c_store_stmt.append(cb.c_store_stmt()); |
2574 | c_store_stmt = stmt_group_t::make(stmt_label_t::c_store(), c_store_stmt); |
2575 | |
2576 | stmt_ = loop_stmt; |
2577 | stmt_ = stmt_seq_t::make(cb.zero_out_stmt(), stmt_); |
2578 | stmt_ = stmt_seq_t::make(stmt_, c_store_stmt); |
2579 | |
2580 | stmt_ = cb.inject_out_alloc_stmts(stmt_); |
2581 | stmt_ = cb.inject_let_stmts(stmt_); |
2582 | |
2583 | stmt_ = gemm_schedule.create_bind_stmt(stmt_); |
2584 | stmt_ = inject_let_stmts(stmt_, init_stmts); |
2585 | stmt_ = inject_alloc_stmts(stmt_, allocs); |
2586 | trace_stop("Create Inital IR" ); |
2587 | |
2588 | stmt_ = inject_external_var_let(stmt_, ir_ctx); |
2589 | stmt_ = merge_slm_buffers(stmt_, ir_ctx); |
2590 | if (!cfg_.pipeline().do_unroll() && cfg_.slm()) { |
2591 | stmt_ = inject_simple_slm_buffering( |
2592 | stmt_, ir_ctx, cfg_, cb.ab_slm_size()); |
2593 | } else if (!cfg_.pipeline().do_unroll() && cfg_.prefetch()) { |
2594 | // Simplify to remove loops with only 1 iteration |
2595 | stmt_ = simplify(stmt_, ir_ctx); |
2596 | stmt_ = inject_prefetch_pipeline(stmt_, ir_ctx, cfg_); |
2597 | } |
2598 | stmt_ = inject_slm_reorder( |
2599 | stmt_, ir_ctx, cfg_.thread_group_grid(), cfg_.slm()); |
2600 | stmt_ = lift_buffer_offsets_in_send(stmt_, ir_ctx); |
2601 | stmt_ = simplify(stmt_, ir_ctx); |
2602 | stmt_ = inject_send(stmt_, ir_ctx); |
2603 | stmt_ = split_wide_stores(stmt_, ir_ctx); |
2604 | stmt_ = lift_alloc(stmt_, ir_ctx, cfg_.pipeline().reuse_headers()); |
2605 | stmt_ = lift_send_2d_header_store(stmt_, ir_ctx); |
2606 | stmt_ = hoist_send_masks(stmt_, ir_ctx, stmt_label_t::c_store(), false); |
2607 | stmt_ = split_shuffle(stmt_, ir_ctx); |
2608 | stmt_ = eliminate_common_subexprs( |
2609 | stmt_, ir_ctx, cfg_.reserved_regs(), cfg_.slm().gmem_bufs()); |
2610 | stmt_ = hoist_exprs(stmt_, ir_ctx, cfg_.reserved_regs()); |
2611 | if (cfg_.pipeline().do_unroll()) |
2612 | stmt_ = loop_strength_reduce(stmt_, ir_ctx); |
2613 | stmt_ = optimize_alloc_let(stmt_, ir_ctx); |
2614 | if (cfg_.pipeline().do_unroll()) { |
2615 | stmt_ = update_loops_for_unrolling(stmt_, ir_ctx); |
2616 | stmt_ = inject_unrolling(stmt_, ir_ctx, cfg_, cb.ab_slm_size()); |
2617 | } |
2618 | if (cfg_.hoist_masks_from_compute_loop()) { |
2619 | stmt_ = hoist_send_masks( |
2620 | stmt_, ir_ctx, stmt_label_t::compute_loop(), true); |
2621 | } |
2622 | stmt_ = fixup_if_conditions(stmt_, ir_ctx); |
2623 | stmt_ = unroll_loops(stmt_, ir_ctx); |
2624 | stmt_ = simplify(stmt_, ir_ctx); |
2625 | stmt_ = maybe_strip_prefetches(stmt_, ir_ctx, cfg_.reserved_regs()); |
2626 | stmt_ = optimize_alloc_let(stmt_, ir_ctx); |
2627 | if (cfg_.hoist_masks_from_compute_loop()) { |
2628 | stmt_ = remove_spurious_send_mask_cast(stmt_, ir_ctx); |
2629 | } |
2630 | stmt_ = fix_int32_overflow(stmt_, ir_ctx); |
2631 | stmt_ = optimize_peephole(stmt_, ir_ctx); |
2632 | stmt_ = optimize_barrier(stmt_, ir_ctx); |
2633 | if (cfg_.fma_kind() == fma_kind_t::dp4a) stmt_ = inject_dp4a(stmt_, ir_ctx); |
2634 | stmt_ = inject_bank_conflict_attribute(stmt_, ir_ctx); |
2635 | stmt_ = stmt_group_t::make(stmt_label_t::kernel(), stmt_); |
2636 | |
2637 | #if !defined(NDEBUG) || defined(GEN_CONV_DEBUG) |
2638 | verify_buffer_access(stmt_, ir_ctx); |
2639 | #endif |
2640 | |
2641 | ir_trace() << "Convolution kernel body:\n" << stmt_ << std::endl; |
2642 | trace_perf(); |
2643 | } |
2644 | |
2645 | } // namespace jit |
2646 | } // namespace gpu |
2647 | } // namespace impl |
2648 | } // namespace dnnl |
2649 | |