1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/ir/message.hpp"
18
19#include "gpu/jit/ir/block_2d_utils.hpp"
20#include "gpu/jit/ir/ir.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27std::ostream &operator<<(std::ostream &out, const send_op_t op) {
28 const char *s = nullptr;
29 switch (op) {
30 case send_op_t::atomic_fadd: s = "atomic_fadd"; break;
31 case send_op_t::load: s = "load"; break;
32 case send_op_t::load_2d: s = "load_2d"; break;
33 case send_op_t::prefetch: s = "prefetch"; break;
34 case send_op_t::prefetch_2d: s = "prefetch_2d"; break;
35 case send_op_t::store: s = "store"; break;
36 case send_op_t::store_2d: s = "store_2d"; break;
37 default: ir_error_not_expected(); s = "unknown";
38 }
39
40 return out << s;
41}
42
43stmt_t send_t::create_offset_store(const expr_t &header_buf,
44 const expr_t &mem_buf, const expr_t &_mem_off,
45 bool is_signed_offset) const {
46 ir_assert(is_var(mem_buf));
47 int header_off = 0;
48 int unit_size = 1;
49 if (!is_lsc && is_block() && (is_slm() || is_bts())) {
50 header_off = 2 * address_type().size();
51 // Convert byte offset to dwords/owords/hwords offset.
52 unit_size = type.scalar().size();
53 }
54
55 expr_t mem_off = _mem_off;
56 if (unit_size != 1) mem_off /= unit_size;
57
58 expr_t header_sub_buf = header_buf[header_off];
59
60 expr_t off;
61 if (is_a64()) {
62 off = cast(mem_buf, address_type());
63 if (mem_off.type().is_vector()) {
64 off = shuffle_t::make_broadcast(off, mem_off.type().elems());
65 }
66 off += mem_off;
67 } else {
68 off = mem_off;
69 }
70 off = cast(off, address_type(is_signed_offset, off.type().elems()));
71 return store_t::make(header_sub_buf, 0, off);
72}
73
74bool send_t::is_supported() const {
75 int max_access_size
76 = (is_2d() && !is_store_2d()) ? 32 * grf_size() : 8 * grf_size();
77 if (access_size() > max_access_size) return false;
78
79 // Block messages imply one slot.
80 if (is_block() && slots != 1) return false;
81
82 if (is_block() && !utils::one_of(type.elems(), 1, 2, 4, 8, 16))
83 return false;
84
85 // owordx8 is max supported unless accessing SLM.
86 if (type.is_oword() && !is_slm() && type.elems() > 8) return false;
87
88 // hword is not supported with SLM.
89 if (is_slm() && type.is_hword()) return false;
90
91 // Allow only block messages for SLM to reduce offset-related arithmetic.
92 if (is_slm() && !is_block()) return false;
93
94 // Only load/store with SLM.
95 if (is_slm() && !is_load() && !is_store()) return false;
96
97 // No hword stores before XeHPC.
98 if (is_store() && type.is_hword() && !is_xe_hpc_plus()) return false;
99
100 // XXX: Half-GRF stores result in correctness issues on XeHPC.
101 if (is_store() && is_block() && is_xe_hpc_plus()
102 && type.size() % grf_size() != 0)
103 return false;
104
105 // Skip transposing messages, they need additional logic in message
106 // decomposition to handle layouts.
107 if (type.is_dword() && type.elems() != 1) return false;
108 if (type.is_qword() && type.elems() != 1) return false;
109
110 // XXX: Allow only hword x {1,2,4,8} prefetch for now.
111 if (is_prefetch() && !type.is_hword()) return false;
112 if (is_prefetch() && type.elems() > 8) return false;
113
114 // Expect only float atomics.
115 if (is_atomic() && !(type.is_dword() || type.is_qword())) return false;
116
117 if (is_atomic() && !is_xe_hpc_plus() && is_a64() && slots > 8) return false;
118
119 // XXX: Tested only byte scattered messages.
120 if (is_scattered() && !is_atomic() && !type.is_byte() && !type.is_qword())
121 return false;
122
123 if (is_scattered() && !is_atomic()
124 && !utils::one_of(type.elems(), 1, 2, 4, 8))
125 return false;
126
127 return true;
128}
129
130std::vector<func_t> send_t::get_all(ngen::HW hw, send_op_t op,
131 send_address_t address, const type_t &mem_type,
132 send_cache_hint_t cache_hint) {
133 std::vector<func_t> filtered;
134 for (int slots : {1, 2, 4, 8, 16}) {
135 for (int elems : {1, 2, 4, 8, 16}) {
136 for (auto &type : {type_t::byte(), type_t::dword(), type_t::qword(),
137 type_t::oword(), type_t::hword()}) {
138 // Require data type size exact match for atomic messages.
139 if (op == send_op_t::atomic_fadd
140 && type.size() != mem_type.size())
141 continue;
142
143 auto f = send_t::make(hw, op, address, type.with_elems(elems),
144 slots, cache_hint);
145 if (!f.as<send_t>().is_supported()) continue;
146 filtered.push_back(f);
147 }
148 }
149 }
150
151 // Sort by total size in descending order.
152 std::sort(filtered.begin(), filtered.end(),
153 [](const func_t &_a, const func_t &_b) {
154 auto &a = _a.as<send_t>();
155 auto &b = _b.as<send_t>();
156 size_t a_sz = a.access_size();
157 size_t b_sz = b.access_size();
158 // Put block messages first.
159 if (a.is_block() != b.is_block()) return a.is_block();
160 // Prefer messages with a smaller type as they have less strict
161 // alignment requirements.
162 if (a_sz == b_sz)
163 return a.type.scalar().size() < b.type.scalar().size();
164 return a_sz > b_sz;
165 });
166
167 // Remove block messages with the same size (e.g. owordx4 and hwordx2).
168 std::vector<func_t> ret;
169 for (size_t i = 0; i < filtered.size(); i++) {
170 if (i > 0) {
171 auto &s_prev = filtered[i - 1].as<send_t>();
172 auto &s_cur = filtered[i].as<send_t>();
173 if (s_prev.is_block() && s_cur.is_block()
174 && (s_prev.type.size() == s_cur.type.size()))
175 continue;
176 }
177 ret.push_back(filtered[i]);
178 }
179
180 return ret;
181}
182
183ngen::CacheSettingsLSC get_cache_settings(const send_t &send) {
184 auto ret = ngen::CacheSettingsLSC::Default;
185 bool is_load = send.is_load() || send.is_load_2d();
186 bool is_store = send.is_store() || send.is_store_2d();
187 bool is_prefetch = send.is_prefetch() || send.is_prefetch_2d();
188 switch (send.cache_hint) {
189 case send_cache_hint_t::undef:
190 switch (send.hw) {
191 case ngen::HW::XeHPG:
192 if (is_store) ret = ngen::CacheSettingsLSC::L1WB_L3WB;
193 break;
194 case ngen::HW::XeHPC:
195 if (is_store) {
196 ret = ngen::CacheSettingsLSC::L1UC_L3WB;
197 } else if (is_load || is_prefetch) {
198 ret = ngen::CacheSettingsLSC::L1C_L3C;
199 }
200 break;
201 default: break;
202 }
203 break;
204 case send_cache_hint_t::load_once:
205 switch (send.hw) {
206 case ngen::HW::XeHPG:
207 ret = ngen::CacheSettingsLSC::L1C_L3UC;
208 break;
209 case ngen::HW::XeHPC:
210 ret = ngen::CacheSettingsLSC::L1C_L3C;
211 break;
212 default: break;
213 }
214 break;
215 }
216 return ret;
217}
218
219// Helper class to iterate through global memory offsets, provides queries to
220// check whether given blocks are dense, properly aligned, etc.
221class memory_walker_t {
222public:
223 memory_walker_t(const constraint_set_t &cset, const view_t &view)
224 : view_(view)
225 , type_size_(view.type().size())
226 , mask_tensor_(view.create_mask_tensor(cset).reinterpret(view.type()))
227 , full_size_(view.velems() * type_size_) {
228 init_dense_blocks(cset);
229 reset();
230 }
231
232 void reset() {
233 cur_off_ = 0;
234 remaining_size_ = full_size_;
235 }
236
237 const mask_tensor_t &mask_tensor() const { return mask_tensor_; }
238
239 bool has_next() const { return cur_off_ < full_size_; }
240
241 int remaining_size() const { return remaining_size_; }
242
243 int remaining_elems() const { return remaining_size_ / type_size_; }
244
245 bool is_dense_and_aligned(int off, int size, int alignment) const {
246 if (off + size > remaining_size_) return false;
247 if (size == 0) return true;
248 int beg = cur_off_ + off;
249 int end = cur_off_ + off + size;
250 if (get_block_index(beg) != get_block_index(end - 1)) return false;
251 if (alignment != 0 && get_alignment(beg) < alignment) return false;
252 return true;
253 }
254
255 // Returns true if each of the given slot regions is dense and aligned.
256 bool check_region(int off, int slots, int slot_size, int alignment) const {
257 for (int i = 0; i < slots; i++) {
258 int off = i * slot_size;
259 // Overflow is fine, expect it to be handled by proper masking.
260 if (off >= remaining_size_) return true;
261 if ((slot_size * slots) % type_size_ != 0) return false;
262 if (!is_dense_and_aligned(off, slot_size, alignment)) return false;
263 }
264 return true;
265 }
266
267 // Returns true if the given region can be masked with `mask_size`
268 // granularity and `nmasks` number of masks.
269 bool check_mask_size(int off, int size, int mask_size, int nmasks) const {
270 auto mask = get_mask(off, size, mask_size, nmasks, /*allow_fail=*/true);
271 return !mask.is_empty();
272 }
273
274 expr_t get_offset(int off, expr_t &base, int &off_const) const {
275 if (off >= remaining_size_) {
276 base = expr_t(0);
277 off_const = 0;
278 return base;
279 }
280 int block_idx = get_block_index(cur_off_ + off);
281 ir_assert(block_idx >= 0 && block_idx < int(block_offs_.size()));
282 base = block_offs_[block_idx];
283 auto prev_base = block_offs_[block_idx == 0 ? 0 : block_idx - 1];
284 auto get_const_summand = [&](expr_t expr) -> int64_t {
285 if (!expr.type().is_int()) return 0;
286 auto binary_op = expr.as_ptr<binary_op_t>();
287 if (binary_op && binary_op->op_kind == op_kind_t::_add
288 && is_const(binary_op->b))
289 return to_cpp<int64_t>(binary_op->b);
290 return 0;
291 };
292
293 auto const_summand = get_const_summand(base);
294 auto base1 = simplify(base - const_summand);
295 auto base2 = simplify(prev_base - get_const_summand(prev_base));
296 bool same_base = base1.is_equal(base2);
297 off_const = (cur_off_ + off) % dense_block_size_;
298 if (!same_base || const_summand == 0) return base + off_const;
299 base = base1;
300 off_const += const_summand;
301 return base + off_const;
302 }
303
304 // Returns a boolean mask expression for the given region to access.
305 expr_t get_mask(int off, int size, int mask_size, int nmasks,
306 bool allow_fail = false) const {
307 ir_assert(size % mask_size == 0) << "Incompatible mask size.";
308 auto sub_mask_tensor = create_sub_mask_tensor(off, size);
309 sub_mask_tensor = sub_mask_tensor.reinterpret(type_t::u8(mask_size));
310 if (sub_mask_tensor.is_empty()) {
311 if (allow_fail) return expr_t();
312 ir_error_not_expected();
313 }
314 auto ret = sub_mask_tensor.to_expr(nmasks);
315 if (ret.is_empty()) {
316 if (allow_fail) return expr_t();
317 ir_error_not_expected() << "Can't create mask.";
318 }
319 return ret;
320 }
321
322 // Moves the current position `size` bytes ahead.
323 void advance(int size) {
324 ir_assert(size % type_size_ == 0);
325 size = std::min(size, remaining_size_);
326 cur_off_ += size;
327 remaining_size_ -= size;
328 }
329
330private:
331 void init_dense_blocks(const constraint_set_t &cset) {
332 auto l = view_.create_pseudo_vlayout();
333 // Find the maximum innermost dense tile.
334 stride_t stride = 1;
335 std::vector<dim_t> dims(l.ndims(), 1);
336 for (auto &b : l.blocks()) {
337 if (b.stride != stride) break;
338 dims[b.dim_idx] *= b.block;
339 stride = b.block * b.stride;
340 }
341 tensor_t tile(dims);
342 dense_block_size_ = tile.elems() * type_size_;
343 // Split the memory view into dense blocks and precompute block offsets
344 // and alignments.
345 view_.for_each_tile(tile, [&](const std::vector<dim_t> &start) {
346 auto off = view_.offset_in_bytes(expr_cast<expr_t>(start));
347 off = simplify(off, cset);
348
349 const int base_alignment = 128;
350 int64_t f = get_max_const_factor(off, cset);
351 int alignment = f ? ir_utils::max_pow2_divisor(f) : base_alignment;
352
353 block_offs_.push_back(off);
354 block_alignments_.push_back(alignment);
355 });
356 }
357
358 mask_tensor_t create_sub_mask_tensor(int off, int size) const {
359 ir_assert(off % type_size_ == 0);
360 ir_assert(size % type_size_ == 0);
361
362 std::vector<dim_t> sub_dims = {size / type_size_};
363 layout_t sub_layout(view_.type(), 0, sub_dims);
364 mask_tensor_t sub_mask_tensor(sub_layout);
365 int beg = (cur_off_ + off) / type_size_;
366 int end = (cur_off_ + off + size) / type_size_;
367 for (int i = beg; i < end; i++) {
368 auto mask = (i < mask_tensor_.elems()) ? mask_tensor_.mask(i)
369 : expr_t(false);
370 sub_mask_tensor.set_mask(i - beg, mask);
371 }
372 return sub_mask_tensor;
373 }
374
375 int get_block_index(int off) const { return off / dense_block_size_; }
376
377 int get_alignment(int off) const {
378 int block_alignment = block_alignments_[off / dense_block_size_];
379 return ir_utils::max_pow2_divisor(
380 block_alignment + off % dense_block_size_);
381 }
382
383 view_t view_;
384 int type_size_;
385 mask_tensor_t mask_tensor_;
386 std::vector<expr_t> block_offs_;
387 std::vector<int> block_alignments_;
388 int cur_off_ = 0;
389 int full_size_ = 0;
390 int remaining_size_ = 0;
391 int dense_block_size_ = 0;
392};
393
394class layout_walker_t {
395public:
396 layout_walker_t() = default;
397 layout_walker_t(const layout_t &layout, int grf_size)
398 : layout_(layout)
399 , grf_size_(grf_size)
400 , type_size_(layout.type().size())
401 , idxs_(layout.blocks().size()) {}
402
403 int offset_bytes() const { return off_bytes_; }
404
405 bool can_access(int size) const {
406 int off = off_bytes_ + size;
407 return off <= max_offset_bytes();
408 }
409
410 // Returns true if the next `elems` elements can be stored in the layout
411 // given the following requirements:
412 // - They must be uniformly strided with `stride` (specified in elements)
413 // - The last element must be GRF boundary aligned (unless `is_last_region`
414 // is true)
415 // - The last element must not cross the layout boundary
416 bool can_advance(int stride, int elems, bool is_last_region = false) {
417 if (is_last_region) elems = std::min(elems, remaining_elems());
418 auto cur_idxs = idxs_;
419 int cur_off_bytes = off_bytes_;
420 for (int i = 0; i < elems - 1; i++) {
421 int next_off_bytes = advance(cur_idxs, cur_off_bytes);
422 if (next_off_bytes - cur_off_bytes != stride * type_size_)
423 return false;
424 cur_off_bytes = next_off_bytes;
425 }
426 cur_off_bytes = advance(cur_idxs, cur_off_bytes);
427 if (cur_off_bytes > max_offset_bytes()) return false;
428 if (!is_last_region && cur_off_bytes % grf_size_ != 0) return false;
429 return true;
430 }
431
432 // Moves the current position `elems` elements ahead.
433 void advance(int elems) {
434 elems = std::min(elems, remaining_elems());
435 for (int i = 0; i < elems; i++) {
436 off_bytes_ = advance(idxs_, off_bytes_);
437 elems_++;
438 }
439 }
440
441private:
442 int max_offset_bytes() const {
443 return utils::rnd_up((int)layout_.size(), grf_size_);
444 }
445
446 int remaining_elems() const { return layout_.elems() - elems_; }
447
448 int advance(std::vector<int> &idxs, int off_bytes) const {
449 for (size_t i = 0; i < idxs.size(); i++) {
450 if (++idxs[i] < layout_.blocks()[i].block) break;
451 idxs[i] = 0;
452 }
453 int off = 0;
454 for (size_t i = 0; i < idxs.size(); i++) {
455 int stride = (int)layout_.blocks()[i].stride;
456 off += idxs[i] * stride;
457 }
458 return off * type_size_;
459 }
460
461 layout_t layout_;
462 int grf_size_;
463 int type_size_;
464
465 std::vector<int> idxs_;
466 int elems_ = 0;
467 int off_bytes_ = 0;
468};
469
470access_builder_t::access_builder_t(ir_context_t &ir_ctx, const view_t &mem_view,
471 const expr_t &mem_buf, const expr_t &reg_buf, send_op_t send_op,
472 send_address_t send_address, send_cache_hint_t send_cache_hint,
473 send_hint_t &send_hint)
474 : ir_ctx_(&ir_ctx)
475 , mem_view_(mem_view)
476 , mem_buf_(mem_buf)
477 , reg_buf_(reg_buf)
478 , send_op_(send_op)
479 , send_address_(send_address)
480 , send_cache_hint_(send_cache_hint)
481 , mem_type_(mem_view.type())
482 , mem_walker_(
483 utils::make_unique<memory_walker_t>(ir_ctx.cset(), mem_view)) {
484 if (send_hint.hint_2d.enable) {
485 if (try_build_2d(send_hint)) return;
486 }
487 send_hint.hint_2d = send_2d_hint_t();
488 build();
489}
490
491access_builder_t::access_builder_t(access_builder_t &&) = default;
492access_builder_t::~access_builder_t() = default;
493
494void access_builder_t::build() {
495 bool ok = false;
496 for (auto &l : candidate_payload_layouts()) {
497 // Try to find send decomposition with the given GRF payload layout.
498 if (try_build(l)) {
499 ok = true;
500 break;
501 }
502 }
503 if (!ok && send_op_ == send_op_t::prefetch) {
504 // Do not treat as an error, skip prefetch messages during generation.
505 ir_warning() << "Can't generate send decomposition for prefetch."
506 << std::endl;
507 return;
508 }
509 ir_assert(ok) << "Can't generate send decomposition.";
510}
511
512static bool stride_dimension_ok(const view_t &view, int stride_tidx,
513 int stride_vidx, const std::vector<expr_t> &vstart) {
514 auto &tdim = view.tdim(stride_tidx);
515 auto e = tdim.expr();
516 for (int i = 0; i < tdim.nvargs(); i++) {
517 int vidx = tdim.vidx(i);
518 auto &vvar = view.vvars()[vidx];
519 if (vidx == stride_vidx) {
520 e = substitute(e, vvar, expr_t(0));
521 } else {
522 e = substitute(e, vvar, vstart[vidx]);
523 }
524 }
525 e = simplify(e);
526 return is_zero(e);
527}
528
529static expr_t try_scalarize(const expr_t &e) {
530 if (e.type().is_scalar()) return e;
531
532 if (auto *shuffle = e.as_ptr<shuffle_t>()) {
533 if (shuffle->is_broadcast()) return try_scalarize(shuffle->vec[0]);
534 return expr_t();
535 }
536
537 if (auto *binary = e.as_ptr<binary_op_t>()) {
538 auto a = try_scalarize(binary->a);
539 auto b = try_scalarize(binary->b);
540 if (a.is_empty() || b.is_empty()) return expr_t();
541 return binary_op_t::make(binary->op_kind, a, b);
542 }
543
544 ir_error_not_expected() << e;
545 return expr_t();
546}
547
548static stmt_t try_promote_to_lsc(const stmt_t &_call) {
549 if (_call.is_empty()) return _call;
550 auto &call = _call.as<func_call_t>();
551 auto &send = call.func.as<send_t>();
552 if (send.is_lsc || send.is_2d()) return call;
553 if (send.hw < ngen::HW::XeHPG) return call;
554 if (send.is_slm() || send.is_bts()) return call;
555 if (!send.is_block()) return call;
556
557 auto mask = try_scalarize(send_t::arg_mask(call));
558 if (mask.is_empty()) return call;
559
560 auto new_args = call.args;
561 send_t::arg_mask(new_args) = mask;
562
563 auto lsc_send = send_t::make(send.hw, send.op, send.address, send.type,
564 send.slots, /*is_lsc=*/true, send.cache_hint);
565 return lsc_send.call(new_args);
566}
567
568bool access_builder_t::try_build_2d(send_hint_t &send_hint) {
569 auto vlayout = mem_view_.create_pseudo_vlayout();
570 auto &hint = send_hint.hint_2d;
571 // The data may be loaded in a wider data type to get a proper GRF layout.
572 if (!hint.type.is_undef()) vlayout = vlayout.reinterpret(hint.type);
573
574 bool is_store = (send_op_ == send_op_t::store);
575 auto send_type = type_t::u(vlayout.type().size() * 8);
576 auto blocks = vlayout.blocks();
577 if (blocks.size() < 2) return false;
578
579 auto &b0 = blocks[0];
580 auto &b1 = blocks[1];
581 ir_assert(b0.dim_idx != b1.dim_idx);
582 if (b0.stride != stride_t(1)) return false;
583 if (!b1.stride.is_fixed()) return false;
584
585 auto get_tdim_idx = [&](int vdim_idx, int &stride) {
586 int ret = -1;
587 for (int i = 0; i < mem_view_.ntdims(); i++) {
588 auto &tdim = mem_view_.tdim(i);
589 for (int j = 0; j < tdim.nvargs(); j++) {
590 if (tdim.vidx(j) == vdim_idx) {
591 ir_assert(ret == -1);
592 stride = (int)tdim.vstride(j);
593 ret = i;
594 }
595 }
596 }
597 return ret;
598 };
599
600 int w_tstride = 0;
601 int h_tstride = 0;
602 int w_dim_idx = get_tdim_idx(b0.dim_idx, w_tstride);
603 int h_dim_idx = get_tdim_idx(b1.dim_idx, h_tstride);
604
605 if (w_tstride != 1) return false;
606
607 auto &tlayout = mem_view_.tlayout();
608 auto get_2d_dim = [&](int tdim_idx) {
609 return tlayout.inner_block(tdim_idx, /*skip_outer=*/false);
610 };
611
612 int surface_width = 0;
613 int surface_height = 0;
614 int surface_pitch = b1.stride;
615 bool is_w_blocked = (get_2d_dim(w_dim_idx) != tlayout.dim(w_dim_idx));
616 bool is_h_blocked = (get_2d_dim(h_dim_idx) != tlayout.dim(h_dim_idx));
617 // Virtual surface means loading from the innermost block of a block layout
618 // which implies no bound checks embedded into 2D block message.
619 bool use_virtual_surface = is_w_blocked || is_h_blocked;
620 if (use_virtual_surface) {
621 if (h_tstride != 1) return false;
622 surface_width = b0.block;
623 surface_height = b1.block;
624 } else {
625 surface_width = tlayout.dim(w_dim_idx);
626 surface_height = tlayout.dim(h_dim_idx);
627 if (surface_height % h_tstride != 0) return false;
628 surface_height = surface_height / h_tstride;
629 }
630 int type_factor = ir_utils::safe_divide(send_type.size(), mem_type_.size());
631 surface_width /= type_factor;
632
633 int width = hint.width;
634 int height = hint.height;
635 int count = 1;
636 bool vnni = hint.vnni;
637 bool transpose = hint.transpose;
638
639 // Try to reduce the number of messages by increasing count per message.
640 int try_count = count * 2;
641 int max_count
642 = block_2d_max_count(is_store, transpose, width, mem_type_.size());
643 while (try_count <= max_count) {
644 if (b0.block % (try_count * width) != 0) break;
645 count = try_count;
646 try_count *= 2;
647 }
648
649 int W = surface_width;
650 int H = surface_height;
651 int P = surface_pitch;
652 int w = width;
653 int h = height;
654 int c = count;
655 if (!fixup_send_2d_params(send_type, vnni, transpose,
656 /*use_xy=*/!use_virtual_surface, W, H, P, w, h, c,
657 hint.vnni_permute_factor))
658 return false;
659
660 std::vector<dim_t> dims(vlayout.ndims(), 1);
661 dims[b0.dim_idx] = count * width;
662 dims[b1.dim_idx] = height;
663 tensor_t tile(dims);
664
665 reg_layout_ = layout_t(type_factor == 1 ? mem_type_ : send_type, 0,
666 std::vector<dim_t>(vlayout.ndims(), 1));
667 int h_inner = vnni ? 4 / send_type.size() : 1;
668 int h_outer = ir_utils::safe_divide(height, h_inner);
669 reg_layout_ = reg_layout_.add_outer_block(b1.dim_idx, h_inner);
670 if (transpose) {
671 reg_layout_ = reg_layout_.add_outer_block(b1.dim_idx, h_outer);
672 reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, width);
673 } else {
674 reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, width);
675 reg_layout_ = reg_layout_.add_outer_block(b1.dim_idx, h_outer);
676 }
677 reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, count);
678
679 int w_outermost
680 = ir_utils::safe_divide(vlayout.dim(b0.dim_idx), count * width);
681 int h_outermost = ir_utils::safe_divide(vlayout.dim(b1.dim_idx), height);
682 reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, w_outermost);
683 reg_layout_ = reg_layout_.add_outer_block(b1.dim_idx, h_outermost);
684
685 if (type_factor != 1) {
686 auto blocks = reg_layout_.blocks();
687 reg_layout_ = layout_t(
688 mem_type_, 0, std::vector<dim_t>(vlayout.ndims(), 1));
689 reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, type_factor);
690 for (auto &b : blocks)
691 reg_layout_ = reg_layout_.add_outer_block(b.dim_idx, b.block);
692 }
693
694 for (auto &b : blocks) {
695 if (utils::one_of(b.dim_idx, b0.dim_idx, b1.dim_idx)) continue;
696 reg_layout_ = reg_layout_.add_outer_block(b.dim_idx, b.block);
697 }
698
699 reg_layout_walker_
700 = utils::make_unique<layout_walker_t>(reg_layout_, grf_size());
701
702 // Update user hint.
703 hint.type = send_type;
704 hint.enable = true;
705 hint.vnni = vnni;
706 hint.transpose = transpose;
707 hint.width = w;
708 hint.height = h;
709 auto _send = send_t::make_2d(ir_ctx_->hw(), send_hint.convert(send_op_),
710 send_type, W, H, P, w, h, c, vnni, transpose, send_cache_hint_);
711 auto &send = _send.as<send_t>();
712
713 stmt_ = stmt_t();
714 bool ok = true;
715 auto vstart0 = mem_view_.vstart();
716 vlayout.for_each_tile(tile, [&](const std::vector<dim_t> &start) {
717 if (!ok) return;
718
719 int access_size = send.access_size();
720 int access_elems = access_size / mem_type_.size();
721
722 // Check mask requirements.
723 expr_t mask;
724 if (!check_2d_mask(tensor_t(tile.dims(), start), use_virtual_surface,
725 w_dim_idx, h_dim_idx, mask)) {
726 ok = false;
727 return;
728 }
729
730 if (!send.is_prefetch_2d()) {
731 if (!reg_layout_walker_->can_advance(1, access_elems)) {
732 ok = false;
733 return;
734 }
735
736 if (!reg_layout_walker_->can_access(send.payload_size())) {
737 ok = false;
738 return;
739 }
740 }
741
742 auto vstart = vstart0;
743 for (int i = 0; i < vlayout.ndims(); i++) {
744 if (start[i] == 0) continue;
745 int factor = (i == b0.dim_idx ? type_factor : 1);
746 vstart[i] += factor * start[i];
747 }
748 auto tstart
749 = mem_view_.cvt_vargs_to_targs(vstart, /*ignore_vstart=*/true);
750
751 auto &_x = tstart[w_dim_idx];
752 auto &_y = tstart[h_dim_idx];
753
754 expr_t x(0);
755 expr_t y(0);
756
757 bool skip_send = false;
758 if (!use_virtual_surface) {
759 std::swap(x, _x);
760 std::swap(y, _y);
761 if (type_factor != 1) x /= type_factor;
762
763 if (h_tstride != 1) {
764 if (!stride_dimension_ok(
765 mem_view_, h_dim_idx, b1.dim_idx, vstart)) {
766 if (send.is_prefetch_2d()) {
767 skip_send = true;
768 } else {
769 ok = false;
770 return;
771 }
772 }
773 y /= h_tstride;
774 }
775 }
776
777 auto off = simplify(
778 mem_view_.tlayout().offset_in_bytes(tstart), ir_ctx_->cset());
779
780 // Check alignment requirements.
781 int64_t align = get_max_const_factor(off, ir_ctx_->cset());
782 if (align % block_2d_base_alignment(ir_ctx_->hw_cfg()) != 0) {
783 ok = false;
784 return;
785 }
786
787 if (!skip_send) {
788 if (!ir_ctx_->cset().can_prove(
789 x % block_2d_x_alignment(send_type.size()) == 0)) {
790 ok = false;
791 return;
792 }
793 auto reg_buf = (send.is_prefetch_2d()
794 ? expr_t()
795 : reg_buf_ + reg_layout_walker_->offset_bytes());
796 auto send_stmt = send(mem_buf_, off, reg_buf, mask, x, y);
797 stmt_ = stmt_.append(send_stmt);
798 }
799
800 reg_layout_walker_->advance(send.access_size() / mem_type_.size());
801 });
802
803 return ok;
804}
805
806bool access_builder_t::fixup_send_2d_params(const type_t &send_type, bool vnni,
807 bool transpose, bool use_xy, int &W, int &H, int &P, int &w, int &h,
808 int &c, int &vnni_permute_factor) {
809 int surface_width_size = W * send_type.size();
810 auto whp_ok = [&]() {
811 return block_2d_width_ok(W, send_type.size()) && block_2d_height_ok(H)
812 && block_2d_pitch_ok(
813 ir_ctx_->hw_cfg(), P, send_type.size(), use_xy);
814 };
815
816 // No VNNI permute by default.
817 vnni_permute_factor = 0;
818
819 // Surface width must be >= 64 bytes. For smaller width we can apply
820 // reshape, e.g. [16a] x [16b] -> [8a] x [2a16b] to have block with larger
821 // width. Such reshape impacts width/height handling with the following
822 // implications:
823 // - Reshape is applied only for VNNI and no transpose case. This allows to
824 // get the same GRF layout but with permuted height elements:
825 // - Layout without reshape: 8a16b2a
826 // - Layout with reshape: 8a16b2a (a/height dimension is permuted)
827 // - Permutation is safe when it's done for the reduction dimension
828 // (doesn't matter in which order elements are accumulated).
829 // - Permutation pattern must be the same between A and B tensors
830 if (surface_width_size >= 64) return whp_ok();
831
832 // Reshape is only expected/supported with VNNI.
833 if (!vnni || transpose) return false;
834
835 if (64 % surface_width_size != 0) return false;
836
837 int factor = 64 / surface_width_size;
838 if (h % factor != 0) return false;
839
840 int max_count = block_2d_max_count(
841 send_op_ == send_op_t::store, transpose, w, send_type.size());
842 if (factor > max_count) return false;
843
844 vnni_permute_factor = factor;
845 W *= factor;
846 P *= factor;
847 H /= factor;
848 h /= factor;
849 c = factor;
850 return whp_ok();
851}
852
853bool access_builder_t::check_2d_mask(const tensor_t &tile,
854 bool use_virtual_surface, int w_dim_idx, int h_dim_idx,
855 expr_t &mask) const {
856 auto sub_view = mem_view_.create_sub_view(tile);
857 auto mask_tensor = sub_view.create_mask_tensor(ir_ctx_->cset());
858 mask = mask_tensor.to_expr(1);
859 if (!mask.is_empty()) return true;
860
861 // Virtual surface implies no out-of-bound send checks.
862 if (use_virtual_surface) return false;
863
864 // Remove bound conditions that are covered by out-of-bound send checks.
865 uint32_t tmask = 0xFFFFFFFF;
866 for (int i = 0; i < sub_view.nvdims(); i++) {
867 if (!utils::one_of(i, w_dim_idx, h_dim_idx)) continue;
868 for (int j = 0; j < sub_view.ntdims(); j++) {
869 auto &tdim = sub_view.tdim(j);
870 for (int k = 0; k < tdim.nvargs(); k++) {
871 if (tdim.vidx(k) == i) {
872 // TODO: Check if tdim mask is a bound mask.
873 tmask &= ~(1U << i);
874 }
875 }
876 }
877 }
878 mask_tensor = sub_view.create_mask_tensor(ir_ctx_->cset(), tmask);
879 mask = mask_tensor.to_expr(1);
880 if (!mask.is_empty()) return true;
881
882 return false;
883}
884
885bool access_builder_t::try_build(const layout_t &try_layout) {
886 auto &try_layout_blocks = try_layout.blocks();
887 int reg_stride
888 = (try_layout_blocks.empty() ? 0
889 : (int)try_layout_blocks[0].stride);
890 auto send_list = send_t::get_all(ir_ctx_->hw(), send_op_, send_address_,
891 mem_type_, send_cache_hint_);
892 reg_layout_walker_
893 = utils::make_unique<layout_walker_t>(try_layout, grf_size());
894 stmt_ = stmt_t();
895 mem_walker_->reset();
896 // Iterate through the memory view, greedily select messages according to
897 // the sorted message list.
898 while (mem_walker_->has_next()) {
899 func_t _send;
900 for (auto &_s : send_list) {
901 auto &s = _s.as<send_t>();
902
903 int slot_size = s.type.size();
904 int alignment = s.alignment();
905 int nmasks = s.nmasks();
906 int payload_stride = s.payload_type_stride();
907 int access_size = s.access_size();
908 int access_elems = access_size / mem_type_.size();
909 bool is_last_chunk = mem_walker_->remaining_size() <= access_size;
910
911 if (reg_stride != 1 || payload_stride != slot_size) {
912 // Detected strided GRF layout or strided payload. In this
913 // case require full data type and stride match.
914 if (reg_stride != 0
915 && payload_stride != reg_stride * mem_type_.size())
916 continue;
917 if (s.type.size() != mem_type_.size()) continue;
918 }
919 // Prefetches don't have payload so skip these conditions for
920 // prefetch.
921 if (!s.is_prefetch()) {
922 if (!reg_layout_walker_->can_advance(
923 reg_stride, access_elems, is_last_chunk))
924 continue;
925
926 if (!reg_layout_walker_->can_access(s.payload_size())) continue;
927 }
928
929 // Check if slots are contiguous and aligned.
930 if (!mem_walker_->check_region(0, s.slots, slot_size, alignment))
931 continue;
932
933 // Check mask requirements.
934 // XXX: Postpone mask check for prefetch until during send call
935 // generation. If the mask cannot be generated, skip the prefetch.
936 if (!s.is_prefetch()
937 && !mem_walker_->check_mask_size(
938 0, access_size, s.mask_size(), nmasks))
939 continue;
940
941 _send = _s;
942 break;
943 }
944 // Can't find a message - try another GRF layout for payload.
945 if (_send.is_empty()) return false;
946
947 auto &send = _send.as<send_t>();
948 auto send_stmt = create_send_stmt(send);
949 send_stmt = try_promote_to_lsc(send_stmt);
950 stmt_ = stmt_.append(send_stmt);
951
952 reg_layout_walker_->advance(send.access_size() / mem_type_.size());
953 mem_walker_->advance(send.access_size());
954 }
955 reg_layout_ = try_layout;
956 return true;
957}
958
959std::vector<layout_t> access_builder_t::candidate_payload_layouts() const {
960 int type_size = mem_type_.size();
961 auto vlayout = mem_view_.create_dense_vlayout();
962
963 std::vector<layout_t> ret;
964
965 // Dense payload layout directly mapping to the memory view.
966 ret.push_back(vlayout);
967
968 // These payload layouts are to match payload for byte x {1,2} scattered
969 // messages (they are dword-strided).
970 if (type_size == 2) ret.push_back(vlayout.make_strided(2));
971 if (type_size == 1) ret.push_back(vlayout.make_strided(4));
972
973 return ret;
974}
975
976stmt_t access_builder_t::create_send_stmt(const send_t &send) {
977 std::vector<expr_t> off_vec;
978 // Try to detect a common base and const vector offset to reduce further
979 // arithmetic.
980 expr_t off_base0;
981 int off_const0 = -1;
982 bool is_same_base = true;
983 std::vector<expr_t> off_const_vec;
984 for (int i = 0; i < send.slots; i++) {
985 expr_t off_base;
986 int off_const;
987 auto off = mem_walker_->get_offset(
988 i * send.type.size(), off_base, off_const);
989 if (off_base0.is_empty()) {
990 off_base0 = off_base;
991 off_const0 = off_const;
992 } else if (!off_base.is_equal(off_base0)) {
993 is_same_base = false;
994 }
995 off_vec.push_back(off);
996 off_const_vec.push_back(off_const - off_const0);
997 }
998 expr_t off;
999 if (send.slots == 1 || !is_same_base) {
1000 off = shuffle_t::make(off_vec);
1001 } else {
1002 off = shuffle_t::make_broadcast(off_base0, send.slots)
1003 + shuffle_t::make_broadcast(off_const0, send.slots)
1004 + shuffle_t::make(off_const_vec);
1005 }
1006 bool allow_fail = send.is_prefetch();
1007 auto _mask = mem_walker_->get_mask(
1008 0, send.access_size(), send.mask_size(), send.nmasks(), allow_fail);
1009 if (_mask.is_empty()) return stmt_t();
1010
1011 auto _reg_buf = (send.is_prefetch()
1012 ? expr_t()
1013 : reg_buf_ + reg_layout_walker_->offset_bytes());
1014 auto ret = send(mem_buf_, off, _reg_buf, _mask);
1015 return ret;
1016}
1017
1018send_2d_hint_t get_send_2d_hint(send_op_t send_op, const type_t &_type,
1019 bool vnni, bool transpose, int w_tile, int h_tile, int w_blk = 0,
1020 int h_blk = 0) {
1021 auto type = _type;
1022 bool orig_vnni = vnni;
1023 bool orig_transpose = transpose;
1024
1025 if (vnni && transpose) {
1026 // This combination is not supported but try to replace by upconvert
1027 // and transpose.
1028 if (type.size() == 2) {
1029 type = type_t::u32();
1030 w_tile = ir_utils::safe_divide(w_tile, 2);
1031 w_blk = ir_utils::safe_divide(w_blk, 2);
1032 vnni = false;
1033 orig_vnni = false;
1034 }
1035 }
1036
1037 // XXX: Convert transpose to VNNI when transpose is not
1038 // supported. This will require additional reorder but
1039 // reorder from "partially transposed" VNNI transformed
1040 // layout is cheaper.
1041 if (transpose && type.size() != 4) {
1042 vnni = true;
1043 transpose = false;
1044 }
1045
1046 bool is_load_or_prefetch
1047 = utils::one_of(send_op, send_op_t::load, send_op_t::prefetch);
1048 bool is_store = (send_op == send_op_t::store);
1049
1050 // Only D8, D16 and D32 are implemented.
1051 if (!utils::one_of(type.size(), 1, 2, 4)) return send_2d_hint_t();
1052
1053 // VNNI and transpose are mutually exclusive.
1054 if (vnni && transpose) return send_2d_hint_t();
1055
1056 // VNNI and transpose are supported with load only.
1057 if (is_store && (vnni || transpose)) return send_2d_hint_t();
1058
1059 // VNNI is supported with D8 and D16 only.
1060 if (vnni && !utils::one_of(type.size(), 1, 2)) return send_2d_hint_t();
1061
1062 // Transpose is supported with D32 only.
1063 if (transpose && type.size() != 4) return send_2d_hint_t();
1064
1065 int w_min = (transpose ? 1 : 4 / type.size());
1066 int w_max = (transpose ? 8 : (vnni ? 16 : 64 / type.size()));
1067 int h_min = (vnni ? (4 / type.size()) : 1);
1068 int h_max = (is_load_or_prefetch ? 32 : 8);
1069
1070 if (w_blk > 0 && (w_blk < w_min || w_blk > w_max)) return send_2d_hint_t();
1071 if (h_blk > 0 && (h_blk < h_min || h_blk > h_max)) return send_2d_hint_t();
1072
1073 auto find_block = [&](int dim, int min, int max) {
1074 for (int b = max; b >= min; b /= 2) {
1075 if (dim % b == 0) return b;
1076 }
1077 return 0;
1078 };
1079
1080 if (w_blk == 0) w_blk = find_block(w_tile, w_min, w_max);
1081 if (h_blk == 0) h_blk = find_block(h_tile, h_min, h_max);
1082 if (w_blk == 0 || h_blk == 0) return send_2d_hint_t();
1083
1084 if (orig_vnni && h_blk > 0) h_blk = find_block(h_tile, h_blk, h_max);
1085 if (orig_transpose && w_blk > 0) w_blk = find_block(w_tile, w_blk, w_max);
1086
1087 send_2d_hint_t hint;
1088 hint.type = type;
1089 hint.enable = true;
1090 hint.width = w_blk;
1091 hint.height = h_blk;
1092 hint.vnni = vnni;
1093 hint.transpose = transpose;
1094 return hint;
1095}
1096
1097send_hint_t get_send_hint(const exec_config_t &exec_cfg, send_op_t send_op,
1098 fma_kind_t fma_kind, abc_kind_t abc_kind, const view_t &view,
1099 const gemm_schedule_t &gemm_schedule, bool allow_2d) {
1100 if (!allow_2d) return send_hint_t();
1101 if (exec_cfg.hw() < ngen::HW::XeHPC) return send_hint_t();
1102 if (!utils::one_of(send_op, send_op_t::load, send_op_t::prefetch,
1103 send_op_t::store))
1104 return send_hint_t();
1105
1106 auto vlayout = view.create_pseudo_vlayout();
1107 auto blocks = vlayout.blocks();
1108 if (blocks.size() < 2) return send_hint_t();
1109
1110 auto &bmnk_mapper = gemm_schedule.bmnk_mapper();
1111 auto &b0 = blocks[0];
1112 auto &b1 = blocks[1];
1113 if (b0.dim_idx == b1.dim_idx) return send_hint_t();
1114 if (b0.stride != stride_t(1)) return send_hint_t();
1115 if (b1.stride.is_unknown()) return send_hint_t();
1116
1117 send_hint_t hint;
1118 if (send_op == send_op_t::load && fma_kind == fma_kind_t::dpas
1119 && utils::one_of(abc_kind, abc_kind_t::a, abc_kind_t::b)) {
1120 bool is_dpas_src1 = (abc_kind == abc_kind_t::b);
1121 int mn_blk = (is_dpas_src1 ? exec_cfg.simd() : 8);
1122 int k_blk = 32 / view.type().size();
1123
1124 bool is_b0_k = (bmnk_mapper.bmnk_kind(abc_kind, b0.dim_idx)
1125 == bmnk_kind_t::k);
1126
1127 // Handle 4 cases (consider bf16):
1128 // src1, MxK: 16a16b -> 8a16b2a (VNNI)
1129 // src1, KxM: 16a16b -> 16b16a -> 8b16a2b (transpose + VNNI)
1130 // src2, KxN: 16a16b -> 16b16a (transpose)
1131 // src2, NxK: 16a16b -> 16a16b ()
1132 bool vnni = is_dpas_src1;
1133 bool transpose = (is_dpas_src1 == is_b0_k);
1134 int b0_blk = is_b0_k ? k_blk : mn_blk;
1135 int b1_blk = !is_b0_k ? k_blk : mn_blk;
1136 if (b0.block % b0_blk != 0) return send_hint_t();
1137 if (b1.block % b1_blk != 0) return send_hint_t();
1138 hint.hint_2d = get_send_2d_hint(send_op, view.type(), vnni, transpose,
1139 b0.block, b1.block, b0_blk, b1_blk);
1140 } else {
1141 if (b0.block >= 128) return hint;
1142 hint.hint_2d = get_send_2d_hint(
1143 send_op, view.type(), false, false, b0.block, b1.block);
1144 }
1145
1146 return hint;
1147}
1148
1149} // namespace jit
1150} // namespace gpu
1151} // namespace impl
1152} // namespace dnnl
1153