1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/ir/message.hpp" |
18 | |
19 | #include "gpu/jit/ir/block_2d_utils.hpp" |
20 | #include "gpu/jit/ir/ir.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | std::ostream &operator<<(std::ostream &out, const send_op_t op) { |
28 | const char *s = nullptr; |
29 | switch (op) { |
30 | case send_op_t::atomic_fadd: s = "atomic_fadd" ; break; |
31 | case send_op_t::load: s = "load" ; break; |
32 | case send_op_t::load_2d: s = "load_2d" ; break; |
33 | case send_op_t::prefetch: s = "prefetch" ; break; |
34 | case send_op_t::prefetch_2d: s = "prefetch_2d" ; break; |
35 | case send_op_t::store: s = "store" ; break; |
36 | case send_op_t::store_2d: s = "store_2d" ; break; |
37 | default: ir_error_not_expected(); s = "unknown" ; |
38 | } |
39 | |
40 | return out << s; |
41 | } |
42 | |
43 | stmt_t send_t::create_offset_store(const expr_t &, |
44 | const expr_t &mem_buf, const expr_t &_mem_off, |
45 | bool is_signed_offset) const { |
46 | ir_assert(is_var(mem_buf)); |
47 | int = 0; |
48 | int unit_size = 1; |
49 | if (!is_lsc && is_block() && (is_slm() || is_bts())) { |
50 | header_off = 2 * address_type().size(); |
51 | // Convert byte offset to dwords/owords/hwords offset. |
52 | unit_size = type.scalar().size(); |
53 | } |
54 | |
55 | expr_t mem_off = _mem_off; |
56 | if (unit_size != 1) mem_off /= unit_size; |
57 | |
58 | expr_t = header_buf[header_off]; |
59 | |
60 | expr_t off; |
61 | if (is_a64()) { |
62 | off = cast(mem_buf, address_type()); |
63 | if (mem_off.type().is_vector()) { |
64 | off = shuffle_t::make_broadcast(off, mem_off.type().elems()); |
65 | } |
66 | off += mem_off; |
67 | } else { |
68 | off = mem_off; |
69 | } |
70 | off = cast(off, address_type(is_signed_offset, off.type().elems())); |
71 | return store_t::make(header_sub_buf, 0, off); |
72 | } |
73 | |
74 | bool send_t::is_supported() const { |
75 | int max_access_size |
76 | = (is_2d() && !is_store_2d()) ? 32 * grf_size() : 8 * grf_size(); |
77 | if (access_size() > max_access_size) return false; |
78 | |
79 | // Block messages imply one slot. |
80 | if (is_block() && slots != 1) return false; |
81 | |
82 | if (is_block() && !utils::one_of(type.elems(), 1, 2, 4, 8, 16)) |
83 | return false; |
84 | |
85 | // owordx8 is max supported unless accessing SLM. |
86 | if (type.is_oword() && !is_slm() && type.elems() > 8) return false; |
87 | |
88 | // hword is not supported with SLM. |
89 | if (is_slm() && type.is_hword()) return false; |
90 | |
91 | // Allow only block messages for SLM to reduce offset-related arithmetic. |
92 | if (is_slm() && !is_block()) return false; |
93 | |
94 | // Only load/store with SLM. |
95 | if (is_slm() && !is_load() && !is_store()) return false; |
96 | |
97 | // No hword stores before XeHPC. |
98 | if (is_store() && type.is_hword() && !is_xe_hpc_plus()) return false; |
99 | |
100 | // XXX: Half-GRF stores result in correctness issues on XeHPC. |
101 | if (is_store() && is_block() && is_xe_hpc_plus() |
102 | && type.size() % grf_size() != 0) |
103 | return false; |
104 | |
105 | // Skip transposing messages, they need additional logic in message |
106 | // decomposition to handle layouts. |
107 | if (type.is_dword() && type.elems() != 1) return false; |
108 | if (type.is_qword() && type.elems() != 1) return false; |
109 | |
110 | // XXX: Allow only hword x {1,2,4,8} prefetch for now. |
111 | if (is_prefetch() && !type.is_hword()) return false; |
112 | if (is_prefetch() && type.elems() > 8) return false; |
113 | |
114 | // Expect only float atomics. |
115 | if (is_atomic() && !(type.is_dword() || type.is_qword())) return false; |
116 | |
117 | if (is_atomic() && !is_xe_hpc_plus() && is_a64() && slots > 8) return false; |
118 | |
119 | // XXX: Tested only byte scattered messages. |
120 | if (is_scattered() && !is_atomic() && !type.is_byte() && !type.is_qword()) |
121 | return false; |
122 | |
123 | if (is_scattered() && !is_atomic() |
124 | && !utils::one_of(type.elems(), 1, 2, 4, 8)) |
125 | return false; |
126 | |
127 | return true; |
128 | } |
129 | |
130 | std::vector<func_t> send_t::get_all(ngen::HW hw, send_op_t op, |
131 | send_address_t address, const type_t &mem_type, |
132 | send_cache_hint_t cache_hint) { |
133 | std::vector<func_t> filtered; |
134 | for (int slots : {1, 2, 4, 8, 16}) { |
135 | for (int elems : {1, 2, 4, 8, 16}) { |
136 | for (auto &type : {type_t::byte(), type_t::dword(), type_t::qword(), |
137 | type_t::oword(), type_t::hword()}) { |
138 | // Require data type size exact match for atomic messages. |
139 | if (op == send_op_t::atomic_fadd |
140 | && type.size() != mem_type.size()) |
141 | continue; |
142 | |
143 | auto f = send_t::make(hw, op, address, type.with_elems(elems), |
144 | slots, cache_hint); |
145 | if (!f.as<send_t>().is_supported()) continue; |
146 | filtered.push_back(f); |
147 | } |
148 | } |
149 | } |
150 | |
151 | // Sort by total size in descending order. |
152 | std::sort(filtered.begin(), filtered.end(), |
153 | [](const func_t &_a, const func_t &_b) { |
154 | auto &a = _a.as<send_t>(); |
155 | auto &b = _b.as<send_t>(); |
156 | size_t a_sz = a.access_size(); |
157 | size_t b_sz = b.access_size(); |
158 | // Put block messages first. |
159 | if (a.is_block() != b.is_block()) return a.is_block(); |
160 | // Prefer messages with a smaller type as they have less strict |
161 | // alignment requirements. |
162 | if (a_sz == b_sz) |
163 | return a.type.scalar().size() < b.type.scalar().size(); |
164 | return a_sz > b_sz; |
165 | }); |
166 | |
167 | // Remove block messages with the same size (e.g. owordx4 and hwordx2). |
168 | std::vector<func_t> ret; |
169 | for (size_t i = 0; i < filtered.size(); i++) { |
170 | if (i > 0) { |
171 | auto &s_prev = filtered[i - 1].as<send_t>(); |
172 | auto &s_cur = filtered[i].as<send_t>(); |
173 | if (s_prev.is_block() && s_cur.is_block() |
174 | && (s_prev.type.size() == s_cur.type.size())) |
175 | continue; |
176 | } |
177 | ret.push_back(filtered[i]); |
178 | } |
179 | |
180 | return ret; |
181 | } |
182 | |
183 | ngen::CacheSettingsLSC get_cache_settings(const send_t &send) { |
184 | auto ret = ngen::CacheSettingsLSC::Default; |
185 | bool is_load = send.is_load() || send.is_load_2d(); |
186 | bool is_store = send.is_store() || send.is_store_2d(); |
187 | bool is_prefetch = send.is_prefetch() || send.is_prefetch_2d(); |
188 | switch (send.cache_hint) { |
189 | case send_cache_hint_t::undef: |
190 | switch (send.hw) { |
191 | case ngen::HW::XeHPG: |
192 | if (is_store) ret = ngen::CacheSettingsLSC::L1WB_L3WB; |
193 | break; |
194 | case ngen::HW::XeHPC: |
195 | if (is_store) { |
196 | ret = ngen::CacheSettingsLSC::L1UC_L3WB; |
197 | } else if (is_load || is_prefetch) { |
198 | ret = ngen::CacheSettingsLSC::L1C_L3C; |
199 | } |
200 | break; |
201 | default: break; |
202 | } |
203 | break; |
204 | case send_cache_hint_t::load_once: |
205 | switch (send.hw) { |
206 | case ngen::HW::XeHPG: |
207 | ret = ngen::CacheSettingsLSC::L1C_L3UC; |
208 | break; |
209 | case ngen::HW::XeHPC: |
210 | ret = ngen::CacheSettingsLSC::L1C_L3C; |
211 | break; |
212 | default: break; |
213 | } |
214 | break; |
215 | } |
216 | return ret; |
217 | } |
218 | |
219 | // Helper class to iterate through global memory offsets, provides queries to |
220 | // check whether given blocks are dense, properly aligned, etc. |
221 | class memory_walker_t { |
222 | public: |
223 | memory_walker_t(const constraint_set_t &cset, const view_t &view) |
224 | : view_(view) |
225 | , type_size_(view.type().size()) |
226 | , mask_tensor_(view.create_mask_tensor(cset).reinterpret(view.type())) |
227 | , full_size_(view.velems() * type_size_) { |
228 | init_dense_blocks(cset); |
229 | reset(); |
230 | } |
231 | |
232 | void reset() { |
233 | cur_off_ = 0; |
234 | remaining_size_ = full_size_; |
235 | } |
236 | |
237 | const mask_tensor_t &mask_tensor() const { return mask_tensor_; } |
238 | |
239 | bool has_next() const { return cur_off_ < full_size_; } |
240 | |
241 | int remaining_size() const { return remaining_size_; } |
242 | |
243 | int remaining_elems() const { return remaining_size_ / type_size_; } |
244 | |
245 | bool is_dense_and_aligned(int off, int size, int alignment) const { |
246 | if (off + size > remaining_size_) return false; |
247 | if (size == 0) return true; |
248 | int beg = cur_off_ + off; |
249 | int end = cur_off_ + off + size; |
250 | if (get_block_index(beg) != get_block_index(end - 1)) return false; |
251 | if (alignment != 0 && get_alignment(beg) < alignment) return false; |
252 | return true; |
253 | } |
254 | |
255 | // Returns true if each of the given slot regions is dense and aligned. |
256 | bool check_region(int off, int slots, int slot_size, int alignment) const { |
257 | for (int i = 0; i < slots; i++) { |
258 | int off = i * slot_size; |
259 | // Overflow is fine, expect it to be handled by proper masking. |
260 | if (off >= remaining_size_) return true; |
261 | if ((slot_size * slots) % type_size_ != 0) return false; |
262 | if (!is_dense_and_aligned(off, slot_size, alignment)) return false; |
263 | } |
264 | return true; |
265 | } |
266 | |
267 | // Returns true if the given region can be masked with `mask_size` |
268 | // granularity and `nmasks` number of masks. |
269 | bool check_mask_size(int off, int size, int mask_size, int nmasks) const { |
270 | auto mask = get_mask(off, size, mask_size, nmasks, /*allow_fail=*/true); |
271 | return !mask.is_empty(); |
272 | } |
273 | |
274 | expr_t get_offset(int off, expr_t &base, int &off_const) const { |
275 | if (off >= remaining_size_) { |
276 | base = expr_t(0); |
277 | off_const = 0; |
278 | return base; |
279 | } |
280 | int block_idx = get_block_index(cur_off_ + off); |
281 | ir_assert(block_idx >= 0 && block_idx < int(block_offs_.size())); |
282 | base = block_offs_[block_idx]; |
283 | auto prev_base = block_offs_[block_idx == 0 ? 0 : block_idx - 1]; |
284 | auto get_const_summand = [&](expr_t expr) -> int64_t { |
285 | if (!expr.type().is_int()) return 0; |
286 | auto binary_op = expr.as_ptr<binary_op_t>(); |
287 | if (binary_op && binary_op->op_kind == op_kind_t::_add |
288 | && is_const(binary_op->b)) |
289 | return to_cpp<int64_t>(binary_op->b); |
290 | return 0; |
291 | }; |
292 | |
293 | auto const_summand = get_const_summand(base); |
294 | auto base1 = simplify(base - const_summand); |
295 | auto base2 = simplify(prev_base - get_const_summand(prev_base)); |
296 | bool same_base = base1.is_equal(base2); |
297 | off_const = (cur_off_ + off) % dense_block_size_; |
298 | if (!same_base || const_summand == 0) return base + off_const; |
299 | base = base1; |
300 | off_const += const_summand; |
301 | return base + off_const; |
302 | } |
303 | |
304 | // Returns a boolean mask expression for the given region to access. |
305 | expr_t get_mask(int off, int size, int mask_size, int nmasks, |
306 | bool allow_fail = false) const { |
307 | ir_assert(size % mask_size == 0) << "Incompatible mask size." ; |
308 | auto sub_mask_tensor = create_sub_mask_tensor(off, size); |
309 | sub_mask_tensor = sub_mask_tensor.reinterpret(type_t::u8(mask_size)); |
310 | if (sub_mask_tensor.is_empty()) { |
311 | if (allow_fail) return expr_t(); |
312 | ir_error_not_expected(); |
313 | } |
314 | auto ret = sub_mask_tensor.to_expr(nmasks); |
315 | if (ret.is_empty()) { |
316 | if (allow_fail) return expr_t(); |
317 | ir_error_not_expected() << "Can't create mask." ; |
318 | } |
319 | return ret; |
320 | } |
321 | |
322 | // Moves the current position `size` bytes ahead. |
323 | void advance(int size) { |
324 | ir_assert(size % type_size_ == 0); |
325 | size = std::min(size, remaining_size_); |
326 | cur_off_ += size; |
327 | remaining_size_ -= size; |
328 | } |
329 | |
330 | private: |
331 | void init_dense_blocks(const constraint_set_t &cset) { |
332 | auto l = view_.create_pseudo_vlayout(); |
333 | // Find the maximum innermost dense tile. |
334 | stride_t stride = 1; |
335 | std::vector<dim_t> dims(l.ndims(), 1); |
336 | for (auto &b : l.blocks()) { |
337 | if (b.stride != stride) break; |
338 | dims[b.dim_idx] *= b.block; |
339 | stride = b.block * b.stride; |
340 | } |
341 | tensor_t tile(dims); |
342 | dense_block_size_ = tile.elems() * type_size_; |
343 | // Split the memory view into dense blocks and precompute block offsets |
344 | // and alignments. |
345 | view_.for_each_tile(tile, [&](const std::vector<dim_t> &start) { |
346 | auto off = view_.offset_in_bytes(expr_cast<expr_t>(start)); |
347 | off = simplify(off, cset); |
348 | |
349 | const int base_alignment = 128; |
350 | int64_t f = get_max_const_factor(off, cset); |
351 | int alignment = f ? ir_utils::max_pow2_divisor(f) : base_alignment; |
352 | |
353 | block_offs_.push_back(off); |
354 | block_alignments_.push_back(alignment); |
355 | }); |
356 | } |
357 | |
358 | mask_tensor_t create_sub_mask_tensor(int off, int size) const { |
359 | ir_assert(off % type_size_ == 0); |
360 | ir_assert(size % type_size_ == 0); |
361 | |
362 | std::vector<dim_t> sub_dims = {size / type_size_}; |
363 | layout_t sub_layout(view_.type(), 0, sub_dims); |
364 | mask_tensor_t sub_mask_tensor(sub_layout); |
365 | int beg = (cur_off_ + off) / type_size_; |
366 | int end = (cur_off_ + off + size) / type_size_; |
367 | for (int i = beg; i < end; i++) { |
368 | auto mask = (i < mask_tensor_.elems()) ? mask_tensor_.mask(i) |
369 | : expr_t(false); |
370 | sub_mask_tensor.set_mask(i - beg, mask); |
371 | } |
372 | return sub_mask_tensor; |
373 | } |
374 | |
375 | int get_block_index(int off) const { return off / dense_block_size_; } |
376 | |
377 | int get_alignment(int off) const { |
378 | int block_alignment = block_alignments_[off / dense_block_size_]; |
379 | return ir_utils::max_pow2_divisor( |
380 | block_alignment + off % dense_block_size_); |
381 | } |
382 | |
383 | view_t view_; |
384 | int type_size_; |
385 | mask_tensor_t mask_tensor_; |
386 | std::vector<expr_t> block_offs_; |
387 | std::vector<int> block_alignments_; |
388 | int cur_off_ = 0; |
389 | int full_size_ = 0; |
390 | int remaining_size_ = 0; |
391 | int dense_block_size_ = 0; |
392 | }; |
393 | |
394 | class layout_walker_t { |
395 | public: |
396 | layout_walker_t() = default; |
397 | layout_walker_t(const layout_t &layout, int grf_size) |
398 | : layout_(layout) |
399 | , grf_size_(grf_size) |
400 | , type_size_(layout.type().size()) |
401 | , idxs_(layout.blocks().size()) {} |
402 | |
403 | int offset_bytes() const { return off_bytes_; } |
404 | |
405 | bool can_access(int size) const { |
406 | int off = off_bytes_ + size; |
407 | return off <= max_offset_bytes(); |
408 | } |
409 | |
410 | // Returns true if the next `elems` elements can be stored in the layout |
411 | // given the following requirements: |
412 | // - They must be uniformly strided with `stride` (specified in elements) |
413 | // - The last element must be GRF boundary aligned (unless `is_last_region` |
414 | // is true) |
415 | // - The last element must not cross the layout boundary |
416 | bool can_advance(int stride, int elems, bool is_last_region = false) { |
417 | if (is_last_region) elems = std::min(elems, remaining_elems()); |
418 | auto cur_idxs = idxs_; |
419 | int cur_off_bytes = off_bytes_; |
420 | for (int i = 0; i < elems - 1; i++) { |
421 | int next_off_bytes = advance(cur_idxs, cur_off_bytes); |
422 | if (next_off_bytes - cur_off_bytes != stride * type_size_) |
423 | return false; |
424 | cur_off_bytes = next_off_bytes; |
425 | } |
426 | cur_off_bytes = advance(cur_idxs, cur_off_bytes); |
427 | if (cur_off_bytes > max_offset_bytes()) return false; |
428 | if (!is_last_region && cur_off_bytes % grf_size_ != 0) return false; |
429 | return true; |
430 | } |
431 | |
432 | // Moves the current position `elems` elements ahead. |
433 | void advance(int elems) { |
434 | elems = std::min(elems, remaining_elems()); |
435 | for (int i = 0; i < elems; i++) { |
436 | off_bytes_ = advance(idxs_, off_bytes_); |
437 | elems_++; |
438 | } |
439 | } |
440 | |
441 | private: |
442 | int max_offset_bytes() const { |
443 | return utils::rnd_up((int)layout_.size(), grf_size_); |
444 | } |
445 | |
446 | int remaining_elems() const { return layout_.elems() - elems_; } |
447 | |
448 | int advance(std::vector<int> &idxs, int off_bytes) const { |
449 | for (size_t i = 0; i < idxs.size(); i++) { |
450 | if (++idxs[i] < layout_.blocks()[i].block) break; |
451 | idxs[i] = 0; |
452 | } |
453 | int off = 0; |
454 | for (size_t i = 0; i < idxs.size(); i++) { |
455 | int stride = (int)layout_.blocks()[i].stride; |
456 | off += idxs[i] * stride; |
457 | } |
458 | return off * type_size_; |
459 | } |
460 | |
461 | layout_t layout_; |
462 | int grf_size_; |
463 | int type_size_; |
464 | |
465 | std::vector<int> idxs_; |
466 | int elems_ = 0; |
467 | int off_bytes_ = 0; |
468 | }; |
469 | |
470 | access_builder_t::access_builder_t(ir_context_t &ir_ctx, const view_t &mem_view, |
471 | const expr_t &mem_buf, const expr_t ®_buf, send_op_t send_op, |
472 | send_address_t send_address, send_cache_hint_t send_cache_hint, |
473 | send_hint_t &send_hint) |
474 | : ir_ctx_(&ir_ctx) |
475 | , mem_view_(mem_view) |
476 | , mem_buf_(mem_buf) |
477 | , reg_buf_(reg_buf) |
478 | , send_op_(send_op) |
479 | , send_address_(send_address) |
480 | , send_cache_hint_(send_cache_hint) |
481 | , mem_type_(mem_view.type()) |
482 | , mem_walker_( |
483 | utils::make_unique<memory_walker_t>(ir_ctx.cset(), mem_view)) { |
484 | if (send_hint.hint_2d.enable) { |
485 | if (try_build_2d(send_hint)) return; |
486 | } |
487 | send_hint.hint_2d = send_2d_hint_t(); |
488 | build(); |
489 | } |
490 | |
491 | access_builder_t::access_builder_t(access_builder_t &&) = default; |
492 | access_builder_t::~access_builder_t() = default; |
493 | |
494 | void access_builder_t::build() { |
495 | bool ok = false; |
496 | for (auto &l : candidate_payload_layouts()) { |
497 | // Try to find send decomposition with the given GRF payload layout. |
498 | if (try_build(l)) { |
499 | ok = true; |
500 | break; |
501 | } |
502 | } |
503 | if (!ok && send_op_ == send_op_t::prefetch) { |
504 | // Do not treat as an error, skip prefetch messages during generation. |
505 | ir_warning() << "Can't generate send decomposition for prefetch." |
506 | << std::endl; |
507 | return; |
508 | } |
509 | ir_assert(ok) << "Can't generate send decomposition." ; |
510 | } |
511 | |
512 | static bool stride_dimension_ok(const view_t &view, int stride_tidx, |
513 | int stride_vidx, const std::vector<expr_t> &vstart) { |
514 | auto &tdim = view.tdim(stride_tidx); |
515 | auto e = tdim.expr(); |
516 | for (int i = 0; i < tdim.nvargs(); i++) { |
517 | int vidx = tdim.vidx(i); |
518 | auto &vvar = view.vvars()[vidx]; |
519 | if (vidx == stride_vidx) { |
520 | e = substitute(e, vvar, expr_t(0)); |
521 | } else { |
522 | e = substitute(e, vvar, vstart[vidx]); |
523 | } |
524 | } |
525 | e = simplify(e); |
526 | return is_zero(e); |
527 | } |
528 | |
529 | static expr_t try_scalarize(const expr_t &e) { |
530 | if (e.type().is_scalar()) return e; |
531 | |
532 | if (auto *shuffle = e.as_ptr<shuffle_t>()) { |
533 | if (shuffle->is_broadcast()) return try_scalarize(shuffle->vec[0]); |
534 | return expr_t(); |
535 | } |
536 | |
537 | if (auto *binary = e.as_ptr<binary_op_t>()) { |
538 | auto a = try_scalarize(binary->a); |
539 | auto b = try_scalarize(binary->b); |
540 | if (a.is_empty() || b.is_empty()) return expr_t(); |
541 | return binary_op_t::make(binary->op_kind, a, b); |
542 | } |
543 | |
544 | ir_error_not_expected() << e; |
545 | return expr_t(); |
546 | } |
547 | |
548 | static stmt_t try_promote_to_lsc(const stmt_t &_call) { |
549 | if (_call.is_empty()) return _call; |
550 | auto &call = _call.as<func_call_t>(); |
551 | auto &send = call.func.as<send_t>(); |
552 | if (send.is_lsc || send.is_2d()) return call; |
553 | if (send.hw < ngen::HW::XeHPG) return call; |
554 | if (send.is_slm() || send.is_bts()) return call; |
555 | if (!send.is_block()) return call; |
556 | |
557 | auto mask = try_scalarize(send_t::arg_mask(call)); |
558 | if (mask.is_empty()) return call; |
559 | |
560 | auto new_args = call.args; |
561 | send_t::arg_mask(new_args) = mask; |
562 | |
563 | auto lsc_send = send_t::make(send.hw, send.op, send.address, send.type, |
564 | send.slots, /*is_lsc=*/true, send.cache_hint); |
565 | return lsc_send.call(new_args); |
566 | } |
567 | |
568 | bool access_builder_t::try_build_2d(send_hint_t &send_hint) { |
569 | auto vlayout = mem_view_.create_pseudo_vlayout(); |
570 | auto &hint = send_hint.hint_2d; |
571 | // The data may be loaded in a wider data type to get a proper GRF layout. |
572 | if (!hint.type.is_undef()) vlayout = vlayout.reinterpret(hint.type); |
573 | |
574 | bool is_store = (send_op_ == send_op_t::store); |
575 | auto send_type = type_t::u(vlayout.type().size() * 8); |
576 | auto blocks = vlayout.blocks(); |
577 | if (blocks.size() < 2) return false; |
578 | |
579 | auto &b0 = blocks[0]; |
580 | auto &b1 = blocks[1]; |
581 | ir_assert(b0.dim_idx != b1.dim_idx); |
582 | if (b0.stride != stride_t(1)) return false; |
583 | if (!b1.stride.is_fixed()) return false; |
584 | |
585 | auto get_tdim_idx = [&](int vdim_idx, int &stride) { |
586 | int ret = -1; |
587 | for (int i = 0; i < mem_view_.ntdims(); i++) { |
588 | auto &tdim = mem_view_.tdim(i); |
589 | for (int j = 0; j < tdim.nvargs(); j++) { |
590 | if (tdim.vidx(j) == vdim_idx) { |
591 | ir_assert(ret == -1); |
592 | stride = (int)tdim.vstride(j); |
593 | ret = i; |
594 | } |
595 | } |
596 | } |
597 | return ret; |
598 | }; |
599 | |
600 | int w_tstride = 0; |
601 | int h_tstride = 0; |
602 | int w_dim_idx = get_tdim_idx(b0.dim_idx, w_tstride); |
603 | int h_dim_idx = get_tdim_idx(b1.dim_idx, h_tstride); |
604 | |
605 | if (w_tstride != 1) return false; |
606 | |
607 | auto &tlayout = mem_view_.tlayout(); |
608 | auto get_2d_dim = [&](int tdim_idx) { |
609 | return tlayout.inner_block(tdim_idx, /*skip_outer=*/false); |
610 | }; |
611 | |
612 | int surface_width = 0; |
613 | int surface_height = 0; |
614 | int surface_pitch = b1.stride; |
615 | bool is_w_blocked = (get_2d_dim(w_dim_idx) != tlayout.dim(w_dim_idx)); |
616 | bool is_h_blocked = (get_2d_dim(h_dim_idx) != tlayout.dim(h_dim_idx)); |
617 | // Virtual surface means loading from the innermost block of a block layout |
618 | // which implies no bound checks embedded into 2D block message. |
619 | bool use_virtual_surface = is_w_blocked || is_h_blocked; |
620 | if (use_virtual_surface) { |
621 | if (h_tstride != 1) return false; |
622 | surface_width = b0.block; |
623 | surface_height = b1.block; |
624 | } else { |
625 | surface_width = tlayout.dim(w_dim_idx); |
626 | surface_height = tlayout.dim(h_dim_idx); |
627 | if (surface_height % h_tstride != 0) return false; |
628 | surface_height = surface_height / h_tstride; |
629 | } |
630 | int type_factor = ir_utils::safe_divide(send_type.size(), mem_type_.size()); |
631 | surface_width /= type_factor; |
632 | |
633 | int width = hint.width; |
634 | int height = hint.height; |
635 | int count = 1; |
636 | bool vnni = hint.vnni; |
637 | bool transpose = hint.transpose; |
638 | |
639 | // Try to reduce the number of messages by increasing count per message. |
640 | int try_count = count * 2; |
641 | int max_count |
642 | = block_2d_max_count(is_store, transpose, width, mem_type_.size()); |
643 | while (try_count <= max_count) { |
644 | if (b0.block % (try_count * width) != 0) break; |
645 | count = try_count; |
646 | try_count *= 2; |
647 | } |
648 | |
649 | int W = surface_width; |
650 | int H = surface_height; |
651 | int P = surface_pitch; |
652 | int w = width; |
653 | int h = height; |
654 | int c = count; |
655 | if (!fixup_send_2d_params(send_type, vnni, transpose, |
656 | /*use_xy=*/!use_virtual_surface, W, H, P, w, h, c, |
657 | hint.vnni_permute_factor)) |
658 | return false; |
659 | |
660 | std::vector<dim_t> dims(vlayout.ndims(), 1); |
661 | dims[b0.dim_idx] = count * width; |
662 | dims[b1.dim_idx] = height; |
663 | tensor_t tile(dims); |
664 | |
665 | reg_layout_ = layout_t(type_factor == 1 ? mem_type_ : send_type, 0, |
666 | std::vector<dim_t>(vlayout.ndims(), 1)); |
667 | int h_inner = vnni ? 4 / send_type.size() : 1; |
668 | int h_outer = ir_utils::safe_divide(height, h_inner); |
669 | reg_layout_ = reg_layout_.add_outer_block(b1.dim_idx, h_inner); |
670 | if (transpose) { |
671 | reg_layout_ = reg_layout_.add_outer_block(b1.dim_idx, h_outer); |
672 | reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, width); |
673 | } else { |
674 | reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, width); |
675 | reg_layout_ = reg_layout_.add_outer_block(b1.dim_idx, h_outer); |
676 | } |
677 | reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, count); |
678 | |
679 | int w_outermost |
680 | = ir_utils::safe_divide(vlayout.dim(b0.dim_idx), count * width); |
681 | int h_outermost = ir_utils::safe_divide(vlayout.dim(b1.dim_idx), height); |
682 | reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, w_outermost); |
683 | reg_layout_ = reg_layout_.add_outer_block(b1.dim_idx, h_outermost); |
684 | |
685 | if (type_factor != 1) { |
686 | auto blocks = reg_layout_.blocks(); |
687 | reg_layout_ = layout_t( |
688 | mem_type_, 0, std::vector<dim_t>(vlayout.ndims(), 1)); |
689 | reg_layout_ = reg_layout_.add_outer_block(b0.dim_idx, type_factor); |
690 | for (auto &b : blocks) |
691 | reg_layout_ = reg_layout_.add_outer_block(b.dim_idx, b.block); |
692 | } |
693 | |
694 | for (auto &b : blocks) { |
695 | if (utils::one_of(b.dim_idx, b0.dim_idx, b1.dim_idx)) continue; |
696 | reg_layout_ = reg_layout_.add_outer_block(b.dim_idx, b.block); |
697 | } |
698 | |
699 | reg_layout_walker_ |
700 | = utils::make_unique<layout_walker_t>(reg_layout_, grf_size()); |
701 | |
702 | // Update user hint. |
703 | hint.type = send_type; |
704 | hint.enable = true; |
705 | hint.vnni = vnni; |
706 | hint.transpose = transpose; |
707 | hint.width = w; |
708 | hint.height = h; |
709 | auto _send = send_t::make_2d(ir_ctx_->hw(), send_hint.convert(send_op_), |
710 | send_type, W, H, P, w, h, c, vnni, transpose, send_cache_hint_); |
711 | auto &send = _send.as<send_t>(); |
712 | |
713 | stmt_ = stmt_t(); |
714 | bool ok = true; |
715 | auto vstart0 = mem_view_.vstart(); |
716 | vlayout.for_each_tile(tile, [&](const std::vector<dim_t> &start) { |
717 | if (!ok) return; |
718 | |
719 | int access_size = send.access_size(); |
720 | int access_elems = access_size / mem_type_.size(); |
721 | |
722 | // Check mask requirements. |
723 | expr_t mask; |
724 | if (!check_2d_mask(tensor_t(tile.dims(), start), use_virtual_surface, |
725 | w_dim_idx, h_dim_idx, mask)) { |
726 | ok = false; |
727 | return; |
728 | } |
729 | |
730 | if (!send.is_prefetch_2d()) { |
731 | if (!reg_layout_walker_->can_advance(1, access_elems)) { |
732 | ok = false; |
733 | return; |
734 | } |
735 | |
736 | if (!reg_layout_walker_->can_access(send.payload_size())) { |
737 | ok = false; |
738 | return; |
739 | } |
740 | } |
741 | |
742 | auto vstart = vstart0; |
743 | for (int i = 0; i < vlayout.ndims(); i++) { |
744 | if (start[i] == 0) continue; |
745 | int factor = (i == b0.dim_idx ? type_factor : 1); |
746 | vstart[i] += factor * start[i]; |
747 | } |
748 | auto tstart |
749 | = mem_view_.cvt_vargs_to_targs(vstart, /*ignore_vstart=*/true); |
750 | |
751 | auto &_x = tstart[w_dim_idx]; |
752 | auto &_y = tstart[h_dim_idx]; |
753 | |
754 | expr_t x(0); |
755 | expr_t y(0); |
756 | |
757 | bool skip_send = false; |
758 | if (!use_virtual_surface) { |
759 | std::swap(x, _x); |
760 | std::swap(y, _y); |
761 | if (type_factor != 1) x /= type_factor; |
762 | |
763 | if (h_tstride != 1) { |
764 | if (!stride_dimension_ok( |
765 | mem_view_, h_dim_idx, b1.dim_idx, vstart)) { |
766 | if (send.is_prefetch_2d()) { |
767 | skip_send = true; |
768 | } else { |
769 | ok = false; |
770 | return; |
771 | } |
772 | } |
773 | y /= h_tstride; |
774 | } |
775 | } |
776 | |
777 | auto off = simplify( |
778 | mem_view_.tlayout().offset_in_bytes(tstart), ir_ctx_->cset()); |
779 | |
780 | // Check alignment requirements. |
781 | int64_t align = get_max_const_factor(off, ir_ctx_->cset()); |
782 | if (align % block_2d_base_alignment(ir_ctx_->hw_cfg()) != 0) { |
783 | ok = false; |
784 | return; |
785 | } |
786 | |
787 | if (!skip_send) { |
788 | if (!ir_ctx_->cset().can_prove( |
789 | x % block_2d_x_alignment(send_type.size()) == 0)) { |
790 | ok = false; |
791 | return; |
792 | } |
793 | auto reg_buf = (send.is_prefetch_2d() |
794 | ? expr_t() |
795 | : reg_buf_ + reg_layout_walker_->offset_bytes()); |
796 | auto send_stmt = send(mem_buf_, off, reg_buf, mask, x, y); |
797 | stmt_ = stmt_.append(send_stmt); |
798 | } |
799 | |
800 | reg_layout_walker_->advance(send.access_size() / mem_type_.size()); |
801 | }); |
802 | |
803 | return ok; |
804 | } |
805 | |
806 | bool access_builder_t::fixup_send_2d_params(const type_t &send_type, bool vnni, |
807 | bool transpose, bool use_xy, int &W, int &H, int &P, int &w, int &h, |
808 | int &c, int &vnni_permute_factor) { |
809 | int surface_width_size = W * send_type.size(); |
810 | auto whp_ok = [&]() { |
811 | return block_2d_width_ok(W, send_type.size()) && block_2d_height_ok(H) |
812 | && block_2d_pitch_ok( |
813 | ir_ctx_->hw_cfg(), P, send_type.size(), use_xy); |
814 | }; |
815 | |
816 | // No VNNI permute by default. |
817 | vnni_permute_factor = 0; |
818 | |
819 | // Surface width must be >= 64 bytes. For smaller width we can apply |
820 | // reshape, e.g. [16a] x [16b] -> [8a] x [2a16b] to have block with larger |
821 | // width. Such reshape impacts width/height handling with the following |
822 | // implications: |
823 | // - Reshape is applied only for VNNI and no transpose case. This allows to |
824 | // get the same GRF layout but with permuted height elements: |
825 | // - Layout without reshape: 8a16b2a |
826 | // - Layout with reshape: 8a16b2a (a/height dimension is permuted) |
827 | // - Permutation is safe when it's done for the reduction dimension |
828 | // (doesn't matter in which order elements are accumulated). |
829 | // - Permutation pattern must be the same between A and B tensors |
830 | if (surface_width_size >= 64) return whp_ok(); |
831 | |
832 | // Reshape is only expected/supported with VNNI. |
833 | if (!vnni || transpose) return false; |
834 | |
835 | if (64 % surface_width_size != 0) return false; |
836 | |
837 | int factor = 64 / surface_width_size; |
838 | if (h % factor != 0) return false; |
839 | |
840 | int max_count = block_2d_max_count( |
841 | send_op_ == send_op_t::store, transpose, w, send_type.size()); |
842 | if (factor > max_count) return false; |
843 | |
844 | vnni_permute_factor = factor; |
845 | W *= factor; |
846 | P *= factor; |
847 | H /= factor; |
848 | h /= factor; |
849 | c = factor; |
850 | return whp_ok(); |
851 | } |
852 | |
853 | bool access_builder_t::check_2d_mask(const tensor_t &tile, |
854 | bool use_virtual_surface, int w_dim_idx, int h_dim_idx, |
855 | expr_t &mask) const { |
856 | auto sub_view = mem_view_.create_sub_view(tile); |
857 | auto mask_tensor = sub_view.create_mask_tensor(ir_ctx_->cset()); |
858 | mask = mask_tensor.to_expr(1); |
859 | if (!mask.is_empty()) return true; |
860 | |
861 | // Virtual surface implies no out-of-bound send checks. |
862 | if (use_virtual_surface) return false; |
863 | |
864 | // Remove bound conditions that are covered by out-of-bound send checks. |
865 | uint32_t tmask = 0xFFFFFFFF; |
866 | for (int i = 0; i < sub_view.nvdims(); i++) { |
867 | if (!utils::one_of(i, w_dim_idx, h_dim_idx)) continue; |
868 | for (int j = 0; j < sub_view.ntdims(); j++) { |
869 | auto &tdim = sub_view.tdim(j); |
870 | for (int k = 0; k < tdim.nvargs(); k++) { |
871 | if (tdim.vidx(k) == i) { |
872 | // TODO: Check if tdim mask is a bound mask. |
873 | tmask &= ~(1U << i); |
874 | } |
875 | } |
876 | } |
877 | } |
878 | mask_tensor = sub_view.create_mask_tensor(ir_ctx_->cset(), tmask); |
879 | mask = mask_tensor.to_expr(1); |
880 | if (!mask.is_empty()) return true; |
881 | |
882 | return false; |
883 | } |
884 | |
885 | bool access_builder_t::try_build(const layout_t &try_layout) { |
886 | auto &try_layout_blocks = try_layout.blocks(); |
887 | int reg_stride |
888 | = (try_layout_blocks.empty() ? 0 |
889 | : (int)try_layout_blocks[0].stride); |
890 | auto send_list = send_t::get_all(ir_ctx_->hw(), send_op_, send_address_, |
891 | mem_type_, send_cache_hint_); |
892 | reg_layout_walker_ |
893 | = utils::make_unique<layout_walker_t>(try_layout, grf_size()); |
894 | stmt_ = stmt_t(); |
895 | mem_walker_->reset(); |
896 | // Iterate through the memory view, greedily select messages according to |
897 | // the sorted message list. |
898 | while (mem_walker_->has_next()) { |
899 | func_t _send; |
900 | for (auto &_s : send_list) { |
901 | auto &s = _s.as<send_t>(); |
902 | |
903 | int slot_size = s.type.size(); |
904 | int alignment = s.alignment(); |
905 | int nmasks = s.nmasks(); |
906 | int payload_stride = s.payload_type_stride(); |
907 | int access_size = s.access_size(); |
908 | int access_elems = access_size / mem_type_.size(); |
909 | bool is_last_chunk = mem_walker_->remaining_size() <= access_size; |
910 | |
911 | if (reg_stride != 1 || payload_stride != slot_size) { |
912 | // Detected strided GRF layout or strided payload. In this |
913 | // case require full data type and stride match. |
914 | if (reg_stride != 0 |
915 | && payload_stride != reg_stride * mem_type_.size()) |
916 | continue; |
917 | if (s.type.size() != mem_type_.size()) continue; |
918 | } |
919 | // Prefetches don't have payload so skip these conditions for |
920 | // prefetch. |
921 | if (!s.is_prefetch()) { |
922 | if (!reg_layout_walker_->can_advance( |
923 | reg_stride, access_elems, is_last_chunk)) |
924 | continue; |
925 | |
926 | if (!reg_layout_walker_->can_access(s.payload_size())) continue; |
927 | } |
928 | |
929 | // Check if slots are contiguous and aligned. |
930 | if (!mem_walker_->check_region(0, s.slots, slot_size, alignment)) |
931 | continue; |
932 | |
933 | // Check mask requirements. |
934 | // XXX: Postpone mask check for prefetch until during send call |
935 | // generation. If the mask cannot be generated, skip the prefetch. |
936 | if (!s.is_prefetch() |
937 | && !mem_walker_->check_mask_size( |
938 | 0, access_size, s.mask_size(), nmasks)) |
939 | continue; |
940 | |
941 | _send = _s; |
942 | break; |
943 | } |
944 | // Can't find a message - try another GRF layout for payload. |
945 | if (_send.is_empty()) return false; |
946 | |
947 | auto &send = _send.as<send_t>(); |
948 | auto send_stmt = create_send_stmt(send); |
949 | send_stmt = try_promote_to_lsc(send_stmt); |
950 | stmt_ = stmt_.append(send_stmt); |
951 | |
952 | reg_layout_walker_->advance(send.access_size() / mem_type_.size()); |
953 | mem_walker_->advance(send.access_size()); |
954 | } |
955 | reg_layout_ = try_layout; |
956 | return true; |
957 | } |
958 | |
959 | std::vector<layout_t> access_builder_t::candidate_payload_layouts() const { |
960 | int type_size = mem_type_.size(); |
961 | auto vlayout = mem_view_.create_dense_vlayout(); |
962 | |
963 | std::vector<layout_t> ret; |
964 | |
965 | // Dense payload layout directly mapping to the memory view. |
966 | ret.push_back(vlayout); |
967 | |
968 | // These payload layouts are to match payload for byte x {1,2} scattered |
969 | // messages (they are dword-strided). |
970 | if (type_size == 2) ret.push_back(vlayout.make_strided(2)); |
971 | if (type_size == 1) ret.push_back(vlayout.make_strided(4)); |
972 | |
973 | return ret; |
974 | } |
975 | |
976 | stmt_t access_builder_t::create_send_stmt(const send_t &send) { |
977 | std::vector<expr_t> off_vec; |
978 | // Try to detect a common base and const vector offset to reduce further |
979 | // arithmetic. |
980 | expr_t off_base0; |
981 | int off_const0 = -1; |
982 | bool is_same_base = true; |
983 | std::vector<expr_t> off_const_vec; |
984 | for (int i = 0; i < send.slots; i++) { |
985 | expr_t off_base; |
986 | int off_const; |
987 | auto off = mem_walker_->get_offset( |
988 | i * send.type.size(), off_base, off_const); |
989 | if (off_base0.is_empty()) { |
990 | off_base0 = off_base; |
991 | off_const0 = off_const; |
992 | } else if (!off_base.is_equal(off_base0)) { |
993 | is_same_base = false; |
994 | } |
995 | off_vec.push_back(off); |
996 | off_const_vec.push_back(off_const - off_const0); |
997 | } |
998 | expr_t off; |
999 | if (send.slots == 1 || !is_same_base) { |
1000 | off = shuffle_t::make(off_vec); |
1001 | } else { |
1002 | off = shuffle_t::make_broadcast(off_base0, send.slots) |
1003 | + shuffle_t::make_broadcast(off_const0, send.slots) |
1004 | + shuffle_t::make(off_const_vec); |
1005 | } |
1006 | bool allow_fail = send.is_prefetch(); |
1007 | auto _mask = mem_walker_->get_mask( |
1008 | 0, send.access_size(), send.mask_size(), send.nmasks(), allow_fail); |
1009 | if (_mask.is_empty()) return stmt_t(); |
1010 | |
1011 | auto _reg_buf = (send.is_prefetch() |
1012 | ? expr_t() |
1013 | : reg_buf_ + reg_layout_walker_->offset_bytes()); |
1014 | auto ret = send(mem_buf_, off, _reg_buf, _mask); |
1015 | return ret; |
1016 | } |
1017 | |
1018 | send_2d_hint_t get_send_2d_hint(send_op_t send_op, const type_t &_type, |
1019 | bool vnni, bool transpose, int w_tile, int h_tile, int w_blk = 0, |
1020 | int h_blk = 0) { |
1021 | auto type = _type; |
1022 | bool orig_vnni = vnni; |
1023 | bool orig_transpose = transpose; |
1024 | |
1025 | if (vnni && transpose) { |
1026 | // This combination is not supported but try to replace by upconvert |
1027 | // and transpose. |
1028 | if (type.size() == 2) { |
1029 | type = type_t::u32(); |
1030 | w_tile = ir_utils::safe_divide(w_tile, 2); |
1031 | w_blk = ir_utils::safe_divide(w_blk, 2); |
1032 | vnni = false; |
1033 | orig_vnni = false; |
1034 | } |
1035 | } |
1036 | |
1037 | // XXX: Convert transpose to VNNI when transpose is not |
1038 | // supported. This will require additional reorder but |
1039 | // reorder from "partially transposed" VNNI transformed |
1040 | // layout is cheaper. |
1041 | if (transpose && type.size() != 4) { |
1042 | vnni = true; |
1043 | transpose = false; |
1044 | } |
1045 | |
1046 | bool is_load_or_prefetch |
1047 | = utils::one_of(send_op, send_op_t::load, send_op_t::prefetch); |
1048 | bool is_store = (send_op == send_op_t::store); |
1049 | |
1050 | // Only D8, D16 and D32 are implemented. |
1051 | if (!utils::one_of(type.size(), 1, 2, 4)) return send_2d_hint_t(); |
1052 | |
1053 | // VNNI and transpose are mutually exclusive. |
1054 | if (vnni && transpose) return send_2d_hint_t(); |
1055 | |
1056 | // VNNI and transpose are supported with load only. |
1057 | if (is_store && (vnni || transpose)) return send_2d_hint_t(); |
1058 | |
1059 | // VNNI is supported with D8 and D16 only. |
1060 | if (vnni && !utils::one_of(type.size(), 1, 2)) return send_2d_hint_t(); |
1061 | |
1062 | // Transpose is supported with D32 only. |
1063 | if (transpose && type.size() != 4) return send_2d_hint_t(); |
1064 | |
1065 | int w_min = (transpose ? 1 : 4 / type.size()); |
1066 | int w_max = (transpose ? 8 : (vnni ? 16 : 64 / type.size())); |
1067 | int h_min = (vnni ? (4 / type.size()) : 1); |
1068 | int h_max = (is_load_or_prefetch ? 32 : 8); |
1069 | |
1070 | if (w_blk > 0 && (w_blk < w_min || w_blk > w_max)) return send_2d_hint_t(); |
1071 | if (h_blk > 0 && (h_blk < h_min || h_blk > h_max)) return send_2d_hint_t(); |
1072 | |
1073 | auto find_block = [&](int dim, int min, int max) { |
1074 | for (int b = max; b >= min; b /= 2) { |
1075 | if (dim % b == 0) return b; |
1076 | } |
1077 | return 0; |
1078 | }; |
1079 | |
1080 | if (w_blk == 0) w_blk = find_block(w_tile, w_min, w_max); |
1081 | if (h_blk == 0) h_blk = find_block(h_tile, h_min, h_max); |
1082 | if (w_blk == 0 || h_blk == 0) return send_2d_hint_t(); |
1083 | |
1084 | if (orig_vnni && h_blk > 0) h_blk = find_block(h_tile, h_blk, h_max); |
1085 | if (orig_transpose && w_blk > 0) w_blk = find_block(w_tile, w_blk, w_max); |
1086 | |
1087 | send_2d_hint_t hint; |
1088 | hint.type = type; |
1089 | hint.enable = true; |
1090 | hint.width = w_blk; |
1091 | hint.height = h_blk; |
1092 | hint.vnni = vnni; |
1093 | hint.transpose = transpose; |
1094 | return hint; |
1095 | } |
1096 | |
1097 | send_hint_t get_send_hint(const exec_config_t &exec_cfg, send_op_t send_op, |
1098 | fma_kind_t fma_kind, abc_kind_t abc_kind, const view_t &view, |
1099 | const gemm_schedule_t &gemm_schedule, bool allow_2d) { |
1100 | if (!allow_2d) return send_hint_t(); |
1101 | if (exec_cfg.hw() < ngen::HW::XeHPC) return send_hint_t(); |
1102 | if (!utils::one_of(send_op, send_op_t::load, send_op_t::prefetch, |
1103 | send_op_t::store)) |
1104 | return send_hint_t(); |
1105 | |
1106 | auto vlayout = view.create_pseudo_vlayout(); |
1107 | auto blocks = vlayout.blocks(); |
1108 | if (blocks.size() < 2) return send_hint_t(); |
1109 | |
1110 | auto &bmnk_mapper = gemm_schedule.bmnk_mapper(); |
1111 | auto &b0 = blocks[0]; |
1112 | auto &b1 = blocks[1]; |
1113 | if (b0.dim_idx == b1.dim_idx) return send_hint_t(); |
1114 | if (b0.stride != stride_t(1)) return send_hint_t(); |
1115 | if (b1.stride.is_unknown()) return send_hint_t(); |
1116 | |
1117 | send_hint_t hint; |
1118 | if (send_op == send_op_t::load && fma_kind == fma_kind_t::dpas |
1119 | && utils::one_of(abc_kind, abc_kind_t::a, abc_kind_t::b)) { |
1120 | bool is_dpas_src1 = (abc_kind == abc_kind_t::b); |
1121 | int mn_blk = (is_dpas_src1 ? exec_cfg.simd() : 8); |
1122 | int k_blk = 32 / view.type().size(); |
1123 | |
1124 | bool is_b0_k = (bmnk_mapper.bmnk_kind(abc_kind, b0.dim_idx) |
1125 | == bmnk_kind_t::k); |
1126 | |
1127 | // Handle 4 cases (consider bf16): |
1128 | // src1, MxK: 16a16b -> 8a16b2a (VNNI) |
1129 | // src1, KxM: 16a16b -> 16b16a -> 8b16a2b (transpose + VNNI) |
1130 | // src2, KxN: 16a16b -> 16b16a (transpose) |
1131 | // src2, NxK: 16a16b -> 16a16b () |
1132 | bool vnni = is_dpas_src1; |
1133 | bool transpose = (is_dpas_src1 == is_b0_k); |
1134 | int b0_blk = is_b0_k ? k_blk : mn_blk; |
1135 | int b1_blk = !is_b0_k ? k_blk : mn_blk; |
1136 | if (b0.block % b0_blk != 0) return send_hint_t(); |
1137 | if (b1.block % b1_blk != 0) return send_hint_t(); |
1138 | hint.hint_2d = get_send_2d_hint(send_op, view.type(), vnni, transpose, |
1139 | b0.block, b1.block, b0_blk, b1_blk); |
1140 | } else { |
1141 | if (b0.block >= 128) return hint; |
1142 | hint.hint_2d = get_send_2d_hint( |
1143 | send_op, view.type(), false, false, b0.block, b1.block); |
1144 | } |
1145 | |
1146 | return hint; |
1147 | } |
1148 | |
1149 | } // namespace jit |
1150 | } // namespace gpu |
1151 | } // namespace impl |
1152 | } // namespace dnnl |
1153 | |