1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/conv/ir_builder.hpp"
18
19#include <algorithm>
20#include <array>
21#include <iostream>
22#include <limits>
23#include <memory>
24#include <numeric>
25#include <utility>
26#include <vector>
27#include <unordered_map>
28
29#include "gpu/jit/conv/config.hpp"
30#include "gpu/jit/conv/epilogue.hpp"
31#include "gpu/jit/conv/pipeline.hpp"
32#include "gpu/jit/conv/post_ops.hpp"
33#include "gpu/jit/conv/slm_reduce_builder.hpp"
34#include "gpu/jit/ir/fma.hpp"
35#include "gpu/jit/ir/gemm_schedule.hpp"
36#include "gpu/jit/ir/ir.hpp"
37#include "gpu/jit/ir/message.hpp"
38#include "gpu/jit/ir/mul_add.hpp"
39#include "gpu/jit/ir/reduce.hpp"
40#include "gpu/jit/ir/reorder.hpp"
41#include "gpu/jit/ir/tensor.hpp"
42#include "gpu/jit/pass/pass.hpp"
43#include "gpu/jit/utils/trace.hpp"
44
45namespace dnnl {
46namespace impl {
47namespace gpu {
48namespace jit {
49
50class buffer_access_verifier_t : public ir_visitor_t {
51public:
52 void _visit(const alloc_t &obj) override {
53 buf_sizes_.emplace(obj.buf, obj.size);
54 ir_visitor_t::_visit(obj);
55 buf_sizes_.erase(obj.buf);
56 }
57
58 void _visit(const func_call_t &obj) override {
59 auto &func = obj.func;
60 if (auto *dpas = func.as_ptr<dpas_t>()) {
61 auto &dst = dpas_t::arg_dst(obj);
62 auto &src0 = dpas_t::arg_src0(obj);
63 auto &src1 = dpas_t::arg_src1(obj);
64 auto &src2 = dpas_t::arg_src2(obj);
65 check_access(dst, dpas->dst_size(), obj);
66 if (!is_zero(src0)) check_access(src0, dpas->src0_size(), obj);
67 check_access(src1, dpas->src1_size(), obj);
68 check_access(src2, dpas->src2_size(), obj);
69 } else if (func.is<eltwise_t>()) {
70 auto &elems = eltwise_t::arg_elems(obj);
71 auto &data = eltwise_t::arg_data(obj);
72 int size = to_cpp<int>(elems) * sizeof(float);
73 check_access(data, size, obj);
74 } else if (auto *mad = func.as_ptr<mad_t>()) {
75 auto &dst = mad_t::arg_dst(obj);
76 auto &src0 = mad_t::arg_src0(obj);
77 auto &src1 = mad_t::arg_src1(obj);
78 auto &src2 = mad_t::arg_src2(obj);
79 check_access(dst, mad->dst_size(), obj);
80 if (!is_zero(src0)) check_access(src0, mad->src0_size(), obj);
81 check_access(src1, mad->src1_size(), obj);
82 check_access(src2, mad->src2_size(), obj);
83 } else if (auto *reduce = func.as_ptr<reduce_t>()) {
84 auto &dst_buf = reduce_t::arg_dst_buf(obj);
85 auto &src_buf = reduce_t::arg_src_buf(obj);
86 check_access(dst_buf, reduce->dst_layout.size(), obj);
87 check_access(src_buf, reduce->src_layout.size(), obj);
88 } else if (auto *reorder = func.as_ptr<reorder_t>()) {
89 auto &dst_buf = reorder_t::arg_dst_buf(obj);
90 auto &src_buf = reorder_t::arg_src_buf(obj);
91 check_access(dst_buf, reorder->dst_layout.size(), obj);
92 check_access(src_buf, reorder->src_layout.size(), obj);
93 return;
94 } else if (auto *send = func.as_ptr<send_t>()) {
95 if (!send->is_prefetch() && !send->is_prefetch_2d()) {
96 auto &reg_buf = send_t::arg_reg_buf(obj);
97 int size = send->payload_size();
98 check_access(reg_buf, size, obj);
99 }
100 return;
101 } else if (func.is<builtin_t>()) {
102 // No buffers to check.
103 } else {
104 ir_error_not_expected() << "Unhandled function: " << obj;
105 }
106
107 ir_visitor_t::_visit(obj);
108 }
109
110 void _visit(const load_t &obj) override {
111 auto elem_type = obj.type.scalar();
112 int stride_bytes
113 = (obj.has_default_stride() ? elem_type.size() : obj.stride);
114 int off = to_cpp<int>(obj.off);
115 auto stride = (obj.type.elems() - 1) * stride_bytes;
116 check_access(obj.buf + off, stride + elem_type.size(), obj);
117 ir_visitor_t::_visit(obj);
118 }
119
120 void _visit(const store_t &obj) override {
121 auto elem_type = obj.value.type().scalar();
122 int stride_bytes
123 = (obj.has_default_stride() ? elem_type.size() : obj.stride);
124 int off = to_cpp<int>(obj.off);
125 auto stride = (obj.value.type().elems() - 1) * stride_bytes;
126 check_access(obj.buf + off, stride + elem_type.size(), obj);
127 ir_visitor_t::_visit(obj);
128 }
129
130private:
131 void check_access(const expr_t &buf, int size, const object_t &obj) {
132 auto &base = (is_var(buf) ? buf : buf.as<ptr_t>().base);
133 int off = (is_var(buf) ? 0 : to_cpp<int>(buf.as<ptr_t>().off));
134 auto it = buf_sizes_.find(base);
135 ir_assert(it != buf_sizes_.end())
136 << "Can't find allocation for buffer: " << buf;
137 int buf_size = it->second;
138 ir_assert(off + size <= buf_size)
139 << "Invalid access:\n " << obj << "\n Buffer " << base
140 << " has size: " << buf_size;
141 }
142
143 object_map_t<expr_t, int> buf_sizes_;
144};
145
146void verify_buffer_access(const stmt_t &s, ir_context_t &ir_ctx) {
147 trace_start();
148 buffer_access_verifier_t verifier;
149 verifier.visit(s);
150 trace_pass("verify_buffer_access", s, ir_ctx);
151}
152
153class multiply_builder_t {
154public:
155 multiply_builder_t() = default;
156
157 multiply_builder_t(const conv_config_t &cfg,
158 const bmnk_mapper_t &bmnk_mapper, const view_t &a_view,
159 const view_t &b_view, const expr_t &a_buf, const expr_t &b_buf,
160 const expr_t &c_buf)
161 : hw_(cfg.hw())
162 , simd_size_(cfg.simd())
163 , bmnk_mapper_(bmnk_mapper)
164 , a_view_(a_view)
165 , b_view_(b_view)
166 , a_buf_(a_buf)
167 , b_buf_(b_buf)
168 , c_buf_(c_buf) {
169 switch (cfg.fma_kind()) {
170 case fma_kind_t::dp4a:
171 case fma_kind_t::dpas:
172 case fma_kind_t::dpasw:
173 if (try_build_dpas()) return;
174 break;
175 case fma_kind_t::mad:
176 if (try_build_mad()) return;
177 break;
178 default: ir_error_not_expected() << "Unknown FMA kind.";
179 }
180
181 ir_error_not_expected()
182 << "Can't decompose into multiplication instructions. A view: "
183 << a_view << ". B view: " << b_view;
184 }
185
186 const stmt_t &stmt() const { return stmt_; }
187
188 const layout_t &c_layout() const { return c_layout_; }
189
190 bool do_transpose() const { return do_transpose_; }
191
192 std::string str() const {
193 std::ostringstream oss;
194 oss << "A view: " << a_view_ << std::endl;
195 oss << "B view: " << b_view_ << std::endl;
196 oss << "C layout: " << c_layout_ << std::endl;
197 oss << "Statement: " << std::endl << stmt_;
198 return oss.str();
199 }
200
201private:
202 struct loop_info_t {
203 loop_info_t() = default;
204
205 loop_info_t(const expr_t &var, bmnk_kind_t bmnk_kind, int dim)
206 : var(var), bmnk_kind(bmnk_kind), dim(dim) {}
207
208 expr_t var;
209 bmnk_kind_t bmnk_kind;
210
211 int dim;
212 int a_idx = -1;
213 int b_idx = -1;
214 int c_idx = -1;
215 int block = 1;
216 };
217
218 bool try_build_dpas() {
219 ir_assert(a_view_.can_convert_to_vlayout())
220 << "Views are not supported with dpas/dpasw.";
221 ir_assert(b_view_.can_convert_to_vlayout())
222 << "Views are not supported with dpas/dpasw.";
223
224 auto a_layout = a_view_.create_vlayout();
225 auto b_layout = b_view_.create_vlayout();
226
227 check_k_blocks_order(a_layout, b_layout);
228
229 bmnk_block_mapper_t from_bmnk_mapper(bmnk_mapper_);
230 from_bmnk_mapper.push_blocks(abc_kind_t::a, a_layout.blocks());
231 from_bmnk_mapper.push_blocks(abc_kind_t::b, b_layout.blocks());
232
233 // Convert to MNK layouts.
234 a_layout = bmnk_mapper_.map_to_bmnk(
235 abc_kind_t::a, {bmnk_kind_t::m, bmnk_kind_t::k}, a_layout);
236 b_layout = bmnk_mapper_.map_to_bmnk(
237 abc_kind_t::b, {bmnk_kind_t::k, bmnk_kind_t::n}, b_layout);
238
239 multiply_desc_t desc(a_layout, b_layout, /*force_c_upconvert=*/true);
240 if (!dpas_t::matches_types(
241 hw_, desc.a_type(), desc.b_type(), desc.c_type()))
242 return false;
243
244 int sdepth = 8;
245 int rcount = std::min(utils::rnd_up_pow2(desc.n()), 8);
246 auto _dpas = dpas_t::make(/*is_dpasw=*/false, simd_size_, sdepth,
247 rcount, desc.c_type(), desc.a_type(), desc.b_type());
248 if (_dpas.as<dpas_t>().matches(desc)) {
249 build_dpas(from_bmnk_mapper, _dpas.as<dpas_t>(), desc);
250 return true;
251 }
252
253 // Try to transpose and flip: C += A * B -> C^T = B^T * A^T.
254 rcount = std::min(desc.m(), 8);
255 desc = multiply_desc_t(
256 b_layout.transpose(), a_layout.transpose(), true);
257 _dpas = dpas_t::make(/*is_dpasw=*/false, /*exec_size=*/simd_size_,
258 sdepth, rcount, desc.c_type(), desc.a_type(), desc.b_type());
259
260 if (_dpas.as<dpas_t>().matches(desc)) {
261 do_transpose_ = true;
262 build_dpas(from_bmnk_mapper, _dpas.as<dpas_t>(), desc);
263 return true;
264 }
265 return false;
266 }
267
268 void check_k_blocks_order(const layout_t &a, const layout_t &b) const {
269 object_map_t<expr_t, int> k_vars;
270 auto k_sub_layout = [&](abc_kind_t abc_kind, const layout_t &l) {
271 layout_t k_layout = layout_t(type_t::u8(), 0,
272 std::vector<dim_t>(layout_t::max_ndims, 1));
273 for (auto &b : l.blocks()) {
274 auto bmnk_kind = bmnk_mapper_.bmnk_kind(abc_kind, b.dim_idx);
275 if (bmnk_kind != bmnk_kind_t::k) continue;
276 auto &var = bmnk_mapper_.var(abc_kind, b.dim_idx);
277 auto ret = k_vars.emplace(var, (int)k_vars.size());
278 k_layout = k_layout.add_outer_block(ret.first->second, b.block);
279 }
280 return k_layout;
281 };
282 auto a_k = k_sub_layout(abc_kind_t::a, a);
283 auto b_k = k_sub_layout(abc_kind_t::b, b);
284 ir_assert(a_k == b_k)
285 << "Order of K dimensions doesn't match in A and B. A layout: "
286 << a << ", B layout: " << b;
287 }
288
289 void build_dpas(const bmnk_block_mapper_t &from_bmnk_mapper,
290 const dpas_t &dpas, const multiply_desc_t &desc) {
291 int m_blk = dpas.exec_size;
292 int n_blk = dpas.rcount;
293 int k_blk = dpas.sdepth * 4 / dpas.src1_type.size();
294
295 c_layout_ = compute_dpas_c_layout(m_blk, n_blk, dpas.c_layout(), desc);
296
297 expr_t a_buf = a_buf_;
298 expr_t b_buf = b_buf_;
299 if (do_transpose_) std::swap(a_buf, b_buf);
300 auto dpas_tail = dpas_t::make(/*is_dpasw=*/false, dpas.exec_size,
301 dpas.sdepth, desc.n() > n_blk ? desc.n() % n_blk : n_blk,
302 dpas.dst_type, dpas.src1_type, dpas.src2_type);
303
304 for (int i_k = 0; i_k < desc.k(); i_k += k_blk) {
305 for (int i_m = 0; i_m < desc.m(); i_m += m_blk) {
306 for (int i_n = 0; i_n < desc.n(); i_n += n_blk) {
307 std::vector<int> a_args = {i_m, i_k};
308 std::vector<int> b_args = {i_k, i_n};
309 std::vector<int> c_args = {i_m, i_n};
310 auto a = a_buf[desc.a_layout()(a_args)
311 * desc.a_type().size()];
312 auto b = b_buf[desc.b_layout()(b_args)
313 * desc.b_type().size()];
314 auto c = c_buf_[c_layout_(c_args) * desc.c_type().size()];
315 auto &_dpas = (i_n + n_blk > desc.n())
316 ? dpas_tail.as<dpas_t>()
317 : dpas;
318 stmt_ = stmt_.append(_dpas(c, c, a, b));
319 }
320 }
321 }
322
323 // Transpose C layout back if needed.
324 if (do_transpose_) c_layout_ = c_layout_.transpose();
325
326 // Convert C layout back to problem notation.
327 c_layout_ = from_bmnk_mapper.map_from_bmnk(
328 abc_kind_t::c, {bmnk_kind_t::m, bmnk_kind_t::n}, c_layout_);
329 }
330
331 static layout_t compute_dpas_c_layout(int m_blk, int n_blk,
332 const layout_t &blk_layout, const multiply_desc_t &desc) {
333 auto c_layout = blk_layout;
334 auto new_blocks = c_layout.blocks();
335 if (new_blocks.size() > 1) new_blocks[1].block = desc.n();
336 c_layout = layout_t(c_layout.type(), c_layout.ndims(),
337 c_layout.offset(), new_blocks);
338 c_layout = c_layout.add_outer_block(0, desc.m() / m_blk);
339 return c_layout;
340 }
341
342 bool try_build_mad() {
343 auto loops = create_loop_nest();
344
345 if (try_build_mad_kmn_block_by_n(loops)) return true;
346 if (try_build_mad_kmn_block_by_b(loops)) return true;
347
348 return false;
349 }
350
351 std::vector<loop_info_t> create_loop_nest() const {
352 object_map_t<expr_t, loop_info_t> loops;
353 for (auto *view : {&a_view_, &b_view_}) {
354 abc_kind_t abc_kind
355 = (view == &a_view_ ? abc_kind_t::a : abc_kind_t::b);
356 for (int i = 0; i < view->nvdims(); i++) {
357 auto &var = bmnk_mapper_.var(abc_kind, i);
358 int dim = int(view->vdims()[i]);
359 if (dim == 1) continue;
360
361 if (loops.count(var) > 0) continue;
362 loops[var] = loop_info_t(var, bmnk_mapper_.bmnk_kind(var), dim);
363 }
364 }
365
366 std::vector<loop_info_t> ret;
367 for (auto &kv : loops) {
368 auto &loop = kv.second;
369 loop.a_idx = bmnk_mapper_.dim_idx(abc_kind_t::a, loop.var);
370 loop.b_idx = bmnk_mapper_.dim_idx(abc_kind_t::b, loop.var);
371 loop.c_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, loop.var);
372 ret.push_back(kv.second);
373 }
374 return ret;
375 }
376
377 // Order of loops: BKMN, block by N.
378 bool try_build_mad_kmn_block_by_n(std::vector<loop_info_t> &_loops) {
379 return try_build_mad_impl(_loops,
380 {bmnk_kind_t::b, bmnk_kind_t::k, bmnk_kind_t::m,
381 bmnk_kind_t::n},
382 bmnk_kind_t::n);
383 }
384
385 // Order of loops: BKMN, block by B.
386 bool try_build_mad_kmn_block_by_b(std::vector<loop_info_t> &_loops) {
387 return try_build_mad_impl(_loops,
388 {bmnk_kind_t::b, bmnk_kind_t::k, bmnk_kind_t::m,
389 bmnk_kind_t::n},
390 bmnk_kind_t::b);
391 }
392
393 bool try_build_mad_impl(std::vector<loop_info_t> &_loops,
394 const std::vector<bmnk_kind_t> &loop_order,
395 bmnk_kind_t block_bmnk_kind) {
396 auto loops = _loops;
397 int nloops = int(loops.size());
398 std::sort(loops.begin(), loops.end(),
399 [&](const loop_info_t &a, const loop_info_t &b) {
400 int a_key = ir_utils::find_index(loop_order, a.bmnk_kind);
401 int b_key = ir_utils::find_index(loop_order, b.bmnk_kind);
402 ir_assert(a_key != -1);
403 ir_assert(b_key != -1);
404 return a_key < b_key;
405 });
406
407 int block_idx = -1;
408 for (int i = 0; i < nloops; i++) {
409 auto &l = loops[i];
410 if (l.bmnk_kind == block_bmnk_kind) {
411 ir_assert(block_idx == -1) << "Can't block 2+ dimensions.";
412 block_idx = i;
413 }
414 }
415
416 // Couldn't find N dimension, try different blocking scheme.
417 if (block_idx == -1) return false;
418
419 auto &block_loop = loops[block_idx];
420
421 int block = simd_size_;
422 while (block >= 1) {
423 if (block_loop.dim % block == 0) break;
424 block /= 2;
425 }
426
427 ir_assert(block >= 1) << "Invalid block size.";
428 block_loop.block = block;
429
430 int a_stride = 0;
431 int b_stride = 0;
432
433 // Ensure that A tile is dense.
434 if (block_loop.a_idx != -1) {
435 std::vector<dim_t> tile_dims(a_view_.nvdims(), 1);
436 tile_dims[block_loop.a_idx] = block;
437 auto layout = a_view_.create_pseudo_vlayout();
438 auto tile = layout.map(tensor_t(tile_dims));
439 if (!is_1d_strided(tile)) return false;
440 a_stride = tile.blocks()[0].stride;
441 }
442
443 // Ensure that B tile is dense.
444 if (block_loop.b_idx != -1) {
445 std::vector<dim_t> tile_dims(b_view_.nvdims(), 1);
446 tile_dims[block_loop.b_idx] = block;
447 auto layout = b_view_.create_pseudo_vlayout();
448 auto tile = layout.map(tensor_t(tile_dims));
449 if (!is_1d_strided(tile)) return false;
450 b_stride = tile.blocks()[0].stride;
451 }
452
453 build_mad(loops, block_loop, a_stride, b_stride);
454 return true;
455 }
456
457 static bool is_1d_strided(const layout_t &layout) {
458 auto &blocks = layout.blocks();
459 if (blocks.size() > 1) return false;
460 return true;
461 }
462
463 void build_mad(const std::vector<loop_info_t> &loops,
464 const loop_info_t &block_loop, int a_stride, int b_stride) {
465 ir_assert(utils::one_of(
466 block_loop.bmnk_kind, bmnk_kind_t::b, bmnk_kind_t::n))
467 << "Unsupported blocking (expected blocking by B or N).";
468
469 auto &a_type = a_view_.type();
470 auto &b_type = b_view_.type();
471 auto c_type = multiply_desc_t::get_c_type(a_type, b_type,
472 /*force_c_upconvert=*/false);
473
474 int block = block_loop.block;
475 auto _mad = mad_t::make(
476 hw_, c_type, block, a_type, a_stride, b_type, b_stride);
477 auto &mad = _mad.as<mad_t>();
478
479 c_layout_ = compute_mad_c_layout(c_type, loops, block_loop);
480
481 int nloops = int(loops.size());
482 std::vector<int> bounds(loops.size());
483 for (int i = 0; i < nloops; i++) {
484 bounds[i] = loops[i].dim / loops[i].block;
485 }
486 std::vector<int> a_idx(a_view_.nvdims());
487 std::vector<int> b_idx(b_view_.nvdims());
488 std::vector<int> c_idx(c_layout_.ndims());
489 ir_utils::for_each(bounds, [&](const std::vector<int> &idx) {
490 for (int i = 0; i < nloops; i++) {
491 int full_idx = idx[i] * loops[i].block;
492 auto &loop = loops[i];
493 if (loop.a_idx != -1) a_idx[loop.a_idx] = full_idx;
494 if (loop.b_idx != -1) b_idx[loop.b_idx] = full_idx;
495 if (loop.c_idx != -1) c_idx[loop.c_idx] = full_idx;
496 }
497 int a_off = a_view_(a_idx) * a_type.size();
498 int b_off = b_view_(b_idx) * b_type.size();
499 int c_off = c_layout_(c_idx) * c_type.size();
500 stmt_ = stmt_.append(mad(c_buf_[c_off], c_buf_[c_off],
501 a_buf_[a_off], b_buf_[b_off]));
502 });
503 }
504
505 layout_t compute_mad_c_layout(const type_t &c_type,
506 const std::vector<loop_info_t> &loops,
507 const loop_info_t &block_loop) const {
508 layout_t c_layout(c_type, bmnk_mapper_.ndims(abc_kind_t::c), 0, {});
509
510 int c_dim_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, block_loop.var);
511 c_layout = c_layout.add_outer_block(c_dim_idx, block_loop.block);
512
513 for (size_t i = 0; i < loops.size(); i++) {
514 if (loops[i].bmnk_kind == bmnk_kind_t::k) continue;
515 int dim_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, loops[i].var);
516 int bound = loops[i].dim / loops[i].block;
517 c_layout = c_layout.add_outer_block(dim_idx, bound);
518 }
519 return c_layout;
520 }
521
522 ngen::HW hw_;
523 int simd_size_;
524 bmnk_mapper_t bmnk_mapper_;
525
526 bool do_transpose_ = false;
527
528 view_t a_view_;
529 view_t b_view_;
530 layout_t c_layout_;
531
532 expr_t a_buf_;
533 expr_t b_buf_;
534 expr_t c_buf_;
535
536 stmt_t stmt_;
537};
538
539class fma_helper_t {
540public:
541 fma_helper_t(int simd_size, fma_kind_t fma_kind, const type_t &a_type,
542 const type_t &b_type, bool allow_a_grf_reorder,
543 bool allow_b_grf_reorder, bool is_src1_broadcast)
544 : simd_size_(simd_size)
545 , fma_kind_(fma_kind)
546 , a_type_(a_type)
547 , b_type_(b_type)
548 , allow_a_grf_reorder_(allow_a_grf_reorder)
549 , allow_b_grf_reorder_(allow_b_grf_reorder)
550 , is_src1_broadcast_(is_src1_broadcast) {}
551
552 fma_kind_t fma_kind() const { return fma_kind_; }
553
554 layout_t convert_to_fma_friendly_layout(const layout_t &layout,
555 abc_kind_t abc_kind, bool is_slm, const bmnk_mapper_t &bmnk_mapper,
556 bool *changed = nullptr) const {
557 bool allow_grf_reorder
558 = (abc_kind == abc_kind_t::a ? allow_a_grf_reorder_
559 : allow_b_grf_reorder_);
560 if (changed) *changed = false;
561 if (!allow_grf_reorder) return layout;
562
563 // GRF reorder is only supported with dpas/dpasw.
564 if (fma_kind_ == fma_kind_t::mad) {
565 if (is_slm) return layout;
566 // mad may require type conversion, supported for GRF layouts only.
567 return convert_to_fma_friendly_type(layout, abc_kind, changed);
568 }
569
570 std::vector<bmnk_kind_t> bmnk_kinds;
571 if (abc_kind == abc_kind_t::a) {
572 bmnk_kinds.push_back(bmnk_kind_t::m);
573 bmnk_kinds.push_back(bmnk_kind_t::k);
574 } else {
575 bmnk_kinds.push_back(bmnk_kind_t::k);
576 bmnk_kinds.push_back(bmnk_kind_t::n);
577 }
578
579 auto bmnk_layout
580 = bmnk_mapper.map_to_bmnk(abc_kind, bmnk_kinds, layout);
581
582 auto dpas_layout = get_dpas_friendly_layout(bmnk_layout, abc_kind);
583 if (dpas_layout == bmnk_layout) return layout;
584
585 if (changed) *changed = true;
586
587 bmnk_block_mapper_t from_bmnk_mapper(bmnk_mapper);
588 from_bmnk_mapper.push_blocks(abc_kind, layout.blocks());
589
590 auto fma_layout = from_bmnk_mapper.map_from_bmnk(
591 abc_kind, bmnk_kinds, dpas_layout);
592 fma_layout = fma_layout.make_dense();
593 return fma_layout;
594 }
595
596private:
597 layout_t convert_to_fma_friendly_type(const layout_t &layout,
598 abc_kind_t abc_kind, bool *changed = nullptr) const {
599 if (changed) *changed = false;
600 if (fma_kind_ != fma_kind_t::mad) return layout;
601
602 // mad with s8/u8 is not supported, promote to strided s16.
603 if (a_type_.is_x8() && b_type_.is_x8()) {
604 if (changed) *changed = true;
605 return layout.retype(type_t::s16()).make_strided(2);
606 }
607
608 // bf16 mixed mode mad requires src2 to be f32.
609 if (abc_kind == abc_kind_t::b && a_type_.is_bf16()) {
610 if (changed) *changed = true;
611 return layout.retype(type_t::f32()).make_dense();
612 }
613
614 // bf16 mixed mode mad requires src1 to be packed, when src1 is
615 // broadcasted it needs to be converted to f32.
616 if (abc_kind == abc_kind_t::a && a_type_.is_bf16()
617 && is_src1_broadcast_) {
618 if (changed) *changed = true;
619 return layout.retype(type_t::f32()).make_dense();
620 }
621
622 // Ensure the layout is dense to align regioning.
623 if (!layout.is_dense()) {
624 if (changed) *changed = true;
625 return layout.make_dense();
626 }
627
628 return layout;
629 }
630
631 layout_t get_dpas_friendly_layout(
632 const layout_t &bmnk_layout, abc_kind_t abc_kind) const {
633 bool is_a = (abc_kind == abc_kind_t::a);
634 int mn_idx = (is_a ? 0 : 1);
635 int k_idx = (is_a ? 1 : 0);
636
637 layout_t dpas_layout;
638 dim_t mn_blk = bmnk_layout.dim(mn_idx);
639 dim_t k_blk = bmnk_layout.dim(k_idx);
640
641 std::vector<int> try_rcount;
642 if (is_a && mn_blk % 8 == 0) try_rcount.push_back(8);
643 // Cannot calculate correct r_count when !is_a, but rcount is
644 // effectively ignored in that case as rcount mainly affects b_layout.
645 // Also note that rcount used here may not be supported in hardware and
646 // is used solely to compute layout.
647 try_rcount.push_back(is_a ? mn_blk : 8);
648 for (int rcount : try_rcount) {
649 auto _dpas = dpas_t::make(/*is_dpasw=*/false, simd_size_,
650 /*sdepth=*/8, rcount, type_t::undef(), b_type_, a_type_);
651 auto &dpas = _dpas.as<dpas_t>();
652
653 dpas_layout = (is_a ? dpas.b_layout() : dpas.a_layout());
654 dpas_layout = dpas_layout.transpose();
655
656 auto default_layout = bmnk_layout.retype(is_a ? a_type_ : b_type_);
657 if (dpas_layout <= default_layout) return default_layout;
658 }
659
660 dim_t dpas_mn_blk = dpas_layout.dim(mn_idx);
661 dim_t dpas_k_blk = dpas_layout.dim(k_idx);
662 ir_assert(k_blk % dpas_k_blk == 0);
663
664 dim_t k_outer = ir_utils::safe_divide(k_blk, dpas_k_blk);
665 dim_t mn_outer = ir_utils::safe_divide(mn_blk, dpas_mn_blk);
666 dpas_layout = dpas_layout.add_outer_block(k_idx, k_outer);
667 dpas_layout = dpas_layout.add_outer_block(mn_idx, mn_outer);
668 return dpas_layout;
669 }
670
671 int simd_size_;
672 fma_kind_t fma_kind_;
673 type_t a_type_;
674 type_t b_type_;
675 bool allow_a_grf_reorder_;
676 bool allow_b_grf_reorder_;
677 bool is_src1_broadcast_;
678};
679
680class b_reduce_context_t {
681public:
682 b_reduce_context_t(ir_context_t &ir_ctx, const conv_config_t &cfg)
683 : ir_ctx_(ir_ctx), cfg_(cfg), reduce_condition_(true) {
684 if (cfg_.reduce_b()) b_reduced_reg_buf_ = make_buffer("b_reduced");
685 }
686
687 // Setters for B reduced memory buffer/view.
688 void set_b_reduced_mem_buf(const expr_t &buf) { b_reduced_mem_buf_ = buf; }
689 void set_b_reduced_view(const view_t &v) { b_reduced_view_ = v; }
690
691 // Sets the condition to update B reduced output. Reduction is done across
692 // K for B (KxN tensor) so M dimension should be checked before the update.
693 void set_reduce_condition(const expr_t &cond) { reduce_condition_ = cond; }
694
695 // Global memory buffer.
696 const expr_t &b_reduced_mem_buf() const { return b_reduced_mem_buf_; }
697
698 // Register buffer.
699 const expr_t &b_reduced_reg_buf() const { return b_reduced_reg_buf_; }
700 int b_reduced_size() const { return b_reduced_size_; }
701
702 // Memory view.
703 const view_t &b_reduced_thr_view() const { return b_reduced_thr_view_; }
704
705 // Register layout.
706 const layout_t &b_reduced_reg_layout() const {
707 return b_reduced_reg_layout_;
708 }
709
710 void init_reduced_thr_view(
711 const tensor_t &b_thr_tile, const expr_t &cond = expr_t()) {
712 ir_assert(b_reduced_thr_view_.is_empty()) << "Can't initialize twice.";
713
714 auto b_reduced_thr_tile = b_to_b_reduced_tile(b_thr_tile);
715 b_reduced_thr_view_
716 = b_reduced_view_.create_sub_view(b_reduced_thr_tile);
717 b_reduced_reg_layout_ = b_reduced_thr_view_.create_dense_vlayout();
718 b_reduced_size_ = b_reduced_reg_layout_.size();
719 b_reduced_size_ = utils::rnd_up(b_reduced_size_, cfg_.grf_size());
720
721 if (!cond.is_empty()) reduce_condition_ &= cond;
722 }
723
724 stmt_t create_reduce_stmt(const layout_t &b_layout, const expr_t &b_buf,
725 const tensor_t &subtile = tensor_t()) {
726 auto reduction_stmt
727 = jit::create_reduce_stmt(b_layout, b_reduced_reg_layout_,
728 b_buf, b_reduced_reg_buf_, subtile, reduction_mask_);
729 return reduction_stmt;
730 }
731
732 stmt_t create_store_stmt(bool use_atomic) const {
733 auto r2g = make_access_builder(ir_ctx_, b_reduced_thr_view_,
734 b_reduced_mem_buf_, b_reduced_reg_buf_,
735 use_atomic ? send_op_t::atomic_fadd : send_op_t::store,
736 send_address_t::a64);
737 // TODO: Check that layouts match.
738 auto ret = r2g.stmt();
739 if (!reduce_condition_.is_empty()) {
740 ret = if_t::make(reduce_condition_, ret);
741 }
742 return ret;
743 }
744
745private:
746 tensor_t b_to_b_reduced_tile(const tensor_t &b_tile) const {
747 std::vector<dim_t> dims;
748 std::vector<expr_t> start;
749 for (int i = 0; i < b_tile.ndims(); i++) {
750 if ((reduction_mask_ & (1 << i)) != 0) {
751 dims.push_back(b_tile(i));
752 start.push_back(b_tile.start(i));
753 }
754 }
755 return tensor_t(dims, start);
756 }
757
758 ir_context_t &ir_ctx_;
759 const conv_config_t &cfg_;
760
761 expr_t reduce_condition_;
762
763 expr_t b_reduced_mem_buf_;
764 expr_t b_reduced_reg_buf_;
765
766 view_t b_reduced_view_;
767 view_t b_reduced_thr_view_;
768
769 layout_t b_reduced_reg_layout_;
770 int b_reduced_size_ = 0;
771
772 uint32_t reduction_mask_ = (1 << 1) | (1 << 2);
773};
774
775class subtile_info_t {
776public:
777 using post_load_func_t = std::function<stmt_t(
778 const layout_t &, const expr_t &, const tensor_t &)>;
779
780 subtile_info_t(ir_context_t &ir_ctx, const gemm_schedule_t &gemm_schedule,
781 const fma_helper_t &fma_helper, abc_kind_t abc_kind, bool use_slm,
782 bool load_buffered, bool allow_2d_load, int idx,
783 const view_t &mem_view, const tensor_t &subtile,
784 const expr_t &mem_buf, const expr_t &slm_buf, const expr_t &reg_buf,
785 const expr_t &tmp_buf)
786 : ir_ctx_(ir_ctx)
787 , gemm_schedule_(gemm_schedule)
788 , fma_helper_(fma_helper)
789 , abc_kind_(abc_kind)
790 , use_slm_(use_slm)
791 , load_buffered_(load_buffered)
792 , allow_2d_load_(allow_2d_load)
793 , idx_(idx)
794 , mem_view_(mem_view)
795 , subtile_(subtile)
796 , mem_buf_(mem_buf)
797 , slm_buf_(slm_buf)
798 , reg_buf_(reg_buf)
799 , tmp_buf_(tmp_buf) {}
800
801 bool is_loaded() const { return is_loaded_; }
802
803 void set_loaded() { is_loaded_ = true; }
804
805 const view_t &reg_view() const { return reg_view_; }
806
807 int reg_buf_size() const {
808 return utils::rnd_up(reg_layout_.size(), ir_ctx_.grf_size());
809 }
810
811 int tmp_buf_size() const { return tmp_buf_size_; }
812
813 const stmt_t &s2r_load() const { return s2r_load_; }
814
815 const stmt_t &g2r_load() const { return g2r_load_; }
816
817 const send_hint_t &send_hint() const { return send_hint_; }
818
819 void load(const post_load_func_t &post_load = post_load_func_t()) {
820 auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
821
822 layout_t load_layout;
823 stmt_t &stmt = (use_slm_ ? s2r_load_ : g2r_load_);
824 load_impl(ir_ctx_, load_layout, reg_view_, send_hint_, stmt);
825
826 if (post_load) {
827 stmt = stmt.append(post_load(load_layout, reg_buf_, subtile_));
828 }
829
830 reg_layout_ = load_layout;
831
832 bool changed;
833 auto fma_layout = fma_helper_.convert_to_fma_friendly_layout(
834 reg_layout_, abc_kind_,
835 /*is_slm=*/false, bmnk_mapper, &changed);
836
837 if (changed) {
838 bool is_reorder_nop
839 = fma_layout.retype(reg_layout_.type()) == reg_layout_
840 && reg_layout_.type().is_bitwise_compatible(
841 fma_layout.type());
842
843 if (fma_layout.type() != reg_layout_.type()) {
844 reg_view_ = reg_view_.retype(fma_layout.type());
845 }
846 reg_layout_ = fma_layout;
847 reg_view_.set_tlayout(reg_layout_);
848 if (!is_reorder_nop) {
849 stmt = substitute(stmt, reg_buf_, tmp_buf_);
850 stmt = stmt.append(create_reorder_stmt(
851 load_layout, reg_layout_, tmp_buf_, reg_buf_));
852 int load_reg_size = int(load_layout.size());
853 load_reg_size
854 = utils::rnd_up(load_reg_size, ir_ctx_.grf_size());
855 tmp_buf_size_ = std::max(tmp_buf_size_, load_reg_size);
856 }
857 }
858 }
859
860private:
861 void load_impl(ir_context_t &ir_ctx, layout_t &load_layout,
862 view_t &load_view, send_hint_t &send_hint, stmt_t &stmt) const {
863 view_t mem_view = mem_view_;
864 if (load_buffered_)
865 mem_view_.try_create_buffer_view(mem_view, load_view);
866
867 send_op_t send_op = send_op_t::load;
868 send_hint = get_send_hint(ir_ctx_.exec_cfg(), send_op_t::load,
869 fma_helper_.fma_kind(), abc_kind_, mem_view, gemm_schedule_,
870 allow_2d_load_);
871 auto read = make_access_builder(ir_ctx, mem_view,
872 use_slm_ ? slm_buf_ : mem_buf_, reg_buf_, send_op,
873 use_slm_ ? send_address_t::slm : send_address_t::a64,
874 send_hint);
875 ir_trace() << (abc_kind_ == abc_kind_t::a ? "A" : "B")
876 << " GMEM/SLM to GRF load #" << idx_ << ":\n"
877 << read.str() << std::endl;
878
879 load_layout = read.reg_layout();
880 if (!load_view.is_empty()) {
881 load_view.set_tlayout(load_layout);
882 } else {
883 load_view = view_t(load_layout);
884 }
885 stmt = read.stmt();
886 }
887
888 ir_context_t &ir_ctx_;
889 const gemm_schedule_t &gemm_schedule_;
890 const fma_helper_t &fma_helper_;
891 abc_kind_t abc_kind_;
892 bool use_slm_;
893 bool load_buffered_;
894 bool allow_2d_load_;
895 int idx_;
896 view_t mem_view_;
897 tensor_t subtile_;
898
899 expr_t mem_buf_;
900 expr_t slm_buf_;
901 expr_t reg_buf_;
902 expr_t tmp_buf_;
903
904 bool is_loaded_ = false;
905 view_t reg_view_;
906 layout_t reg_layout_;
907 int tmp_buf_size_ = 0;
908 stmt_t s2r_load_;
909 stmt_t g2r_load_;
910 send_hint_t send_hint_;
911};
912
913class load_multiply_builder_t {
914public:
915 load_multiply_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
916 const gemm_schedule_t &gemm_schedule,
917 const fma_helper_t &fma_helper, b_reduce_context_t &b_reduce_ctx,
918 const expr_t &ap_buf, const expr_t &a_slm_buf, const expr_t &bp_buf,
919 const expr_t &b_slm_buf, const view_t &ap_x_view,
920 const view_t &bp_x_view, const kernel_info_t &kernel_info)
921 : prb_(cfg.prb())
922 , cfg_(cfg)
923 , ir_ctx_(ir_ctx)
924 , gemm_schedule_(gemm_schedule)
925 , fma_helper_(fma_helper)
926 , b_reduce_ctx_(b_reduce_ctx)
927 , ap_buf_(ap_buf)
928 , a_slm_buf_(a_slm_buf)
929 , bp_buf_(bp_buf)
930 , b_slm_buf_(b_slm_buf)
931 , kernel_info_(kernel_info) {
932 ir_assert(cfg_.subtiles().a() == 1 || cfg_.subtiles().b() == 1)
933 << "At most one tensor can be tiled.";
934
935 ab_tmp_buf_ = make_buffer("ab_tmp");
936 a_buf_ = make_buffer("a");
937 b_buf_ = make_buffer("b");
938 c_buf_ = make_buffer("c");
939
940 // Views to multiply by a thread.
941 a_thr_view_ = ap_x_view.create_sub_view(gemm_schedule_.a_thr_tile());
942 b_thr_view_ = bp_x_view.create_sub_view(gemm_schedule_.b_thr_tile());
943
944 // Initialize view for reduced B.
945 if (cfg_.reduce_b() && !cfg_.slm().b()) {
946 b_reduce_ctx_.init_reduced_thr_view(
947 gemm_schedule_.b_thr_tile(/*is_relative=*/false));
948 }
949
950 // TODO: Specify loops over subtiles in the schedule, use unrolling.
951 // Sub-tile indices.
952 a_idx_ = ir_ctx_.create_tmp_var(type_t::s32(), "a_idx");
953 b_idx_ = ir_ctx_.create_tmp_var(type_t::s32(), "b_idx");
954
955 // Sub-tile views.
956 a_i_view_ = create_subtile_view(abc_kind_t::a, a_thr_view_,
957 cfg_.subtiles().a(), a_idx_, bmnk_kind_t::m, &a_i_outer_blocks_,
958 a_i_tile_);
959 b_j_view_ = create_subtile_view(abc_kind_t::b, b_thr_view_,
960 cfg_.subtiles().b(), b_idx_, bmnk_kind_t::n, &b_j_outer_blocks_,
961 b_j_tile_);
962
963 build();
964 }
965
966 const std::vector<stmt_t> &allocs() const { return allocs_; }
967
968 const stmt_t &load_mul_stmt() const { return load_mul_stmt_; }
969
970 const expr_t &c_buf() const { return c_buf_; }
971
972 const layout_t &c_reg_layout() const { return c_reg_layout_; }
973
974private:
975 view_t create_subtile_view(abc_kind_t abc_kind, const view_t &thr_view,
976 int subtiles, const expr_t &idx, bmnk_kind_t bmnk_kind,
977 std::vector<block_t> *outer_blocks, tensor_t &subtile) const {
978 auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
979 auto layout = thr_view.create_pseudo_vlayout();
980 dim_t mn_dim = 1;
981 for (auto &b : layout.blocks()) {
982 auto b_bmnk_kind = bmnk_mapper.bmnk_kind(abc_kind, b.dim_idx);
983 if (b_bmnk_kind == bmnk_kind) mn_dim *= b.block;
984 }
985
986 std::vector<dim_t> subtile_dims(thr_view.nvdims(), 1);
987 dim_t mn_subtile_dim = ir_utils::safe_divide(mn_dim, dim_t(subtiles));
988 for (auto &b : layout.blocks()) {
989 auto b_bmnk_kind = bmnk_mapper.bmnk_kind(abc_kind, b.dim_idx);
990 if (b_bmnk_kind == bmnk_kind) {
991 if (mn_subtile_dim == 1) continue;
992 dim_t next_block;
993 if (mn_subtile_dim % b.block == 0) {
994 next_block = b.block;
995 } else {
996 ir_assert(b.block % mn_subtile_dim == 0);
997 next_block = mn_subtile_dim;
998 }
999 subtile_dims[b.dim_idx] *= next_block;
1000 mn_subtile_dim /= next_block;
1001 } else {
1002 subtile_dims[b.dim_idx] *= b.block;
1003 }
1004 }
1005 grid_info_t grid({subtiles}, {idx});
1006 subtile = layout.split(tensor_t(subtile_dims), grid, outer_blocks);
1007 return thr_view.create_sub_view(subtile);
1008 }
1009
1010 void build() {
1011 int max_iters = 2;
1012 bool load_ok = false;
1013 for (int iter = 0; iter < max_iters; iter++) {
1014 if (try_load_subtiles(/*allow_2d_load=*/iter == 0)) {
1015 load_ok = true;
1016 break;
1017 }
1018 }
1019 ir_assert(load_ok) << "Can't generate load statements for subtiles.";
1020
1021 auto a_subtiles = cfg_.subtiles().a();
1022 auto b_subtiles = cfg_.subtiles().b();
1023
1024 for (int i = 0; i < a_subtiles; i++) {
1025 for (int j = 0; j < b_subtiles; j++) {
1026 build_subtile(i, j);
1027 }
1028 }
1029
1030 if (zp_buf_size_ > 0)
1031 register_buffer(zp_buf_, zp_buf_size_, alloc_kind_t::grf);
1032 if (zp_mask_size_ > 0)
1033 register_buffer(zp_mask_, zp_mask_size_, alloc_kind_t::grf);
1034
1035 // Handle temporary buffer in case of GRF reorders.
1036 int tmp_buf_size = 0;
1037 for (int i = 0; i < a_subtiles; i++)
1038 tmp_buf_size
1039 = std::max(tmp_buf_size, a_subtiles_[i].tmp_buf_size());
1040 for (int j = 0; j < b_subtiles; j++)
1041 tmp_buf_size
1042 = std::max(tmp_buf_size, b_subtiles_[j].tmp_buf_size());
1043 if (tmp_buf_size > 0)
1044 register_buffer(ab_tmp_buf_, tmp_buf_size, alloc_kind_t::grf);
1045
1046 // C layout in problem notation.
1047 auto c_layout = c_subtile_layout_;
1048
1049 // Add outer blocks coming from A/B subtiles.
1050 auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
1051 for (auto &b : a_i_outer_blocks_) {
1052 auto &var = bmnk_mapper.var(abc_kind_t::a, b.dim_idx);
1053 int c_dim_idx = bmnk_mapper.dim_idx(abc_kind_t::c, var);
1054 c_layout = c_layout.add_outer_block(c_dim_idx, b.block);
1055 }
1056 for (auto &b : b_j_outer_blocks_) {
1057 auto &var = bmnk_mapper.var(abc_kind_t::b, b.dim_idx);
1058 int c_dim_idx = bmnk_mapper.dim_idx(abc_kind_t::c, var);
1059 c_layout = c_layout.add_outer_block(c_dim_idx, b.block);
1060 }
1061
1062 c_reg_layout_ = c_layout;
1063 }
1064
1065 bool can_use_2d_load(const abc_kind_t &abc_kind, const view_t &view) const {
1066 bool is_blocked = view.tlayout().innermost_block_layout().elems() > 1;
1067 if (!is_blocked) return true;
1068
1069 // In general we want to skip expensive logic to check requirements for
1070 // 2D block messages with block layouts as performance with 1D messages
1071 // is good enough. However there are a few cases (backward by weights
1072 // with dpas) when 2D block messages give boost even for block layouts
1073 // due to VNNI/transpose features.
1074 if (prb_.is_bwd_w && cfg_.is_dp_fma()) {
1075 auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
1076 auto &blocks = view.tlayout().blocks();
1077 if (blocks.size() < 2) return false;
1078 int b1_dim_idx = blocks[1].dim_idx;
1079 return bmnk_mapper.bmnk_kind(abc_kind, b1_dim_idx)
1080 == bmnk_kind_t::k;
1081 }
1082 return false;
1083 }
1084
1085 bool try_load_subtiles(bool allow_2d_load) {
1086 int a_subtiles = cfg_.subtiles().a();
1087 int b_subtiles = cfg_.subtiles().b();
1088 bool use_a_slm = cfg_.slm().a();
1089 bool use_b_slm = cfg_.slm().b();
1090 a_subtiles_.clear();
1091 b_subtiles_.clear();
1092 for (int i = 0; i < a_subtiles; i++) {
1093 auto view = a_i_view_.substitute(a_idx_, i);
1094 auto tile = a_i_tile_.substitute(a_idx_, i);
1095 // Using buffered view is enabled only when:
1096 // - Loading directly from global memory
1097 // - FMA kind is mad (dpas implementation is more strict and requires
1098 // layouts, not views)
1099 // - Loading A tensor (A - activations for FWD/BWD_D where we may have
1100 // overlapping when applying KW blocking )
1101 bool load_buffered = cfg_.ow_kw_grf_cache() && !use_a_slm
1102 && cfg_.fma_kind() == fma_kind_t::mad;
1103 a_subtiles_.emplace_back(ir_ctx_, gemm_schedule_, fma_helper_,
1104 abc_kind_t::a, use_a_slm, load_buffered,
1105 allow_2d_load && can_use_2d_load(abc_kind_t::a, a_i_view_),
1106 i, view, tile, ap_buf_, a_slm_buf_, a_buf_, ab_tmp_buf_);
1107 a_subtiles_.back().load();
1108 }
1109 subtile_info_t::post_load_func_t b_post_load;
1110 if (!use_b_slm && cfg_.reduce_b()) {
1111 b_post_load = [&](const layout_t &reg_layout, const expr_t &reg_buf,
1112 const tensor_t &tile) {
1113 return b_reduce_ctx_.create_reduce_stmt(
1114 reg_layout, reg_buf, tile);
1115 };
1116 }
1117 for (int j = 0; j < b_subtiles; j++) {
1118 auto view = b_j_view_.substitute(b_idx_, j);
1119 auto tile = b_j_tile_.substitute(b_idx_, j);
1120 b_subtiles_.emplace_back(ir_ctx_, gemm_schedule_, fma_helper_,
1121 abc_kind_t::b, use_b_slm,
1122 /*load_buffered=*/false,
1123 allow_2d_load && can_use_2d_load(abc_kind_t::b, b_j_view_),
1124 j, view, tile, bp_buf_, b_slm_buf_, b_buf_, ab_tmp_buf_);
1125
1126 b_subtiles_.back().load(b_post_load);
1127 }
1128
1129 // Validate subtile loads, when VNNI permutation is applied, both A/B
1130 // have to use the same pattern.
1131 int vnni_permute_factor
1132 = a_subtiles_[0].send_hint().hint_2d.vnni_permute_factor;
1133 for (int i = 1; i < a_subtiles; i++) {
1134 int f = a_subtiles_[i].send_hint().hint_2d.vnni_permute_factor;
1135 if (f != vnni_permute_factor) return false;
1136 }
1137 for (int j = 0; j < b_subtiles; j++) {
1138 int f = b_subtiles_[j].send_hint().hint_2d.vnni_permute_factor;
1139 if (f != vnni_permute_factor) return false;
1140 }
1141 return true;
1142 }
1143
1144 class src_zp_mask_info_t {
1145 public:
1146 src_zp_mask_info_t() = delete;
1147 src_zp_mask_info_t(load_multiply_builder_t &lmb, int m_blk, int k_blk,
1148 int desc_m, int desc_n, int channels, int a_stride, bool is_mad,
1149 const view_t &a_view, expr_t &zp_mask, int &zp_mask_size)
1150 : lmb_(lmb)
1151 , is_const_(true)
1152 , is_simd_(true)
1153 , is_scalar_(false)
1154 , is_wide_(m_blk < 16)
1155 , zp_mask_(zp_mask) {
1156 const auto tile
1157 = lmb_.gemm_schedule_.a_thr_tile(/*is_relative=*/false);
1158 const auto a_thr_view
1159 = lmb_.gemm_schedule_.a_view().create_sub_view(tile);
1160 const auto ic_dim = (!is_mad) ? 2 : 1;
1161 ic_start_ = a_thr_view.vstart(ic_dim);
1162
1163 // 0. Are the masks at all required?
1164 const auto &prb = lmb_.prb_;
1165 const auto dims = tile.dims()[3] * tile.dims()[4] * tile.dims()[5];
1166 const auto is_scalar = !is_mad && (dims <= 1);
1167
1168 const auto has_pad = (prb.pd + 1) * (prb.ph + 1) * (prb.pw + 1) > 1;
1169 const auto has_stride_bd
1170 = prb.is_bwd_d && (prb.sd * prb.sh * prb.sw > 1);
1171
1172 // 1. Get the raw representation of the buffer`s masks
1173 auto mask_tensor
1174 = a_thr_view.create_mask_tensor(lmb_.ir_ctx_.cset());
1175
1176 // 2. Collect the masks, transforming the dimensions as needed
1177 int channels_blk = std::min(channels,
1178 (int)a_thr_view.tlayout().normalize().blocks()[0].block);
1179 if (channels_blk > 32) channels_blk = 32;
1180 const auto c_blk = std::min(channels_blk, m_blk);
1181 auto a_tdims = a_view.tlayout().dims();
1182 auto mask_blk = is_mad ? c_blk : channels_blk;
1183 size_ = ((prb.kd * prb.kh * prb.kw > 1) || has_pad || has_stride_bd)
1184 * ((!is_scalar) ? accumulate(a_tdims.begin(), a_tdims.end(),
1185 1, std::multiplies<dim_t>())
1186 / mask_blk
1187 : 1);
1188 if (size_ == 0) return;
1189
1190 mask_tensor_t masks(
1191 layout_t(type_t::_bool(), 0, std::vector<dim_t> {size_}));
1192 std::vector<dim_t> a_dims(a_view.tlayout().ndims(), 1);
1193 a_dims[ic_dim] = mask_blk;
1194 int i = 0;
1195 a_view.tlayout().for_each_tile(
1196 tensor_t(a_dims), [&](const std::vector<dim_t> &start) {
1197 std::vector<dim_t> a_st(a_thr_view.nvdims(), 0);
1198 for (int idx = 0; idx < (int)start.size(); idx++) {
1199 auto tdim_to_vdims = [&](int idx) {
1200 const auto &tdim = a_view.tdim(idx);
1201 auto vidx0 = tdim.vidx(0);
1202 auto vidx1 = -1;
1203 int vdim1 = 0;
1204 ir_assert(tdim.nvargs() <= 2);
1205 for (int vdim0 = 0;
1206 vdim0 < a_thr_view.vdims()[vidx0];
1207 vdim0++) {
1208 auto tdim_expr = substitute(tdim.expr(),
1209 a_view.vvars()[vidx0],
1210 to_expr(vdim0));
1211 if (tdim.nvargs() == 2) {
1212 vidx1 = tdim.vidx(1);
1213 for (vdim1 = 0; vdim1
1214 < a_thr_view.vdims()[vidx1];
1215 vdim1++) {
1216 auto tdim_expr2 = substitute(
1217 tdim_expr,
1218 a_view.vvars()[vidx1],
1219 to_expr(vdim1));
1220 if (to_cpp<dim_t>(
1221 simplify(tdim_expr2))
1222 == start[idx]) {
1223 a_st[vidx1] = vdim1;
1224 a_st[vidx0] = vdim0;
1225 return;
1226 }
1227 }
1228 tdim_expr = substitute(tdim_expr,
1229 a_view.vvars()[vidx1],
1230 to_expr(0));
1231 }
1232 if (to_cpp<dim_t>(simplify(tdim_expr))
1233 == start[idx]) {
1234 a_st[vidx0] = vdim0;
1235 break;
1236 }
1237 }
1238 };
1239 tdim_to_vdims(idx);
1240 }
1241 auto off = a_thr_view.create_pseudo_vlayout()
1242 .make_dense()
1243 .offset<dim_t>(a_st);
1244 if (i >= size_) return;
1245 masks.set_mask(i, mask_tensor.mask(off));
1246 i++;
1247 });
1248
1249 // 3. Compute some basic properties of the masks just collected
1250 for (int n = 0; n < size_; n++) {
1251 auto *sh = masks.mask(n).as_ptr<shuffle_t>();
1252 is_simd_ &= !sh || sh->is_broadcast();
1253 is_const_ &= !!sh;
1254 for (int v = (sh) ? 0 : c_blk; v < c_blk; v++)
1255 is_const_ &= sh->vec[sh->idx[v]].is<bool_imm_t>();
1256 }
1257
1258 // 4. Scalarize if the masks permit, transform to shorts otherwise
1259 for (int n = 0; n < size_; n++)
1260 if (is_simd_) {
1261 object_map_t<expr_t, std::vector<expr_t>> vars;
1262 expr_scalarizer_t sc(c_blk, 0, vars);
1263 masks.set_mask(n, sc.mutate(masks.mask(n)));
1264 } else if (is_const_) {
1265 uint16_t mask = 0;
1266 auto &sh = masks.mask(n).as<shuffle_t>();
1267 for (int v = c_blk; v; v--)
1268 mask = mask * 2
1269 + sh.vec[sh.idx[v - 1]].as<bool_imm_t>().value;
1270 masks.set_mask(n, mask);
1271 } else {
1272 ir_error_not_expected() << "Non-SIMD non-constant masks!";
1273 }
1274
1275 // 5. Assume lack of masks if they all are true
1276 bool all_true = true;
1277 for (int n = 0; all_true && (n < size_); n++)
1278 all_true &= masks.mask(n).is_equal(expr_t(true));
1279 if (all_true) {
1280 is_const_ = true;
1281 is_simd_ = true;
1282 size_ = 0;
1283 return;
1284 }
1285 is_scalar_ = is_scalar;
1286
1287 // 6. zp_mask_ gets created by the caller; allocate var_mask_
1288 var_mask_ = lmb_.ir_ctx_.create_tmp_var(
1289 (!is_bool()) ? type_t::s16() : type_t::_bool(16));
1290
1291 // 7. Vectorize everything for easier computation and emit the IR
1292 if (!is_scalar) {
1293 std::vector<expr_t> exprs;
1294 object_eq_map_t<expr_t, expr_t> vars;
1295
1296 // Here we assume two important things:
1297 // - C has exactly one N block like 4c16f8c (where f is ow)
1298 // - The innermost block is by M and it matches the SIMD size
1299
1300 std::vector<expr_t> proto(size_);
1301 if (is_wide_) {
1302 for (int n = 0; n < int(proto.size()); n++) {
1303 const auto r = (n / 2 / k_blk) * 2 * k_blk
1304 + (n / 2) % k_blk + (n & 1) * k_blk;
1305 proto[n] = masks.mask(r % size_);
1306 }
1307 } else {
1308 for (int n = 0; n < int(proto.size()); n++)
1309 proto[n] = masks.mask(n % size_);
1310 }
1311 for (; size_ >= m_blk * 2; size_ /= 2) {
1312 auto c = [](expr_t &a, expr_t &b) { return a.is_equal(b); };
1313 auto half = proto.begin() + size_ / 2;
1314 if (!std::equal(proto.begin(), half, half, c)) break;
1315 }
1316
1317 const auto blk
1318 = (size_ > m_blk) ? std::min(m_blk * 2, 16) : m_blk;
1319 for (int n = 0; n < size_; n += blk) {
1320 std::vector<expr_t> e(blk);
1321 for (int m = 0; m < blk; m++) {
1322 e[m] = proto[n + m % (size_ - n)];
1323 }
1324 int ntrue = 0, nfalse = 0;
1325 for (int m = 0; m < blk; m++) {
1326 if (e[m].is<bool_imm_t>())
1327 ((e[m].as<bool_imm_t>().value) ? ntrue : nfalse)++;
1328 }
1329 ir_assert((ntrue == 0) || (ntrue + nfalse == blk));
1330 if ((ntrue == 0) && (nfalse > 0) && (nfalse < blk)) {
1331 auto nb = *std::find_if(e.begin(), e.end(),
1332 [](expr_t &x) { return !x.is<bool_imm_t>(); });
1333 for (int m = 0; m < blk; m++) {
1334 e[m] = (e[m].is<bool_imm_t>())
1335 ? (nb & expr_t(false))
1336 : (e[m] & expr_t(true));
1337 }
1338 }
1339 exprs.emplace_back(vector2expr(e, vars));
1340 }
1341
1342 const auto real_size = std::max(
1343 size_ * ((is_wide_) ? 8 : 1), int(exprs.size()) * blk);
1344 zp_mask_size = std::max(zp_mask_size, real_size * w_stride());
1345 for (int i = 0; i < int(exprs.size()); i++) {
1346 auto expr = cast_t::make(w_type(blk), exprs[i]);
1347 stmt_ = stmt_.append(
1348 store_t::make(zp_mask_, i * blk * w_stride(),
1349 (is_simd_) ? -expr : expr, w_stride()));
1350 }
1351 if (is_wide_) {
1352 auto wide_scalar = [&](const expr_t &a, const expr_t &b) {
1353 std::vector<int> idx(16, 1);
1354 for (int i = 0; i < 8; i++)
1355 idx[i] = 0;
1356 return shuffle_t::make(std::vector<expr_t> {a, b}, idx);
1357 };
1358 for (int i = size_ - 2; i > 0; i -= 2) {
1359 auto load_l = load_t::make(
1360 type_t::s16(), zp_mask_, i * w_stride());
1361 auto load_h = load_t::make(
1362 type_t::s16(), zp_mask_, (i + 1) * w_stride());
1363 auto load = wide_scalar(load_l, load_h);
1364 load = cast_t::make(type_t::s16(16), load);
1365 stmt_ = stmt_.append(store_t::make(zp_mask_,
1366 i * 8 * w_stride(), load, w_stride()));
1367 }
1368 if (size_ % 2 == 0) {
1369 auto l0h = load_t::make(
1370 type_t::s16(), zp_mask_, w_stride());
1371 stmt_ = stmt_.append(store_t::make(zp_mask_,
1372 8 * w_stride(),
1373 shuffle_t::make_broadcast(l0h, 8), w_stride()));
1374 }
1375 auto l0l = load_t::make(type_t::s16(), zp_mask_, 0);
1376 stmt_ = stmt_.append(store_t::make(zp_mask_, 0,
1377 shuffle_t::make_broadcast(l0l, 8), w_stride()));
1378 }
1379
1380 for (auto &v : vars)
1381 stmt_ = let_t::make(v.second, v.first, stmt_);
1382 } else { // is_scalar == true
1383 zp_mask_size = std::max(zp_mask_size, type_t::s16().size());
1384 auto expr = cast_t::make(type_t::s16(), masks.mask(0));
1385 if (is_simd_) expr = cast(-expr, type_t::s16());
1386 stmt_ = stmt_.append(store_t::make(zp_mask_, 0, expr));
1387 }
1388 }
1389
1390 const stmt_t &stmt() const { return stmt_; }
1391 expr_t ic_start() const { return ic_start_; }
1392 bool is_scalar() const { return is_scalar_; }
1393 bool is_simd() const { return is_simd_; }
1394 bool is_const() const { return is_const_; }
1395 bool is_bool() const { return !size_ || !is_simd() || is_scalar(); }
1396
1397 expr_t gen_mask(int base) const {
1398 auto null_mask = (is_bool()) ? expr_t() : expr_t(-1);
1399 if (!size_ || is_scalar_) return (size_) ? var_mask_ : null_mask;
1400 return word2bool(base % size_);
1401 }
1402
1403 expr_t maybe_gen_mask_let(const stmt_t &loop) const {
1404 return (size_ && is_scalar_)
1405 ? let_t::make(var_mask_, word2bool(0), loop)
1406 : loop;
1407 }
1408
1409 private:
1410 type_t w_type(int width = 1) const {
1411 return (is_bool()) ? type_t::u16(width) : type_t::s32(width);
1412 }
1413 int w_stride() const { return w_type().size(); }
1414
1415 expr_t word2bool(int off) const {
1416 auto type = (is_bool()) ? type_t::u16() : type_t::s16(16);
1417 expr_t load;
1418 if (is_wide_ && !is_bool()) {
1419 load = load_t::make(
1420 type, zp_mask_, off * 8 * w_stride(), w_stride());
1421 } else {
1422 load = load_t::make(
1423 type.scalar(), zp_mask_, off * w_stride(), w_stride());
1424 if (!is_bool()) load = shuffle_t::make_broadcast(load, 16);
1425 }
1426 return (is_bool()) ? cast_t::make(type_t::_bool(16), load) : load;
1427 }
1428
1429 expr_t vector2expr(const std::vector<expr_t> &expr,
1430 object_eq_map_t<expr_t, expr_t> &vars) const {
1431 constexpr size_t mask = 0x8000;
1432 auto hash = [&](const binary_op_t &b) -> size_t {
1433 return size_t(b.op_kind) | ((b.b.is<int_imm_t>()) ? mask : 0UL);
1434 };
1435 auto fetch_var = [this, &vars](expr_t e) {
1436 if (vars.find(e) == vars.end()) {
1437 auto var = lmb_.ir_ctx_.create_tmp_var(
1438 type_t::s32(e.type().elems()), "zp_mask");
1439 vars.emplace(e, var);
1440 }
1441 return vars[e];
1442 };
1443 if (expr.empty()) return expr_t();
1444 // Can only vectorize if the element count is a power of 2
1445 ir_assert(math::is_pow2(expr.size())) << "Cannot vectorize.";
1446
1447 std::unordered_map<size_t, size_t> kind;
1448 for (const expr_t &e : expr)
1449 if (const auto *bin = e.as_ptr<binary_op_t>()) {
1450 kind[hash(*bin)]++;
1451 auto bin_a = bin;
1452 while ((bin_a = bin_a->a.as_ptr<binary_op_t>())) {
1453 if (kind[hash(*bin_a)] > 0) {
1454 kind[hash(*bin)] += kind[hash(*bin_a)];
1455 break;
1456 }
1457 }
1458 }
1459 if (!kind.empty()) {
1460 using k_type = decltype(kind)::value_type;
1461 auto k = std::max_element(
1462 kind.begin(), kind.end(), [](k_type &a, k_type &b) {
1463 return a.second < b.second;
1464 });
1465 const auto k_raw = op_kind_t(k->first & (mask - 1));
1466 std::vector<expr_t> a, b;
1467 for (const expr_t &e : expr) {
1468 const auto *bin = e.as_ptr<binary_op_t>();
1469 if (bin && (hash(*bin) == k->first)) {
1470 a.emplace_back(bin->a);
1471 b.emplace_back(bin->b);
1472 } else {
1473 const int is_mul = (k_raw == op_kind_t::_mul);
1474 ir_assert(is_mul || (k_raw == op_kind_t::_add));
1475 a.emplace_back(e);
1476 b.emplace_back(is_mul);
1477 }
1478 }
1479 auto a_new = vector2expr(a, vars);
1480 auto b_new = vector2expr(b, vars);
1481 if (auto *a_bin = a_new.as_ptr<binary_op_t>())
1482 if ((a_bin->op_kind == op_kind_t::_add) && is_var(a_bin->b)
1483 && is_cmp_op(k_raw) && is_shuffle_const(b_new))
1484 for (auto &v : vars)
1485 if (v.second.is_equal(a_bin->b)) {
1486 auto fold = const_fold_non_recursive(
1487 b_new - v.first);
1488 return binary_op_t::make(negate_cmp_op(k_raw),
1489 fetch_var(fold), a_bin->a);
1490 }
1491 return binary_op_t::make(k_raw, a_new, b_new);
1492 }
1493
1494 size_t num_ints = 0;
1495 for (const expr_t &e : expr)
1496 num_ints += e.is<int_imm_t>();
1497 ir_assert((num_ints == 0) || (num_ints == expr.size()));
1498 if (num_ints == expr.size()) {
1499 auto offs = shuffle_t::make(expr);
1500 if (offs.as<shuffle_t>().is_broadcast()) return offs;
1501 return fetch_var(offs);
1502 }
1503
1504 size_t num_bools = 0;
1505 for (const expr_t &e : expr)
1506 num_bools += e.is<bool_imm_t>();
1507 ir_assert((num_bools == 0) || (num_bools == expr.size()));
1508 if (num_bools == expr.size()) return shuffle_t::make(expr);
1509
1510 ir_assert(expr.front().is<var_t>());
1511 for (const expr_t &e : expr)
1512 ir_assert(e.is_same(expr.front()));
1513 return shuffle_t::make_broadcast(expr.front(), int(expr.size()));
1514 }
1515
1516 load_multiply_builder_t &lmb_;
1517 bool is_const_;
1518 bool is_simd_;
1519 bool is_scalar_;
1520 bool is_wide_;
1521 int size_;
1522 expr_t ic_start_;
1523 expr_t var_mask_;
1524 const expr_t &zp_mask_;
1525 stmt_t stmt_;
1526 };
1527
1528 stmt_t maybe_add_src_zps(const view_t &a_view, const view_t &b_view,
1529 const multiply_builder_t &mul_builder, int i_buf, int j_buf) {
1530 if (!prb_.zp_cfg.do_src_compensation) return mul_builder.stmt();
1531 const bool is_runtime = prb_.zp_cfg.is_runtime_src_zero_points;
1532 const bool is_scalar = prb_.zp_cfg.is_common_src_zero_point;
1533 const bool is_mad = (cfg_.fma_kind() == fma_kind_t::mad);
1534 const int channels = utils::rnd_up_pow2(
1535 (!is_mad) ? (prb_.is_fwd) ? prb_.ic : prb_.oc : prb_.g);
1536 const int c_blk = (channels < 32) ? channels : 32;
1537 const int k_blk = ((channels > 4)) ? 32 / c_blk : 1;
1538 const int m_blk = cfg_.simd();
1539
1540 const type_t s_type = a_view.type();
1541 const type_t i_type = type_t::s32(); // x32 type that is always signed
1542 auto has_sign = [&]() {
1543 if (is_runtime) return s_type.is_signed();
1544 ir_assert(is_scalar);
1545 return prb_.zp_cfg.common_src_zero_point < 0;
1546 };
1547 const type_t d_type = (has_sign()) ? type_t::s32() : type_t::u32();
1548 ir_assert((is_mad) ? s_type.is_x16() : s_type.is_x8());
1549
1550 const int a_stride
1551 = s_type.size() * int(a_view.tlayout().blocks()[0].stride);
1552 int desc_m = 0, desc_n = 0;
1553
1554 if (!is_mad) {
1555 auto &mapper = gemm_schedule_.bmnk_mapper();
1556 auto a_layout = mapper.map_to_bmnk(abc_kind_t::a,
1557 {bmnk_kind_t::m, bmnk_kind_t::k}, a_view.create_vlayout());
1558 auto b_layout = mapper.map_to_bmnk(abc_kind_t::b,
1559 {bmnk_kind_t::k, bmnk_kind_t::n}, b_view.create_vlayout());
1560 if (mul_builder.do_transpose()) {
1561 a_layout = a_layout.transpose();
1562 b_layout = b_layout.transpose();
1563 std::swap(a_layout, b_layout);
1564 }
1565 multiply_desc_t desc(a_layout, b_layout, true);
1566 desc_m = desc.m();
1567 desc_n = desc.n();
1568 } else {
1569 desc_n = a_view.tlayout().size() / m_blk / a_stride;
1570 desc_m = m_blk;
1571 }
1572 if (zp_mask_.is_empty())
1573 zp_mask_ = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "zp_mask");
1574 src_zp_mask_info_t masks(*this, m_blk, k_blk, desc_m, desc_n, channels,
1575 a_stride, is_mad, a_view, zp_mask_, zp_mask_size_);
1576 stmt_t data = masks.stmt();
1577
1578 const int simd_per_ic = utils::div_up(
1579 std::min((!is_scalar) ? channels : 1, 32), m_blk);
1580 const std::vector<dim_t> dims
1581 = {m_blk * std::min((is_mad) ? 1 : 2, simd_per_ic)};
1582 const bool sc_ic = is_scalar || (channels <= 32);
1583 expr_t offs = (!sc_ic) ? masks.ic_start() * d_type.size() : 0;
1584
1585 if (is_runtime && !sc_ic && !cfg_.pipeline().do_unroll()
1586 && (cfg_.slm().bufs() > 1)) {
1587 auto buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "zp_mask");
1588 register_buffer(buf, type_t::u32().size(), alloc_kind_t::grf);
1589 data = data.append(store_t::make(buf, 0, offs));
1590 offs = load_t::make(type_t::u32(), buf, 0);
1591 }
1592
1593 auto get_src_zp_size = [](bool scalar, bool runtime, bool mad, int b) {
1594 if (scalar) return (!mad) ? b * 2 : ((runtime) ? b : 0);
1595 return (!mad) ? std::max(b * 2, 32) : 32;
1596 };
1597 const int m_blk_x2 = std::min(m_blk * 2, 16);
1598 const int src_zp_size = get_src_zp_size(
1599 is_scalar, is_runtime, is_mad, m_blk_x2 * k_blk);
1600 if (zp_buf_.is_empty())
1601 zp_buf_ = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "zp_buf");
1602 zp_buf_size_ = std::max(zp_buf_size_, src_zp_size * d_type.size());
1603
1604 for (int i = (is_runtime) ? 0 : std::numeric_limits<int>::max();
1605 i < m_blk * simd_per_ic; i += dims[0]) {
1606 const int b = i * d_type.size();
1607 view_t zpv(layout_t(d_type, 0, dims));
1608 auto read = make_access_builder(ir_ctx_, zpv,
1609 kernel_info_.find_arg("src_zero_points")[offs + b],
1610 zp_buf_[b], send_op_t::load, send_address_t::a64);
1611 data = data.append(read.stmt());
1612 }
1613
1614 if (is_mad) {
1615 // TODO: for now, only b-blocking (per G) of the MAD loop is ready;
1616 // please implement n-blocking (per OC) as well!
1617 ir_assert(a_view.tlayout().size() % a_stride == 0);
1618 ir_assert(masks.is_simd());
1619
1620 std::vector<stmt_t> loop(std::max(1, 32 / m_blk));
1621 for (int a_off = 0; a_off < a_view.tlayout().size();
1622 a_off += m_blk * a_stride) {
1623 int iter = (a_off / m_blk / a_stride) % loop.size();
1624 type_t sv_type(s_type.kind(), m_blk);
1625 type_t b_type(s_type.kind(), (!is_scalar) ? m_blk : 1);
1626 auto a = load_t::make(sv_type, a_buf_, a_off, a_stride);
1627 auto b_off
1628 = (!is_scalar && (channels > m_blk)) ? iter * m_blk : 0;
1629 auto b = (is_runtime) // '4'-s mean '(|i32| / |i16|) * |i16|'
1630 ? load_t::make(b_type, zp_buf_, b_off * 4, 4)
1631 : prb_.zp_cfg.common_src_zero_point;
1632 auto mask = masks.gen_mask(
1633 (utils::div_up(k_blk, 2)) * a_off / m_blk / a_stride);
1634 auto mad = (masks.is_bool())
1635 ? binary_op_t::make(op_kind_t::_sub, a, b, sv_type)
1636 : ternary_op_t::make(
1637 op_kind_t::_mad, a, mask, b, sv_type);
1638 loop[iter] = loop[iter].append(store_t::make(a_buf_, a_off, mad,
1639 a_stride, (masks.is_bool()) ? mask : expr_t()));
1640 }
1641 for (size_t i = 1; i < loop.size(); i++)
1642 loop[0] = loop[0].append(loop[i]);
1643 return data.append(masks.maybe_gen_mask_let(
1644 loop[0].append(mul_builder.stmt())));
1645 }
1646
1647 if (is_scalar) {
1648 expr_t expr = (!is_runtime)
1649 ? (prb_.zp_cfg.common_src_zero_point & 0xFF) * 0x01010101
1650 : cast_t::make(type_t::s8(4),
1651 shuffle_t::make_broadcast(
1652 load_t::make(s_type, zp_buf_, 0), 4));
1653 data = data.append(store_t::make(zp_buf_, 0, expr));
1654 } else {
1655 data = data.append(store_t::make(zp_buf_, 0,
1656 load_t::make(type_t::u8(m_blk_x2), zp_buf_, 0, 4)));
1657 if (channels > 16)
1658 data = data.append(store_t::make(zp_buf_, 16,
1659 load_t::make(type_t::u8(m_blk_x2), zp_buf_, 64, 4)));
1660 if (m_blk_x2 != m_blk)
1661 data = data.append(store_t::make(zp_buf_, 32,
1662 load_t::make(type_t::u32(4), zp_buf_, 4, 8), 8));
1663 }
1664 std::vector<stmt_t> parts;
1665
1666 auto wide_scalar = [m_blk](const expr_t &a, const expr_t &b, int blk) {
1667 if (blk == m_blk) return shuffle_t::make_broadcast(a, m_blk);
1668 std::vector<int> index(blk, 1);
1669 for (int i = 0; i < m_blk; i++)
1670 index[i] = 0;
1671 return shuffle_t::make(std::vector<expr_t> {a, b}, index);
1672 };
1673 auto wide_vector = [m_blk, i_type](const expr_t &a, int blk) {
1674 if (blk == m_blk)
1675 return load_t::make(type_t(i_type.kind(), m_blk), a, 0);
1676 std::vector<expr_t> vec(m_blk);
1677 std::vector<int> index(blk);
1678 for (int i = 0; i < m_blk; i++) {
1679 vec[i] = load_t::make(i_type, a, i * i_type.size());
1680 index[i + m_blk] = index[i] = i;
1681 }
1682 return shuffle_t::make(vec, index);
1683 };
1684 std::vector<expr_t> acc;
1685 for (int i = 1; i <= 2 * k_blk; i++)
1686 acc.emplace_back(
1687 zp_buf_[(src_zp_size
1688 - utils::div_up(
1689 i, m_blk != m_blk_x2 ? 1 : 2)
1690 * m_blk)
1691 * d_type.size()]);
1692 for (int i_m = 0; i_m < desc_m; i_m += m_blk) {
1693 const int blk
1694 = (masks.is_simd() || masks.is_scalar()) ? m_blk_x2 : m_blk;
1695 for (int i = 0; i < k_blk; i++) {
1696 for (int i_k = i * (c_blk / 4); i_k
1697 < ((channels > 4) ? (c_blk + i * c_blk) / 4 : prb_.kw);
1698 i_k += m_blk_x2 / m_blk) {
1699 type_t vi(i_type.kind(), m_blk_x2);
1700 const int szp_off = (is_scalar) ? 0 : (i_k * d_type.size());
1701 auto b0 = load_t::make(d_type, zp_buf_, szp_off);
1702 auto b1 = load_t::make(
1703 d_type, zp_buf_, szp_off + m_blk * d_type.size());
1704 auto b = (is_scalar) ? b0 : wide_scalar(b0, b1, m_blk_x2);
1705 auto c = load_t::make(vi, b_buf_,
1706 (i_m * (32 / 4) + i_k * m_blk) * d_type.size());
1707 if (is_scalar) std::swap(b, c);
1708 auto a = (i_k != i * (c_blk / 4))
1709 ? load_t::make(vi, acc[i * 2 + 1], 0)
1710 : expr_t(0);
1711 parts.emplace_back(store_t::make(acc[i * 2 + 1], 0,
1712 ternary_op_t::make(op_kind_t::_dp4a, a, b, c, vi)));
1713 }
1714
1715 if (m_blk_x2 != m_blk) {
1716 type_t vi(i_type.kind(), m_blk);
1717 auto a = load_t::make(vi, acc[i * 2 + 1], 0);
1718 auto b = load_t::make(vi, acc[i * 2 + 0], 0);
1719 parts.emplace_back(store_t::make(acc[i * 2 + 1], 0, a + b));
1720 if (!masks.is_bool() && (blk != m_blk))
1721 parts.emplace_back(store_t::make(acc[i * 2 + 0], 0, a));
1722 }
1723 }
1724 for (int i_n = 0; i_n < desc_n; i_n += blk / m_blk) {
1725 int off_n = i_m / m_blk * desc_n + i_n;
1726 const int ij_buf = i_buf * cfg_.subtiles().b() + j_buf;
1727 auto dst = c_buf_ + off_n * m_blk * d_type.size()
1728 + ij_buf * mul_builder.c_layout().size();
1729 type_t vi(i_type.kind(), blk);
1730 auto a = load_t::make(vi, dst, 0);
1731 for (int i = 0; i < k_blk; i++) {
1732 auto mask = masks.gen_mask(off_n * k_blk + i * blk / m_blk);
1733 if (!masks.is_bool()) {
1734 auto b = load_t::make(vi, acc[i * 2 + 1], 0);
1735 auto mad = ternary_op_t::make(
1736 op_kind_t::_mad, a, b, mask);
1737 parts.emplace_back(store_t::make(dst, 0, mad));
1738 } else {
1739 auto b = wide_vector(acc[i * 2 + 1], blk);
1740 auto sub = binary_op_t::make(op_kind_t::_sub, a, b);
1741 parts.emplace_back(store_t::make(
1742 dst, 0, sub, store_t::default_stride, mask));
1743 }
1744 }
1745 }
1746 }
1747 // Stick the compensations between DPASes for better GPU utilization
1748 auto raw_dpas = flatten_statements(mul_builder.stmt());
1749 std::vector<stmt_t> dpas;
1750 stmt_t full;
1751 expr_t src1;
1752 for (auto &r : raw_dpas) {
1753 ir_assert(is_func_call<dpas_t>(r));
1754 auto &this_src1 = dpas_t::arg_src1(r);
1755 if (this_src1.is_equal(src1)) {
1756 dpas.back() = dpas.back().append(r);
1757 } else {
1758 src1 = this_src1;
1759 dpas.emplace_back(r);
1760 }
1761 }
1762 ir_assert(parts.size() % dpas.size() == 0);
1763 const int loop_size = int(parts.size()) / int(dpas.size());
1764 for (int i = 0; i < int(dpas.size()); i++) {
1765 full = full.append(dpas[i]);
1766 const auto k = (i + int(dpas.size()) / 2) % int(dpas.size());
1767 for (int j = k * loop_size; j < (k + 1) * loop_size; j++)
1768 full = full.append(parts[j]);
1769 }
1770 return data.append(masks.maybe_gen_mask_let(full));
1771 }
1772
1773 void build_subtile(int i, int j) {
1774 bool is_first = (i == 0 && j == 0);
1775
1776 stmt_t ab_s2r_load;
1777 stmt_t ab_g2r_load;
1778 if (!a_subtiles_[i].is_loaded()) {
1779 ab_s2r_load = ab_s2r_load.append(a_subtiles_[i].s2r_load());
1780 ab_g2r_load = ab_g2r_load.append(a_subtiles_[i].g2r_load());
1781 a_subtiles_[i].set_loaded();
1782 }
1783 if (!b_subtiles_[j].is_loaded()) {
1784 ab_s2r_load = ab_s2r_load.append(b_subtiles_[j].s2r_load());
1785 ab_g2r_load = ab_g2r_load.append(b_subtiles_[j].g2r_load());
1786 b_subtiles_[j].set_loaded();
1787 }
1788 load_mul_stmt_ = load_mul_stmt_.append(
1789 stmt_group_t::make(stmt_label_t::g2r_load(i + j), ab_g2r_load));
1790 load_mul_stmt_ = load_mul_stmt_.append(
1791 stmt_group_t::make(stmt_label_t::s2r_load(i + j), ab_s2r_load));
1792
1793 auto &a_i_view = a_subtiles_[i].reg_view();
1794 auto &b_j_view = b_subtiles_[j].reg_view();
1795
1796 // Multiply C_i_j += A_i x B_j in GEMM notation.
1797 multiply_builder_t mul_builder(cfg_, gemm_schedule_.bmnk_mapper(),
1798 a_i_view, b_j_view, a_buf_, b_buf_, c_buf_[c_buf_off_]);
1799 c_subtile_layout_ = mul_builder.c_layout();
1800
1801 auto mul_total
1802 = maybe_add_src_zps(a_i_view, b_j_view, mul_builder, i, j);
1803
1804 c_buf_off_ += c_subtile_layout_.size();
1805 ir_trace() << "Multiply (" << i << ", " << j << "):\n"
1806 << mul_total.str() << std::endl;
1807
1808 load_mul_stmt_ = load_mul_stmt_.append(
1809 stmt_group_t::make(stmt_label_t::mul(i + j), mul_total));
1810
1811 if (!is_first) {
1812 ir_assert(mul_builder.c_layout() == c_subtile_layout_)
1813 << "Sub-tile layouts must be equal.";
1814 return;
1815 }
1816
1817 register_buffer(
1818 a_buf_, a_subtiles_[i].reg_buf_size(), alloc_kind_t::grf);
1819 register_buffer(
1820 b_buf_, b_subtiles_[j].reg_buf_size(), alloc_kind_t::grf);
1821 }
1822 void register_buffer(const stmt_t &alloc) {
1823 ir_assert(alloc.is<alloc_t>());
1824 allocs_.push_back(alloc);
1825 }
1826
1827 void register_buffer(const expr_t &buf, int size, alloc_kind_t kind,
1828 const alloc_attr_t &attr = {}) {
1829 register_buffer(alloc_t::make(buf, size, kind, attr));
1830 }
1831
1832 const conv_problem_t &prb_;
1833 const conv_config_t &cfg_;
1834 ir_context_t &ir_ctx_;
1835 const gemm_schedule_t &gemm_schedule_;
1836 const fma_helper_t &fma_helper_;
1837 b_reduce_context_t &b_reduce_ctx_;
1838
1839 expr_t ap_buf_;
1840 expr_t a_slm_buf_;
1841
1842 expr_t bp_buf_;
1843 expr_t b_slm_buf_;
1844
1845 expr_t zp_buf_;
1846 int zp_buf_size_ = 0;
1847
1848 expr_t zp_mask_;
1849 int zp_mask_size_ = 0;
1850
1851 layout_t c_reg_layout_;
1852
1853 expr_t ab_tmp_buf_;
1854 expr_t a_buf_;
1855 expr_t b_buf_;
1856 expr_t c_buf_;
1857
1858 // Per-thread views to multiply.
1859 view_t a_thr_view_;
1860 view_t b_thr_view_;
1861
1862 // Sub-tile indices.
1863 expr_t a_idx_;
1864 expr_t b_idx_;
1865
1866 // Sub-tile views.
1867 view_t a_i_view_;
1868 view_t b_j_view_;
1869
1870 tensor_t a_i_tile_;
1871 tensor_t b_j_tile_;
1872
1873 std::vector<subtile_info_t> a_subtiles_;
1874 std::vector<subtile_info_t> b_subtiles_;
1875
1876 std::vector<block_t> a_i_outer_blocks_;
1877 std::vector<block_t> b_j_outer_blocks_;
1878
1879 std::vector<stmt_t> allocs_;
1880
1881 stmt_t load_mul_stmt_;
1882
1883 int c_buf_off_ = 0;
1884 layout_t c_subtile_layout_;
1885
1886 const kernel_info_t &kernel_info_;
1887};
1888
1889class compute_builder_t {
1890public:
1891 compute_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
1892 const kernel_info_t &kernel_info)
1893 : cfg_(cfg)
1894 , ir_ctx_(ir_ctx)
1895 , b_reduce_ctx_(ir_ctx, cfg)
1896 , g2s_ctx_(ir_ctx)
1897 , fma_helper_(cfg.simd(), cfg.fma_kind(), cfg.prb().a_data_type,
1898 cfg.prb().b_data_type, cfg.allow_a_grf_reorder(),
1899 cfg.allow_b_grf_reorder(), !cfg.prb().is_dw)
1900 , kernel_info_(kernel_info) {}
1901
1902 int ab_slm_size() const { return ab_slm_size_; }
1903
1904 const stmt_t &c_zero_out_stmt() const { return c_zero_out_stmt_; }
1905 const stmt_t &b_reduced_zero_out_stmt() const {
1906 return b_reduced_zero_out_stmt_;
1907 }
1908
1909 stmt_t zero_out_stmt() const {
1910 stmt_t ret;
1911 ret = ret.append(c_zero_out_stmt());
1912 ret = ret.append(b_reduced_zero_out_stmt());
1913 return ret;
1914 }
1915
1916 stmt_t iter_stmt() const {
1917 stmt_t stmt;
1918 bool use_prefetch = !prefetch_stmt_.is_empty();
1919 bool use_slm = !g2s_load_stmt_.is_empty();
1920 if (use_prefetch) {
1921 stmt = stmt.append(stmt_group_t::make(
1922 stmt_label_t::prefetch(), prefetch_stmt_));
1923 } else if (use_slm) {
1924 stmt = stmt.append(stmt_group_t::make(
1925 stmt_label_t::g2s_load(), g2s_load_stmt_));
1926 stmt = stmt.append(funcs::barrier());
1927 stmt = stmt.append(stmt_group_t::make(
1928 stmt_label_t::g2s_store(), g2s_store_stmt_));
1929 stmt = stmt.append(funcs::barrier());
1930 }
1931 stmt = stmt.append(load_mul_stmt_);
1932 return stmt;
1933 }
1934
1935 const stmt_t &c_store_stmt() const { return c_store_stmt_; }
1936 const stmt_t &b_reduced_store_stmt() const { return b_reduced_store_stmt_; }
1937
1938 stmt_t inject_compute_alloc_stmts(const stmt_t &stmt) const {
1939 return jit::inject_alloc_stmts(stmt, compute_allocs_);
1940 }
1941
1942 stmt_t inject_out_alloc_stmts(const stmt_t &stmt) const {
1943 return jit::inject_alloc_stmts(stmt, out_allocs_);
1944 }
1945
1946 stmt_t inject_let_stmts(const stmt_t &stmt) const {
1947 return jit::inject_let_stmts(stmt, g2s_ctx_.grid_idx_lets);
1948 }
1949
1950 void set_gemm_schedule(const gemm_schedule_t &gemm_schedule) {
1951 gemm_schedule_ = gemm_schedule;
1952 }
1953
1954 // Setters for original AP/BP/CP buffers (P - problem notation).
1955 void set_ap_buf(const expr_t &buf) { ap_buf_ = buf; }
1956 void set_bp_buf(const expr_t &buf) { bp_buf_ = buf; }
1957 void set_cp_buf(const expr_t &buf) { cp_buf_ = buf; }
1958 void set_b_reduced_mem_buf(const expr_t &buf) {
1959 b_reduce_ctx_.set_b_reduced_mem_buf(buf);
1960 }
1961
1962 void set_b_reduced_view(const view_t &v) {
1963 b_reduce_ctx_.set_b_reduced_view(v);
1964 }
1965
1966 void set_post_op_context(const post_op_context_t &post_op_ctx) {
1967 post_op_ctx_ = post_op_ctx;
1968 }
1969
1970 void set_reduce_condition(const expr_t &cond) {
1971 b_reduce_ctx_.set_reduce_condition(cond);
1972 }
1973
1974 void build() {
1975 // Initialize SLM buffers.
1976 expr_t a_slm_buf = make_buffer("a_slm");
1977 expr_t b_slm_buf = make_buffer("b_slm");
1978
1979 view_t ap_gmem_view = gemm_schedule_.a_tg_view();
1980 view_t bp_gmem_view = gemm_schedule_.b_tg_view();
1981
1982 // Views to multiply by a thread group (either GMEM or SLM).
1983 view_t ap_x_view;
1984 view_t bp_x_view;
1985 prepare_gmem_to_slm("A", cfg_.slm().a(), gemm_schedule_.a_tg_tile(),
1986 ap_gmem_view, ap_buf_, a_slm_buf, ap_x_view, g2s_ctx_);
1987 prepare_gmem_to_slm("B", cfg_.slm().b(), gemm_schedule_.b_tg_tile(),
1988 bp_gmem_view, bp_buf_, b_slm_buf, bp_x_view, g2s_ctx_);
1989 prepare_prefetch("A", cfg_.prefetch(), ap_gmem_view, ap_buf_);
1990 prepare_prefetch("B", cfg_.prefetch(), bp_gmem_view, bp_buf_);
1991
1992 if (ap_x_view.is_empty()) ap_x_view = ap_gmem_view;
1993 if (bp_x_view.is_empty()) bp_x_view = bp_gmem_view;
1994
1995 for (auto &bi : g2s_ctx_.bufs) {
1996 register_compute_buffer(bi.buf, bi.size, alloc_kind_t::grf);
1997 }
1998
1999 load_multiply_builder_t load_mul_builder(cfg_, ir_ctx_, gemm_schedule_,
2000 fma_helper_, b_reduce_ctx_, ap_buf_, a_slm_buf, bp_buf_,
2001 b_slm_buf, ap_x_view, bp_x_view, kernel_info_);
2002
2003 load_mul_stmt_ = load_mul_builder.load_mul_stmt();
2004 compute_allocs_.insert(compute_allocs_.end(),
2005 load_mul_builder.allocs().begin(),
2006 load_mul_builder.allocs().end());
2007
2008 auto c_buf = load_mul_builder.c_buf();
2009 int c_size = load_mul_builder.c_reg_layout().size();
2010 int c_size_grf_rounded = utils::rnd_up(c_size, ir_ctx_.grf_size());
2011 register_out_buffer(c_buf, c_size_grf_rounded, alloc_kind_t::grf);
2012
2013 auto c_thr_reg_layout = load_mul_builder.c_reg_layout();
2014 auto thr_tile = gemm_schedule_.c_thr_tile(/*is_relative=*/false);
2015
2016 auto reduce_cond = expr_t();
2017 if (gemm_schedule_.with_thread_group_k_slicing()) {
2018 slm_reduce_builder_t slm_reduce_builder(ir_ctx_,
2019 gemm_schedule_.tg_grid(), c_buf, c_thr_reg_layout,
2020 thr_tile);
2021 c_store_stmt_ = c_store_stmt_.append(slm_reduce_builder.stmt());
2022 c_thr_reg_layout = slm_reduce_builder.reg_layout();
2023 thr_tile = slm_reduce_builder.thr_tile();
2024 reduce_cond = slm_reduce_builder.reduce_cond();
2025 }
2026
2027 auto c_thr_mem_view = gemm_schedule_.c_view().create_sub_view(thr_tile);
2028 auto c_m2g_stmt = create_epilogue_stmt(cfg_, ir_ctx_, gemm_schedule_,
2029 post_op_ctx_, thr_tile, c_thr_mem_view, c_thr_reg_layout,
2030 cp_buf_, c_buf);
2031 if (!reduce_cond.is_empty())
2032 c_m2g_stmt = if_t::make(reduce_cond, c_m2g_stmt);
2033 ir_trace() << "C GRF to GMEM store:\n" << c_m2g_stmt << std::endl;
2034
2035 c_zero_out_stmt_ = stmt_group_t::make(stmt_label_t::c_zero_out(),
2036 create_zero_out_stmt(ir_ctx_, c_buf, c_size));
2037 c_store_stmt_ = c_store_stmt_.append(c_m2g_stmt);
2038
2039 if (cfg_.reduce_b()) {
2040 auto &ctx = b_reduce_ctx_;
2041 b_reduced_zero_out_stmt_ = create_zero_out_stmt(
2042 ir_ctx_, ctx.b_reduced_reg_buf(), ctx.b_reduced_size());
2043 b_reduced_store_stmt_ = ctx.create_store_stmt(
2044 gemm_schedule_.with_kernel_grid_k_slicing()
2045 || cfg_.slm().b());
2046 register_out_buffer(ctx.b_reduced_reg_buf(), ctx.b_reduced_size(),
2047 alloc_kind_t::grf);
2048 }
2049
2050 // Replace DPAS by DPASW when applicable.
2051 if (cfg_.fma_kind() == fma_kind_t::dpasw) {
2052 alloc_updater_t alloc_updater;
2053 inject_dpasw(ir_ctx_.hw(), load_mul_stmt_, c_buf, c_store_stmt_,
2054 alloc_updater, gemm_schedule_.tg_grid().idx(0));
2055 for (auto &a : compute_allocs_) {
2056 a = alloc_updater.update(a);
2057 }
2058 for (auto &a : out_allocs_) {
2059 a = alloc_updater.update(a);
2060 }
2061 }
2062
2063 // Assign {Atomic} for DPAS(W) when applicable.
2064 load_mul_stmt_ = inject_atomic(load_mul_stmt_);
2065 }
2066
2067private:
2068 struct buf_info_t {
2069 buf_info_t(const std::string &tag, const expr_t &buf)
2070 : tag(tag), buf(buf) {}
2071
2072 std::string tag;
2073 expr_t buf;
2074 int size = 0;
2075 };
2076
2077 struct g2s_context_t {
2078 g2s_context_t(ir_context_t &ir_ctx) : ir_ctx(ir_ctx) {}
2079
2080 expr_t create_buf(const char *tag, bool force_reuse = false) {
2081 if (reuse_buffers || force_reuse) {
2082 for (auto &bi : bufs) {
2083 if (bi.tag == tag) return bi.buf;
2084 }
2085 }
2086 auto buf = ir_ctx.create_tmp_var(type_t::byte_ptr(), tag);
2087 bufs.emplace_back(tag, buf);
2088 return buf;
2089 }
2090
2091 void set_buf_size(const expr_t &buf, int size) {
2092 for (auto &bi : bufs) {
2093 if (bi.buf.is_same(buf)) bi.size = std::max(bi.size, size);
2094 }
2095 }
2096
2097 expr_t create_tmp_grid_idx() {
2098 auto var = ir_ctx.create_tmp_var(type_t::s32(), "idx");
2099 tmp_grid_idxs.insert({var, expr_t()});
2100 return var;
2101 }
2102
2103 void set_grid_idx_value(const expr_t &idx, const expr_t &value) {
2104 auto &old = tmp_grid_idxs[idx];
2105 ir_assert(old.is_empty());
2106 old = substitute_grid_idx_value(value);
2107 }
2108
2109 expr_t substitute_grid_idx_value(const expr_t &_e) {
2110 auto e = _e;
2111 auto vars = find_unique_objects<var_t>(e);
2112 for (auto &v : vars) {
2113 auto it = tmp_grid_idxs.find(v);
2114 if (it == tmp_grid_idxs.end()) continue;
2115 e = substitute(e, v, it->second);
2116 }
2117 return e;
2118 }
2119
2120 void register_grid(const grid_info_t &grid) {
2121 for (int i = 0; i < grid.ndims(); i++) {
2122 auto &idx = grid.idx(i);
2123 auto it = tmp_grid_idxs.find(idx);
2124 if (it == tmp_grid_idxs.end()) continue;
2125 grid_idx_lets.emplace_back(let_t::make(idx, it->second));
2126 }
2127 }
2128
2129 ir_context_t &ir_ctx;
2130 grid_info_t prev_load_grid;
2131 bool reuse_buffers = false;
2132 std::vector<buf_info_t> bufs;
2133
2134 object_map_t<expr_t, expr_t> tmp_grid_idxs;
2135 std::vector<stmt_t> grid_idx_lets;
2136 };
2137
2138 void register_compute_buffer(const expr_t &buf, int size, alloc_kind_t kind,
2139 const alloc_attr_t &attr = {}) {
2140 compute_allocs_.push_back(alloc_t::make(buf, size, kind, attr));
2141 }
2142
2143 void register_out_buffer(const expr_t &buf, int size, alloc_kind_t kind,
2144 const alloc_attr_t &attr = {}) {
2145 out_allocs_.push_back(alloc_t::make(buf, size, kind, attr));
2146 }
2147
2148 // Handles GMEM to SLM load for A and B. Done in two steps:
2149 // 1. Load: GMEM -> GRF (temporary)
2150 // 2. Store: GRF (temporary) -> SLM
2151 void prepare_gmem_to_slm(const char *tag, bool use_x_slm,
2152 const tensor_t &tg_tile, const view_t &x_gmem_view,
2153 const expr_t &xp_buf, const expr_t &x_slm_buf, view_t &x_slm_view,
2154 g2s_context_t &g2s_ctx) {
2155 if (!use_x_slm) return;
2156
2157 grid_info_t load_grid = gemm_schedule_.tg_grid();
2158 for (;;) {
2159 bool ok = prepare_gmem_to_slm_impl(tag, use_x_slm, tg_tile,
2160 x_gmem_view, xp_buf, x_slm_buf, x_slm_view, load_grid,
2161 g2s_ctx);
2162 if (ok) {
2163 g2s_ctx.prev_load_grid = load_grid;
2164 g2s_ctx.register_grid(load_grid);
2165 return;
2166 }
2167
2168 // Reduce grid and try again.
2169 auto grid_idx = g2s_ctx.create_tmp_grid_idx();
2170 int dim_idx;
2171 expr_t grid_idx_value;
2172 auto new_load_grid
2173 = load_grid.halven(grid_idx, dim_idx, grid_idx_value);
2174 if (new_load_grid.is_empty()) break;
2175
2176 if (new_load_grid == g2s_ctx.prev_load_grid) {
2177 new_load_grid = load_grid.halven(
2178 grid_idx, dim_idx, grid_idx_value, /*first=*/false);
2179 g2s_ctx.reuse_buffers = true;
2180 }
2181 g2s_ctx.set_grid_idx_value(grid_idx, grid_idx_value);
2182
2183 ir_ctx_.add_constraint(grid_idx >= 0);
2184 ir_ctx_.add_constraint(grid_idx < new_load_grid.dim(dim_idx));
2185
2186 load_grid = new_load_grid;
2187 }
2188 ir_error_not_expected() << "Can't create GMEM -> SLM loads/stores.";
2189 }
2190
2191 bool prepare_gmem_to_slm_impl(const char *tag, bool use_x_slm,
2192 const tensor_t &tg_tile, const view_t &x_gmem_view,
2193 const expr_t &xp_buf, const expr_t &x_slm_buf, view_t &x_slm_view,
2194 const grid_info_t &load_grid, g2s_context_t &g2s_ctx) {
2195 bool is_a = (tag[0] == 'A');
2196 abc_kind_t ab_kind = (is_a ? abc_kind_t::a : abc_kind_t::b);
2197
2198 auto xp_slm_layout = create_slm_layout(x_gmem_view, ab_kind, load_grid);
2199
2200 auto grid_cond = load_grid.slice_condition();
2201
2202 // Per-thread tile and view to load from GMEM and store to SLM.
2203 tensor_t thr_tile;
2204 view_t x_g2s_view;
2205 if (cfg_.allow_slm_tg_slicing()) {
2206 x_g2s_view = x_gmem_view.split(load_grid, thr_tile);
2207 } else {
2208 thr_tile = xp_slm_layout.split(load_grid);
2209 x_g2s_view = x_gmem_view.create_sub_view(thr_tile);
2210 }
2211
2212 auto bound_cond = expr_t();
2213 if (is_a && !cfg_.fuse_spatial()
2214 && thr_tile.elems() * load_grid.elems()
2215 != xp_slm_layout.elems()) {
2216 for (int i = 0; i < x_gmem_view.nvdims(); i++) {
2217 if (!x_g2s_view.vstart(i).is_equal(x_gmem_view.vstart(i))) {
2218 auto dim_expr
2219 = x_g2s_view.vstart(i) - x_gmem_view.vstart(i);
2220 if (bound_cond.is_empty())
2221 bound_cond = dim_expr < x_gmem_view.vdims()[i];
2222 else
2223 bound_cond &= dim_expr < x_gmem_view.vdims()[i];
2224 }
2225 }
2226 }
2227 if (!bound_cond.is_empty()) {
2228 if (!grid_cond.is_empty())
2229 grid_cond = grid_cond & bound_cond;
2230 else
2231 grid_cond = bound_cond;
2232 }
2233
2234 auto slm_thr_layout = xp_slm_layout.map(thr_tile);
2235
2236 // Ensure that each thread writes a dense region to SLM. If the layout
2237 // is not dense, return and try with smaller grid.
2238 if (!slm_thr_layout.is_dense()) return false;
2239
2240 register_compute_buffer(
2241 x_slm_buf, xp_slm_layout.size(), alloc_kind_t::slm);
2242 ab_slm_size_ += xp_slm_layout.size();
2243
2244 // Temporary GRF buffer.
2245 expr_t x_g2s_reg_buf = g2s_ctx.create_buf("g2s");
2246
2247 // GMEM -> GRF load.
2248 auto x_read = make_access_builder(ir_ctx_, x_g2s_view, xp_buf,
2249 x_g2s_reg_buf, send_op_t::load, send_address_t::a64);
2250 ir_trace() << tag << " GMEM to GRF load:\n"
2251 << x_read.str() << std::endl;
2252
2253 g2s_ctx.set_buf_size(x_g2s_reg_buf, x_read.reg_buf_size());
2254
2255 auto load_stmt = x_read.stmt();
2256 if (!grid_cond.is_empty()) load_stmt = if_t::make(grid_cond, load_stmt);
2257 g2s_load_stmt_ = g2s_load_stmt_.append(load_stmt);
2258
2259 // GRF -> SLM store.
2260 auto x_write = make_access_builder(ir_ctx_, view_t(slm_thr_layout),
2261 x_slm_buf, x_g2s_reg_buf, send_op_t::store,
2262 send_address_t::slm);
2263 ir_trace() << tag << " GRF to SLM store:\n"
2264 << x_write.str() << std::endl;
2265 auto store_stmt = x_write.stmt();
2266
2267 auto &read_layout = x_read.reg_layout();
2268 auto &write_layout = x_write.reg_layout();
2269 if (read_layout != write_layout) {
2270 if (is_a ? cfg_.allow_a_grf_reorder()
2271 : cfg_.allow_b_grf_reorder()) {
2272 // Temporary GRF buffer.
2273 expr_t tmp_buf
2274 = g2s_ctx.create_buf("g2s_tmp", /*force_reuse=*/true);
2275 auto reorder_stmt = create_reorder_stmt(
2276 read_layout, write_layout, x_g2s_reg_buf, tmp_buf);
2277 g2s_ctx.set_buf_size(tmp_buf, x_write.reg_buf_size());
2278 store_stmt = substitute(store_stmt, x_g2s_reg_buf, tmp_buf);
2279 store_stmt = reorder_stmt.append(store_stmt);
2280 } else {
2281 ir_error_not_expected() << "Requested register layouts for "
2282 << tag << " do not match: "
2283 << "read: " << read_layout
2284 << ", write: " << write_layout;
2285 }
2286 }
2287 // Generate reduction statement for B.
2288 if (!is_a && cfg_.reduce_b()) {
2289 auto absolute_thr_tile = tg_tile.create_sub_tensor(thr_tile);
2290 b_reduce_ctx_.init_reduced_thr_view(absolute_thr_tile, grid_cond);
2291 auto reduce_stmt = b_reduce_ctx_.create_reduce_stmt(
2292 read_layout, x_g2s_reg_buf);
2293 store_stmt = reduce_stmt.append(store_stmt);
2294 }
2295 if (!grid_cond.is_empty())
2296 store_stmt = if_t::make(grid_cond, store_stmt);
2297 g2s_store_stmt_ = g2s_store_stmt_.append(store_stmt);
2298
2299 x_slm_view = view_t(xp_slm_layout);
2300
2301 return true;
2302 }
2303
2304 void prepare_prefetch(const char *tag, bool use_prefetch,
2305 const view_t &x_gmem_view, const expr_t &xp_buf) {
2306 if (!use_prefetch) return;
2307
2308 // Per-thread view to prefetch from GMEM.
2309 auto thr_view = x_gmem_view.split(gemm_schedule_.tg_grid());
2310
2311 auto send_hint = get_send_hint(ir_ctx_.exec_cfg(), send_op_t::prefetch,
2312 (tag[0] == 'A') ? abc_kind_t::a : abc_kind_t::b, thr_view,
2313 gemm_schedule_);
2314
2315 // GMEM prefetch.
2316 auto x_prefetch = make_access_builder(ir_ctx_, thr_view, xp_buf,
2317 expr_t(), send_op_t::prefetch, send_address_t::a64, send_hint);
2318
2319 // too many prefetches degrades performance
2320 if (find_objects<func_call_t>(x_prefetch.stmt()).size() > 16) {
2321 ir_warning() << "Dropping excessive prefetches." << std::endl;
2322 prefetch_stmt_ = stmt_t();
2323 } else {
2324 ir_trace() << tag << " GMEM prefetch:\n"
2325 << x_prefetch.str() << std::endl;
2326 prefetch_stmt_ = prefetch_stmt_.append(x_prefetch.stmt());
2327 }
2328 }
2329
2330 layout_t create_slm_layout(const view_t &tg_view, abc_kind_t abc_kind,
2331 const grid_info_t &load_grid) const {
2332 auto layout = tg_view.create_dense_vlayout();
2333 auto ret = fma_helper_.convert_to_fma_friendly_layout(layout, abc_kind,
2334 /*is_slm=*/true, gemm_schedule_.bmnk_mapper());
2335 if (cfg_.pad_slm()) ret = pad_slm_layout(ret, load_grid);
2336 return ret.normalize();
2337 }
2338
2339 // SLM has 65 dword-granularity banks (Xe_HP):
2340 // banks: [bank 0] [bank 1] [bank 2] ... [bank 0]
2341 // byte offsets: | 0 | 4 | 8 ... | 4 * 65
2342 // SLM reads don't have conflicts. During SLM writes each fused EU writes
2343 // 64 bytes (in total 128 bytes per clock). If there are repeating banks
2344 // between 128 bytes the write takes 2 clocks to complete.
2345 // Assume that every X-axis thread (across tg_dim[0]) writes the
2346 // corresponding outer block of the layout. The goal is to ensure that the
2347 // stride between outer blocks allows to avoid duplicated banks.
2348 layout_t pad_slm_layout(
2349 const layout_t &layout, const grid_info_t &load_grid) const {
2350 // EUs are not fused in XeHPC+ so no need to pad SLM.
2351 if (ir_ctx_.hw() >= ngen::HW::XeHPC) return layout;
2352 auto tg_dim0 = load_grid.dim(0);
2353 auto tg_dim1 = load_grid.dim(1);
2354 int type_size = layout.type().size();
2355
2356 ir_assert(layout.elems() % tg_dim0 == 0) << layout;
2357 dim_t inner_block = layout.elems() / tg_dim0;
2358
2359 ir_assert((inner_block * type_size) % tg_dim1 == 0) << layout;
2360 dim_t per_thr_bytes = (inner_block * type_size) / tg_dim1;
2361
2362 std::vector<dim_t> multi_blocks = {inner_block, tg_dim0};
2363 auto l = layout.split_into_multi_blocks(multi_blocks);
2364
2365 if (l.is_empty()) {
2366 ir_warning() << "Couldn't split layout for SLM padding."
2367 << std::endl;
2368 return layout;
2369 }
2370 auto padded_blocks = l.blocks();
2371 dim_t stride = -1;
2372 dim_t remaining_elems = inner_block;
2373 bool past_inner_block = remaining_elems == 1;
2374 for (auto &b : padded_blocks) {
2375 if (past_inner_block) {
2376 if (stride == -1) {
2377 dim_t stride_bytes = find_min_stride_without_conflicts(
2378 per_thr_bytes, dim_t(b.stride) * type_size);
2379 ir_assert(stride_bytes % type_size == 0);
2380 stride = stride_bytes / type_size;
2381 }
2382 b.stride = stride;
2383 stride = b.stride * b.block;
2384 continue;
2385 }
2386 ir_assert(remaining_elems % b.block == 0);
2387 remaining_elems /= b.block;
2388 if (remaining_elems == 1) past_inner_block = true;
2389 }
2390 return layout_t(
2391 layout.type(), layout.ndims(), layout.offset(), padded_blocks);
2392 }
2393
2394 dim_t find_min_stride_without_conflicts(
2395 dim_t inner_bytes, dim_t dense_stride_bytes) const {
2396 int write_step = 64;
2397 int stride_step = 16;
2398 dim_t stride_beg = dense_stride_bytes;
2399 dim_t stride_end = 2 * dense_stride_bytes;
2400 auto arch = convert_ngen_arch_to_dnnl(ir_ctx_.hw());
2401 const int slm_banks
2402 = compute::device_info_t::slm_memory_bank_count(arch);
2403 const int bank_granularity
2404 = compute::device_info_t::slm_memory_bank_granularity(arch);
2405 for (dim_t s = stride_beg; s < stride_end; s += stride_step) {
2406 bool ok = true;
2407 for (dim_t off0 = 0; off0 < inner_bytes; off0 += write_step) {
2408 // Check banks for a single SLM write.
2409 std::vector<bool> found(slm_banks, false);
2410 for (dim_t off = off0; off < off0 + write_step;
2411 off += bank_granularity) {
2412 int bank0 = (off / bank_granularity) % slm_banks;
2413 int bank1 = ((off + s) / bank_granularity) % slm_banks;
2414 if (found[bank0]) {
2415 ok = false;
2416 break;
2417 }
2418 found[bank0] = true;
2419 if (found[bank1]) {
2420 ok = false;
2421 break;
2422 }
2423 found[bank1] = true;
2424 }
2425 if (ok) return s;
2426 }
2427 }
2428
2429 ir_warning()
2430 << "Couldn't find stride without conflicts for SLM padding."
2431 << std::endl;
2432
2433 return dense_stride_bytes;
2434 }
2435
2436 const conv_config_t &cfg_;
2437 ir_context_t &ir_ctx_;
2438 post_op_context_t post_op_ctx_;
2439 b_reduce_context_t b_reduce_ctx_;
2440
2441 g2s_context_t g2s_ctx_;
2442 fma_helper_t fma_helper_;
2443
2444 gemm_schedule_t gemm_schedule_;
2445
2446 expr_t ap_buf_;
2447 expr_t bp_buf_;
2448 expr_t cp_buf_;
2449
2450 std::vector<stmt_t> compute_allocs_;
2451 std::vector<stmt_t> out_allocs_;
2452 int ab_slm_size_ = 0;
2453
2454 stmt_t g2s_load_stmt_;
2455 stmt_t g2s_store_stmt_;
2456 stmt_t prefetch_stmt_;
2457 stmt_t load_mul_stmt_;
2458
2459 stmt_t c_zero_out_stmt_;
2460 stmt_t c_store_stmt_;
2461
2462 stmt_t b_reduced_zero_out_stmt_;
2463 stmt_t b_reduced_store_stmt_;
2464
2465 const kernel_info_t &kernel_info_;
2466};
2467
2468class compute_loop_label_injector_t : public ir_mutator_t {
2469public:
2470 object_t _mutate(const for_t &obj) override {
2471 if (injected_) return obj;
2472
2473 bool found_continue = false;
2474 auto calls = find_objects<func_call_t>(obj);
2475 for (auto &_c : calls) {
2476 auto &c = _c.as<func_call_t>();
2477 if (c.func.is_equal(funcs::continue_func())) found_continue = true;
2478 }
2479
2480 if (!found_continue) {
2481 injected_ = true;
2482 return stmt_group_t::make(stmt_label_t::compute_loop(), obj);
2483 }
2484 return ir_mutator_t::_mutate(obj);
2485 }
2486
2487private:
2488 bool injected_ = false;
2489};
2490
2491// Injects compute_loop statement label to the outermost loop that can be
2492// pipelined. If a loop contains a "continue" function call it can't be
2493// pipelined because of conditional flow.
2494stmt_t inject_compute_loop_label(const stmt_t &s) {
2495 return compute_loop_label_injector_t().mutate(s);
2496}
2497
2498void conv_ir_builder_t::build() {
2499 constraint_set_t init_cset;
2500
2501 trace_reset();
2502
2503 std::vector<stmt_t> init_stmts;
2504 init_kernel_grid(cfg_.kernel_grid(), cfg_.thread_group_grid(), cfg_.simd(),
2505 init_cset, init_stmts);
2506
2507 gemm_schedule_t gemm_schedule(
2508 init_cset, cfg_.kernel_grid(), cfg_.thread_group_grid());
2509
2510 // Initialize memory buffers.
2511 std::vector<stmt_t> inner_lets;
2512
2513 view_t a_view;
2514 view_t b_view;
2515 view_t c_view;
2516 view_t bp_reduced_view;
2517
2518 expr_t ap_buf;
2519 expr_t bp_buf;
2520 expr_t cp_buf;
2521 expr_t b_reduced_mem_buf;
2522 expr_t b_reduction_condition;
2523
2524 if (cfg_.prb().is_fwd) {
2525 init_fwd(gemm_schedule, a_view, b_view, c_view, ap_buf, bp_buf, cp_buf);
2526 } else if (cfg_.prb().is_bwd_d) {
2527 init_bwd_d(
2528 gemm_schedule, a_view, b_view, c_view, ap_buf, bp_buf, cp_buf);
2529 } else if (cfg_.prb().is_bwd_w) {
2530 init_bwd_w(gemm_schedule, a_view, b_view, c_view, bp_reduced_view,
2531 ap_buf, bp_buf, cp_buf, b_reduced_mem_buf,
2532 b_reduction_condition);
2533 } else {
2534 ir_error_not_expected();
2535 }
2536
2537 gemm_schedule.finalize();
2538
2539 trace_stamp("GEMM Schedule");
2540
2541 ir_context_t ir_ctx(cfg_.exec_cfg(), init_cset);
2542 post_op_context_t post_op_ctx(cfg_, gemm_schedule, kernel_info_);
2543 compute_builder_t cb(cfg_, ir_ctx, kernel_info_);
2544
2545 cb.set_gemm_schedule(gemm_schedule);
2546 cb.set_ap_buf(ap_buf);
2547 cb.set_bp_buf(bp_buf);
2548 cb.set_cp_buf(cp_buf);
2549 cb.set_b_reduced_mem_buf(b_reduced_mem_buf);
2550 cb.set_b_reduced_view(bp_reduced_view);
2551 cb.set_post_op_context(post_op_ctx);
2552 cb.set_reduce_condition(b_reduction_condition);
2553
2554 cb.build();
2555
2556 trace_stamp("Compute Builder");
2557
2558 std::vector<stmt_t> allocs;
2559 for (int i = 0; i < kernel_info_.nargs(); i++) {
2560 auto &var = kernel_info_.arg_var(i);
2561 if (!var.type().is_ptr()) continue;
2562 allocs.push_back(alloc_t::make(var, 0, alloc_kind_t::global));
2563 }
2564
2565 // Create IR statements.
2566 stmt_t loop_stmt = cb.iter_stmt();
2567 loop_stmt = gemm_schedule.create_loop_nest(loop_stmt);
2568 loop_stmt = inject_compute_loop_label(loop_stmt);
2569 loop_stmt = cb.inject_compute_alloc_stmts(loop_stmt);
2570
2571 stmt_t c_store_stmt;
2572 c_store_stmt = c_store_stmt.append(cb.b_reduced_store_stmt());
2573 c_store_stmt = c_store_stmt.append(cb.c_store_stmt());
2574 c_store_stmt = stmt_group_t::make(stmt_label_t::c_store(), c_store_stmt);
2575
2576 stmt_ = loop_stmt;
2577 stmt_ = stmt_seq_t::make(cb.zero_out_stmt(), stmt_);
2578 stmt_ = stmt_seq_t::make(stmt_, c_store_stmt);
2579
2580 stmt_ = cb.inject_out_alloc_stmts(stmt_);
2581 stmt_ = cb.inject_let_stmts(stmt_);
2582
2583 stmt_ = gemm_schedule.create_bind_stmt(stmt_);
2584 stmt_ = inject_let_stmts(stmt_, init_stmts);
2585 stmt_ = inject_alloc_stmts(stmt_, allocs);
2586 trace_stop("Create Inital IR");
2587
2588 stmt_ = inject_external_var_let(stmt_, ir_ctx);
2589 stmt_ = merge_slm_buffers(stmt_, ir_ctx);
2590 if (!cfg_.pipeline().do_unroll() && cfg_.slm()) {
2591 stmt_ = inject_simple_slm_buffering(
2592 stmt_, ir_ctx, cfg_, cb.ab_slm_size());
2593 } else if (!cfg_.pipeline().do_unroll() && cfg_.prefetch()) {
2594 // Simplify to remove loops with only 1 iteration
2595 stmt_ = simplify(stmt_, ir_ctx);
2596 stmt_ = inject_prefetch_pipeline(stmt_, ir_ctx, cfg_);
2597 }
2598 stmt_ = inject_slm_reorder(
2599 stmt_, ir_ctx, cfg_.thread_group_grid(), cfg_.slm());
2600 stmt_ = lift_buffer_offsets_in_send(stmt_, ir_ctx);
2601 stmt_ = simplify(stmt_, ir_ctx);
2602 stmt_ = inject_send(stmt_, ir_ctx);
2603 stmt_ = split_wide_stores(stmt_, ir_ctx);
2604 stmt_ = lift_alloc(stmt_, ir_ctx, cfg_.pipeline().reuse_headers());
2605 stmt_ = lift_send_2d_header_store(stmt_, ir_ctx);
2606 stmt_ = hoist_send_masks(stmt_, ir_ctx, stmt_label_t::c_store(), false);
2607 stmt_ = split_shuffle(stmt_, ir_ctx);
2608 stmt_ = eliminate_common_subexprs(
2609 stmt_, ir_ctx, cfg_.reserved_regs(), cfg_.slm().gmem_bufs());
2610 stmt_ = hoist_exprs(stmt_, ir_ctx, cfg_.reserved_regs());
2611 if (cfg_.pipeline().do_unroll())
2612 stmt_ = loop_strength_reduce(stmt_, ir_ctx);
2613 stmt_ = optimize_alloc_let(stmt_, ir_ctx);
2614 if (cfg_.pipeline().do_unroll()) {
2615 stmt_ = update_loops_for_unrolling(stmt_, ir_ctx);
2616 stmt_ = inject_unrolling(stmt_, ir_ctx, cfg_, cb.ab_slm_size());
2617 }
2618 if (cfg_.hoist_masks_from_compute_loop()) {
2619 stmt_ = hoist_send_masks(
2620 stmt_, ir_ctx, stmt_label_t::compute_loop(), true);
2621 }
2622 stmt_ = fixup_if_conditions(stmt_, ir_ctx);
2623 stmt_ = unroll_loops(stmt_, ir_ctx);
2624 stmt_ = simplify(stmt_, ir_ctx);
2625 stmt_ = maybe_strip_prefetches(stmt_, ir_ctx, cfg_.reserved_regs());
2626 stmt_ = optimize_alloc_let(stmt_, ir_ctx);
2627 if (cfg_.hoist_masks_from_compute_loop()) {
2628 stmt_ = remove_spurious_send_mask_cast(stmt_, ir_ctx);
2629 }
2630 stmt_ = fix_int32_overflow(stmt_, ir_ctx);
2631 stmt_ = optimize_peephole(stmt_, ir_ctx);
2632 stmt_ = optimize_barrier(stmt_, ir_ctx);
2633 if (cfg_.fma_kind() == fma_kind_t::dp4a) stmt_ = inject_dp4a(stmt_, ir_ctx);
2634 stmt_ = inject_bank_conflict_attribute(stmt_, ir_ctx);
2635 stmt_ = stmt_group_t::make(stmt_label_t::kernel(), stmt_);
2636
2637#if !defined(NDEBUG) || defined(GEN_CONV_DEBUG)
2638 verify_buffer_access(stmt_, ir_ctx);
2639#endif
2640
2641 ir_trace() << "Convolution kernel body:\n" << stmt_ << std::endl;
2642 trace_perf();
2643}
2644
2645} // namespace jit
2646} // namespace gpu
2647} // namespace impl
2648} // namespace dnnl
2649