1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/codegen/bank_conflict_allocation.hpp"
18
19#include "common/verbose.hpp"
20#include "gpu/jit/ir/fma.hpp"
21#include "gpu/jit/utils/utils.hpp"
22
23#include <sstream>
24#include <string>
25#include <vector>
26#include <initializer_list>
27
28#if defined(__GNUC__) && __GNUC__ == 7
29// GCC 7.x issues a false positive warning 'array subscript is above array bounds'
30#pragma GCC diagnostic push
31#pragma GCC diagnostic ignored "-Warray-bounds"
32#endif
33
34namespace dnnl {
35namespace impl {
36namespace gpu {
37namespace jit {
38
39namespace {
40
41// Helper structure to access HW-specific information.
42struct hw_context_t {
43 hw_context_t(ngen::HW hw, int regs)
44 : hw(hw), regs(regs), reg_size(ngen::GRF::bytes(hw)) {
45 int bank0 = reg_bank(0);
46 for (int i = 1; i < regs; i++)
47 if (reg_bank(i) != bank0) {
48 reg_bank_stride = i;
49 break;
50 }
51 ir_assert(reg_bank_stride != -1);
52
53 bank_masks.resize(ngen::Bundle::bank_count(hw));
54 bundle_masks.resize(ngen::Bundle::bundle_count(hw));
55
56 for (int i = 0; i < regs; i++) {
57 int bank = reg_bank(i);
58 int bundle = reg_bundle(i);
59 if (i < 64) {
60 bank_masks[bank] |= (1ull << i);
61 bundle_masks[bundle] |= (1ull << i);
62 } else {
63 // Ensure bank/bundle pattern is repeated.
64 int j = (i % 64);
65 ir_assert((bank_masks[bank] & (1ull << j)) != 0);
66 ir_assert((bundle_masks[bundle] & (1ull << j)) != 0);
67 }
68 }
69 }
70
71 int reg_bank(int reg) const {
72 auto bundle = ngen::Bundle::locate(hw, ngen::GRF(reg));
73 return bundle.bank_id;
74 }
75
76 int reg_bundle(int reg) const {
77 auto bundle = ngen::Bundle::locate(hw, ngen::GRF(reg));
78 return bundle.bundle_id;
79 }
80
81 int hw_simd() const {
82 switch (hw) {
83 case ngen::HW::Gen9:
84 case ngen::HW::Gen10:
85 case ngen::HW::Gen11:
86 case ngen::HW::XeLP:
87 case ngen::HW::XeHP:
88 case ngen::HW::XeHPG: return 8;
89 case ngen::HW::XeHPC: return 16;
90 default: ir_error_not_expected();
91 }
92 return -1;
93 }
94
95 ngen::HW hw;
96 int regs; // Number of registers.
97 int reg_size; // Size of register in bytes.
98
99 // Stride in registers between different GRF banks.
100 int reg_bank_stride = -1;
101
102 // 64-bit bitmasks for each bank/bundle. If i-th bit is set in the mask for
103 // B bank/bundle then i-th register belongs to B bank/bundle. Assume the
104 // pattern is repeated for the next 64 registers, etc.
105 std::vector<uint64_t> bank_masks;
106 std::vector<uint64_t> bundle_masks;
107};
108
109// Bitmask for registers, one bit per register. Many interfaces are named after
110// std::bitset API.
111struct reg_mask_t {
112 reg_mask_t() = default;
113
114 reg_mask_t(const hw_context_t *hw_ctx, uint64_t chunk_mask = -1)
115 : hw_ctx(hw_ctx), nchunks(hw_ctx->regs / chunk_bits) {
116 for (int i = 0; i < nchunks; i++)
117 chunks[i] = chunk_mask;
118 }
119
120 bool none() const {
121 uint64_t mask = 0;
122 for (int i = 0; i < nchunks; i++)
123 mask |= chunks[i];
124 return mask == 0;
125 }
126
127 bool test(int i) const {
128 int ichunk = i / chunk_bits;
129 int bit = i % chunk_bits;
130 return (chunks[ichunk] >> bit) & 0x1;
131 }
132
133 // Returns true if all bits from [off, off + len - 1] range are not set.
134 bool is_unset(int off, int len) const;
135
136 void set(int i, bool value = true) {
137 int ichunk = i / chunk_bits;
138 int bit = i % chunk_bits;
139 if (value)
140 chunks[ichunk] |= (1ull << bit);
141 else
142 chunks[ichunk] &= ~(1ull << bit);
143 }
144
145 void set(int off, int len, bool value = true) {
146 for (int i = off; i < off + len; i++)
147 set(i, value);
148 }
149
150 void reset() {
151 for (int i = 0; i < nchunks; i++)
152 chunks[i] = 0;
153 }
154
155 // Returns the number of set register bits.
156 int count() const {
157 int ret = 0;
158 for (int i = 0; i < nchunks; i++)
159 ret += ngen::utils::popcnt(chunks[i]);
160 return ret;
161 }
162
163 // Returns the index of the first set register bit.
164 int bsf() const {
165 for (int i = 0; i < hw_ctx->regs; i++) {
166 if (test(i)) return i;
167 }
168 return -1;
169 }
170
171 // Returns the index of the last set register bit.
172 int bsr() const {
173 UNUSED(&reg_mask_t::bsr);
174 for (int i = hw_ctx->regs - 1; i >= 0; i--) {
175 if (test(i)) return i;
176 }
177 return -1;
178 }
179
180 // Returns a mask where all bits in [off, off + len - 1] range are set and
181 // other bits are not set.
182 reg_mask_t range_mask(int off, int len) const {
183 reg_mask_t ret(hw_ctx);
184 ret = ret << (hw_ctx->regs - len);
185 ret = ret >> (hw_ctx->regs - len - off);
186 return ret;
187 }
188
189 // Returns GRF bank for all set register bits if they share the same bank,
190 // otherwise returns -1.
191 int bank() const;
192
193 void subtract(const reg_mask_t &other) { *this &= ~other; }
194
195 reg_mask_t &operator&=(const reg_mask_t &other) {
196 for (int i = 0; i < nchunks; i++)
197 chunks[i] = chunks[i] & other.chunks[i];
198 return *this;
199 }
200
201 reg_mask_t &operator|=(const reg_mask_t &other) {
202 UNUSED(&reg_mask_t::operator|=);
203 for (int i = 0; i < nchunks; i++)
204 chunks[i] = chunks[i] | other.chunks[i];
205 return *this;
206 }
207
208 reg_mask_t operator<<(int shift) const {
209 int idx = shift / chunk_bits;
210 int bit = shift % chunk_bits;
211 reg_mask_t ret(hw_ctx, 0);
212 for (int i = idx + 1; i < nchunks; i++) {
213 auto c0 = (chunks[i - idx] << bit);
214 auto c1 = (bit == 0 ? 0
215 : (chunks[i - idx - 1] >> (chunk_bits - bit)));
216 ret.chunks[i] = c0 | c1;
217 }
218 ret.chunks[idx] = (chunks[0] << bit);
219 return ret;
220 }
221
222 reg_mask_t operator>>(int shift) const {
223 int idx = shift / chunk_bits;
224 int bit = shift % chunk_bits;
225 reg_mask_t ret(hw_ctx, 0);
226 for (int i = 0; i + idx + 1 < nchunks; i++) {
227 auto c0 = (chunks[i + idx] >> bit);
228 auto c1 = (bit == 0 ? 0
229 : (chunks[i + idx + 1] << (chunk_bits - bit)));
230 ret.chunks[i] = c0 | c1;
231 }
232 ret.chunks[nchunks - idx - 1] = (chunks[nchunks - 1] >> bit);
233 return ret;
234 }
235
236 bool operator==(const reg_mask_t &other) const {
237 for (int i = 0; i < nchunks; i++)
238 if (chunks[i] != other.chunks[i]) return false;
239 return true;
240 }
241
242 reg_mask_t operator~() const {
243 reg_mask_t ret(hw_ctx);
244 for (int i = 0; i < nchunks; i++)
245 ret.chunks[i] = ~chunks[i];
246 return ret;
247 }
248
249 std::string str() const {
250 UNUSED(&reg_mask_t::str);
251 std::ostringstream oss;
252 for (int i = hw_ctx->regs - 1; i >= 0; i--) {
253 oss << (test(i) ? "1" : "0");
254 }
255 return oss.str();
256 }
257
258 IR_DEFINE_DUMP()
259
260 static const int chunk_bits = 64;
261 static const int max_regs = 256;
262 static const int max_nchunks = max_regs / chunk_bits;
263
264 const hw_context_t *hw_ctx = nullptr;
265 int nchunks = 0;
266 uint64_t chunks[max_nchunks] = {0};
267};
268
269inline reg_mask_t operator&(const reg_mask_t &a, const reg_mask_t &b) {
270 auto ret = a;
271 return ret &= b;
272}
273
274inline bool reg_mask_t::is_unset(int off, int len) const {
275 auto m = range_mask(off, len);
276 return m == (*this & m);
277}
278
279inline int reg_mask_t::bank() const {
280 if (*this == (*this & reg_mask_t(hw_ctx, hw_ctx->bank_masks[0]))) return 0;
281 if (*this == (*this & reg_mask_t(hw_ctx, hw_ctx->bank_masks[1]))) return 1;
282 return -1;
283}
284
285// Represents a compound mask for a contiguous block of registers. For each
286// register in the block its mask describes potential candidates for the
287// register.
288struct reg_block_mask_t {
289 reg_block_mask_t() = default;
290
291 reg_block_mask_t(const hw_context_t *hw_ctx, int regs) : regs(regs) {
292 masks.reserve(regs);
293 for (int i = 0; i < regs; i++)
294 masks.emplace_back(hw_ctx);
295
296 auto &mask0 = masks[0];
297
298 // Align all blocks to a GRF bank boundary.
299 int step = hw_ctx->reg_bank_stride;
300 for (int i = 0; i < hw_ctx->regs; i += step) {
301 for (int j = i + 1; j < i + step; j++) {
302 mask0.set(j, false);
303 }
304 }
305 // Exclude base registers that result in crossing the last register.
306 for (int i = hw_ctx->regs - regs + 1; i < hw_ctx->regs; i++) {
307 mask0.set(i, false);
308 }
309 // Update other masks.
310 propagate_masks();
311 }
312
313 void exclude(const reg_mask_t &mask) {
314 for (auto &m : masks)
315 m.subtract(mask);
316 }
317
318 bool can_be_assigned() const {
319 for (auto &m : masks)
320 if (m.none()) return false;
321 return true;
322 }
323
324 bool is_assigned() const { return masks[0].count() == 1; }
325
326 void propagate_masks() {
327 // Limit the first register mask based on other register masks.
328 for (int j = 1; j < regs; j++) {
329 masks[0] &= (masks[j] >> j);
330 }
331 // Propagate back.
332 for (int j = 1; j < regs; j++) {
333 masks[j] &= (masks[0] << j);
334 }
335 }
336
337 std::string str() const {
338 UNUSED(&reg_block_mask_t::str);
339 std::ostringstream oss;
340 for (int i = 0; i < regs; i++) {
341 oss << "#" << i << " mask: " << masks[i].str();
342 if (i != regs - 1) oss << std::endl;
343 }
344 return oss.str();
345 }
346
347 IR_DEFINE_DUMP()
348
349 int regs;
350 std::vector<reg_mask_t> masks;
351};
352
353// Represents a single register in a register block.
354struct reg_t {
355 reg_t() = default;
356
357 reg_t(reg_block_mask_t *block, int off) : block(block), off(off) {}
358
359 bool is_empty() const { return !block; }
360
361 int bank() const {
362 if (is_empty()) return -1;
363 return block->masks[off].bank();
364 }
365
366 void exclude(const reg_mask_t &mask) {
367 if (is_empty()) return;
368 block->masks[off].subtract(mask);
369 }
370
371 bool operator==(const reg_t &other) const {
372 return (other.block == block) && (other.off == off);
373 }
374
375 std::string str() const {
376 UNUSED(&reg_t::str);
377 if (is_empty()) return "null";
378 std::ostringstream oss;
379 if (block->is_assigned()) {
380 int reg = block->masks[off].bsf();
381 oss << "r" << reg;
382 } else {
383 oss << "R" << off;
384 }
385 return oss.str();
386 }
387
388 IR_DEFINE_DUMP()
389
390 reg_block_mask_t *block = nullptr;
391 int off = -1;
392};
393
394// Mask for a blocked register buffer. Buffer consists of blocks: B0, B1, ...
395// Each block B(i) has the same size and is contiguous inside, B(i) and B(i+1)
396// are not necessarily contiguous.
397struct reg_buf_mask_t {
398 reg_buf_mask_t(const hw_context_t *hw_ctx, int regs, int block_regs = 0)
399 : hw_ctx(hw_ctx), regs(regs), block_regs(block_regs) {
400 if (block_regs == 0) this->block_regs = regs;
401 ir_assert(regs % this->block_regs == 0);
402 for (int i = 0; i < nblocks(); i++) {
403 blocks.emplace_back(hw_ctx, this->block_regs);
404 }
405 }
406
407 // Size in bytes.
408 int size() const { return regs * hw_ctx->reg_size; }
409
410 int nblocks() const { return regs / block_regs; }
411
412 reg_t get_reg(int off_bytes) {
413 ir_assert(off_bytes < size());
414 off_bytes /= hw_ctx->reg_size;
415 int block_idx = off_bytes / block_regs;
416 int reg_idx = off_bytes % block_regs;
417 return reg_t(&blocks[block_idx], reg_idx);
418 }
419
420 const hw_context_t *hw_ctx;
421 int regs; // Number of registers in the buffer.
422 int block_regs; // Number of registers in one block.
423 std::vector<reg_block_mask_t> blocks;
424};
425
426// Represents a 3-src instruction.
427struct instruction_t {
428 instruction_t(const reg_t &src0, const reg_t &src1, const reg_t &src2)
429 : src0(src0), src1(src1), src2(src2) {}
430
431 bool has(const reg_t &reg) const {
432 return reg == src0 || reg == src1 || reg == src2;
433 }
434
435 reg_t src0;
436 reg_t src1;
437 reg_t src2;
438};
439
440// Helper structure for GRF assignment search.
441struct search_context_t {
442 search_context_t(const hw_context_t *hw_ctx, const reg_mask_t &reg_mask,
443 std::vector<reg_block_mask_t *> &blocks,
444 const std::vector<instruction_t> &instructions)
445 : hw_ctx(hw_ctx)
446 , reg_mask(reg_mask)
447 , blocks(blocks)
448 , instructions(instructions) {
449 saved_blocks.resize(nblocks() * nblocks());
450 }
451
452 int nblocks() { return int(blocks.size()); }
453
454 void set_check_bundles(bool value = true) { check_bundles = value; }
455
456 void set_check_diff_banks_src02(bool value = true) {
457 check_diff_banks_src02 = value;
458 }
459
460 // Saves block masks for the current recursion level.
461 void save_blocks() {
462 ir_assert(saved_block_idx + nblocks() <= int(saved_blocks.size()));
463 for (int i = 0; i < nblocks(); i++) {
464 saved_blocks[saved_block_idx + i] = *blocks[i];
465 }
466 saved_block_idx += nblocks();
467 steps++;
468 }
469
470 // Restores saved block masks.
471 void restore_blocks() {
472 saved_block_idx -= nblocks();
473 ir_assert(saved_block_idx >= 0);
474 for (int i = 0; i < nblocks(); i++) {
475 *blocks[i] = saved_blocks[saved_block_idx + i];
476 }
477 }
478
479 bool should_stop() const {
480 int max_steps = 250;
481 return steps > max_steps;
482 }
483
484 void reset_steps() { steps = 0; }
485
486 const hw_context_t *hw_ctx;
487
488 int steps = 0;
489 reg_mask_t reg_mask;
490 std::vector<reg_block_mask_t *> blocks;
491 std::vector<instruction_t> instructions;
492
493 int saved_block_idx = 0;
494 std::vector<reg_block_mask_t> saved_blocks;
495
496 // Whether to require bundle check.
497 bool check_bundles = false;
498
499 // Whether to require src0 and src2 to be in different banks (dpas-specific).
500 bool check_diff_banks_src02 = false;
501};
502
503bool search(search_context_t &ctx, int block_idx = 0) {
504 // All blocks are assigned, success.
505 if (block_idx >= ctx.nblocks()) return true;
506
507 auto *hw_ctx = ctx.hw_ctx;
508
509 auto &block = *ctx.blocks[block_idx];
510 auto &mask0 = block.masks[0];
511
512 // 1. Assign i-th register for the current block base
513 // 2. Update register constraints for other blocks
514 // 3. If the remaining blocks still can be assigned, move to the next
515 // block. Otherwise try the next register in step 1.
516 for (int i = 0; i < ctx.hw_ctx->regs; i++) {
517 if (!mask0.test(i)) continue;
518 if (!ctx.reg_mask.is_unset(i, block.regs)) continue;
519 // Stop the search if it takes too many steps.
520 if (ctx.should_stop()) return false;
521
522 // Try to assign the current block to i-th register.
523 ctx.save_blocks();
524
525 // Claim the register block.
526 ctx.reg_mask.set(i, block.regs, false);
527 reg_mask_t i_reg_mask(hw_ctx, 0);
528 for (int j = 0; j < block.regs; j++) {
529 block.masks[j].reset();
530 block.masks[j].set(i + j);
531 i_reg_mask.set(i + j);
532 }
533
534 bool conflicts_ok = true;
535
536 // Exclude the new region from the remaining masks.
537 for (int j = block_idx + 1; j < ctx.nblocks(); j++) {
538 ctx.blocks[j]->exclude(i_reg_mask);
539 if (!ctx.blocks[j]->can_be_assigned()) {
540 conflicts_ok = false;
541 break;
542 }
543 }
544
545 // Update constraints according to register usages in instructions.
546 std::vector<reg_mask_t> bundle_masks;
547 bundle_masks.reserve(block.regs);
548 for (int j = 0; j < block.regs; j++) {
549 int bundle = hw_ctx->reg_bundle(i + j);
550 bundle_masks.emplace_back(hw_ctx, hw_ctx->bundle_masks[bundle]);
551 }
552
553 for (auto &insn : ctx.instructions) {
554 for (int j = 0; j < block.regs; j++) {
555 reg_t j_reg(&block, j);
556 if (!insn.has(j_reg)) continue;
557
558 int bank0 = insn.src0.bank();
559 int bank1 = insn.src1.bank();
560 int bank2 = insn.src2.bank();
561
562 if (ctx.check_diff_banks_src02) {
563 if (bank0 != -1 && bank0 == bank2) {
564 conflicts_ok = false;
565 break;
566 }
567 }
568
569 // Handle bank conflict condition.
570 if (bank0 != -1 && bank0 == bank1 && bank1 == bank2) {
571 conflicts_ok = false;
572 break;
573 }
574
575 // Handle bundle conflict condition.
576 if (ctx.check_bundles) {
577 for (auto *reg : {&insn.src0, &insn.src1, &insn.src2}) {
578 if (*reg == j_reg) continue;
579 reg->exclude(bundle_masks[j]);
580 }
581 }
582 break;
583 }
584 }
585
586 if (conflicts_ok) {
587 for (auto *b : ctx.blocks) {
588 b->propagate_masks();
589 if (!b->can_be_assigned()) {
590 conflicts_ok = false;
591 break;
592 }
593 }
594 }
595
596 if (conflicts_ok) {
597 bool ok = search(ctx, block_idx + 1);
598 if (ok) return true;
599 }
600
601 // Release the register block, move to the next candidate.
602 ctx.reg_mask.set(i, block.regs, true);
603 ctx.restore_blocks();
604 }
605 return false;
606}
607
608reg_mask_t create_available_reg_mask(
609 reg_allocator_t &ra, const hw_context_t *hw_ctx) {
610 reg_mask_t reg_mask(hw_ctx, 0);
611 ra.start_speculate();
612
613 // Query the allocator to get information about free registers.
614 for (;;) {
615 auto grf = ra.try_alloc();
616 if (grf.isInvalid()) break;
617 reg_mask.set(grf.getBase());
618 }
619
620 for (int i = 0; i < hw_ctx->regs; i++) {
621 if (reg_mask.test(i)) {
622 ngen::GRF grf(i);
623 ra.safeRelease(grf);
624 }
625 }
626
627 ra.finish_speculate();
628 return reg_mask;
629}
630
631} // namespace
632
633bank_conflict_allocation_t bank_conflict_allocation_t::create(
634 reg_allocator_t &ra, int regs, const bank_conflict_attr_t &attr) {
635 hw_context_t hw_ctx(ra.hardware(), regs);
636
637 bool is_dpas = false;
638 bool is_dp4a = false;
639 expr_t dst_base;
640 if (!attr.instructions.empty()) {
641 auto &s = attr.instructions[0];
642 auto &func = s.as<func_call_t>().func;
643 if (func.is<dpas_t>()) {
644 is_dp4a = func.as<dpas_t>().is_dp4a();
645 is_dpas = !is_dp4a;
646 dst_base = get_base(dpas_t::arg_dst(s));
647 } else if (func.is<mad_t>()) {
648 dst_base = get_base(mad_t::arg_dst(s));
649 } else {
650 ir_error_not_expected();
651 }
652 }
653
654 // Heuristics for src/dst block sizes.
655 int dst_block_regs = (is_dpas ? 0 : 2);
656 int src_block_regs = (is_dpas || is_dp4a ? 0 : 16);
657
658 std::vector<expr_t> bufs = attr.bufs;
659 std::vector<reg_buf_mask_t> buf_masks;
660 std::vector<int> buf_src_idx(bufs.size(), -1);
661 for (int i = 0; i < int(bufs.size()); i++) {
662 int buf_size = attr.buf_sizes[i];
663 int min_block_size = attr.buf_min_block_sizes[i];
664 int regs = utils::div_up(buf_size, hw_ctx.reg_size);
665 int block_regs = (bufs[i].is_equal(dst_base) ? dst_block_regs
666 : src_block_regs);
667 // Ensure that blocks are uniform, otherwise allocate as a single block.
668 if (block_regs != 0 && regs % block_regs != 0) block_regs = 0;
669 // Ensure that the block size is allowed (to avoid unhandled boundary
670 // crossing), otherwise allocate as a single block.
671 if (block_regs != 0 && block_regs * hw_ctx.reg_size < min_block_size)
672 block_regs = 0;
673 buf_masks.emplace_back(&hw_ctx, regs, block_regs);
674 }
675
676 auto create_reg = [&](const expr_t &e, int src_idx, int off_bytes) {
677 if (is_zero(e)) return reg_t();
678 auto base = get_base(e);
679 int off = 0;
680 if (!is_var(e)) off = to_cpp<int>(e.as<ptr_t>().off);
681 off += off_bytes;
682 for (size_t i = 0; i < bufs.size(); i++) {
683 if (base.is_same(bufs[i])) {
684 buf_src_idx[i] = src_idx;
685 return buf_masks[i].get_reg(off);
686 }
687 }
688 ir_error_not_expected();
689 return reg_t();
690 };
691
692 int hw_simd = hw_ctx.hw_simd();
693 std::vector<instruction_t> instructions;
694 for (auto &s : attr.instructions) {
695 auto &call = s.as<func_call_t>();
696 expr_t src0, src1, src2;
697 int simd = 0;
698 int src0_stride_bytes;
699 int src1_stride_bytes;
700 int src2_stride_bytes;
701 if (call.func.is<dpas_t>()) {
702 auto &dpas = call.func.as<dpas_t>();
703 simd = dpas.exec_size;
704 src0_stride_bytes = dpas.dst_type.size();
705 src1_stride_bytes = dpas.src1_type.size();
706 src2_stride_bytes = dpas.is_dp4a() ? 0 : dpas.src2_type.size();
707 src0 = dpas_t::arg_src0(call);
708 src1 = dpas_t::arg_src1(call);
709 src2 = dpas_t::arg_src2(call);
710 if (!dpas.is_dp4a()) ir_assert(simd == hw_simd);
711 } else if (call.func.is<mad_t>()) {
712 auto &mad = call.func.as<mad_t>();
713 simd = mad.exec_size;
714 src0_stride_bytes = mad.dst_type.size();
715 src1_stride_bytes = mad.src1_stride * mad.src1_type.size();
716 src2_stride_bytes = mad.src2_stride * mad.src2_type.size();
717 src0 = mad_t::arg_src0(call);
718 src1 = mad_t::arg_src1(call);
719 src2 = mad_t::arg_src2(call);
720 } else {
721 ir_error_not_expected();
722 }
723 for (int off = 0; off < simd; off += hw_simd) {
724 auto _src0 = create_reg(src0, 0, off * src0_stride_bytes);
725 auto _src1 = create_reg(src1, 1, off * src1_stride_bytes);
726 auto _src2 = create_reg(src2, 2, off * src2_stride_bytes);
727 instructions.emplace_back(_src0, _src1, _src2);
728 }
729 }
730
731 std::vector<reg_block_mask_t *> blocks;
732
733 for (size_t i = 0; i < bufs.size(); i++)
734 ir_assert(buf_src_idx[i] != -1)
735 << "Buffer is not referenced: " << bufs[i];
736
737 // Heuristic: search for register blocks in this order: src1, src2, src0.
738 for (int i : {1, 2, 0}) {
739 for (size_t j = 0; j < bufs.size(); j++) {
740 if (buf_src_idx[j] == i) {
741 for (auto &block : buf_masks[j].blocks)
742 blocks.push_back(&block);
743 }
744 }
745 }
746
747 auto reg_mask = create_available_reg_mask(ra, &hw_ctx);
748 search_context_t ctx(&hw_ctx, reg_mask, blocks, instructions);
749
750 if (is_dpas) ctx.set_check_diff_banks_src02();
751
752 bool found = false;
753
754 // First try to find an allocation with bundle check, if it fails check
755 // only for bank conflicts.
756 for (bool check_bundles : {true, false}) {
757 // dpas doesn't need bundle check.
758 if (is_dpas && check_bundles) continue;
759
760 ctx.reset_steps();
761 ctx.set_check_bundles(check_bundles);
762
763 ir_assert(ctx.saved_block_idx == 0);
764 ir_assert(ctx.reg_mask == reg_mask);
765
766#ifdef GEN_CONV_DEBUG
767 double search_time = get_msec();
768#endif
769 found = search(ctx);
770#ifdef GEN_CONV_DEBUG
771 search_time = get_msec() - search_time;
772 ir_trace() << "Bank conflict allocation:" << std::endl;
773 ir_trace() << " Search time: " << search_time << " ms" << std::endl;
774 ir_trace() << " Status: " << (found ? "OK" : "FAIL") << std::endl;
775 ir_trace() << " Steps: " << ctx.steps << std::endl;
776 ir_trace() << " Bundle check: "
777 << ir_utils::to_string(ctx.check_bundles) << std::endl;
778#endif
779 if (found) break;
780 }
781
782 bool was_claimed = false;
783 if (!found) {
784 // Can't find allocation without conflicts, use the fallback scheme:
785 // use different banks for src0 and src2.
786 int bank = -1;
787 for (size_t i = 0; i < bufs.size(); i++) {
788 bool is_src02 = utils::one_of(buf_src_idx[i], 0, 2);
789 int regs = buf_masks[i].regs;
790 // Always use single block buffer.
791 buf_masks[i] = reg_buf_mask_t(&hw_ctx, regs);
792 ngen::Bundle bundle;
793 // Choose the opposite bank for src0 or src2.
794 if (is_src02 && bank != -1)
795 bundle = ngen::Bundle(1 - bank, ngen::Bundle::any);
796 auto &mask = buf_masks[i].blocks[0].masks[0];
797 auto range = ra.alloc_range(regs, bundle);
798 int base = range[0].getBase();
799 if (is_src02 && bank == -1) bank = hw_ctx.reg_bank(base);
800 mask.reset();
801 mask.set(base);
802 }
803 was_claimed = true;
804 }
805
806 // Initialize register buffers with found assignment.
807 bank_conflict_allocation_t bca(ra);
808 for (size_t i = 0; i < bufs.size(); i++) {
809 int nblocks = buf_masks[i].nblocks();
810 std::vector<int> block_bases(nblocks);
811 for (int j = 0; j < nblocks; j++) {
812 int reg = buf_masks[i].blocks[j].masks[0].bsf();
813 block_bases[j] = reg;
814 }
815 reg_buf_t reg_buf(ra.hardware(), buf_masks[i].block_regs, block_bases);
816 if (!was_claimed) reg_buf.claim(ra);
817 bca.set_reg_buf(bufs[i], reg_buf);
818 }
819 return bca;
820}
821
822} // namespace jit
823} // namespace gpu
824} // namespace impl
825} // namespace dnnl
826
827#if defined(__GNUC__) && __GNUC__ == 7
828#pragma GCC diagnostic pop
829#endif
830