1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/codegen/bank_conflict_allocation.hpp" |
18 | |
19 | #include "common/verbose.hpp" |
20 | #include "gpu/jit/ir/fma.hpp" |
21 | #include "gpu/jit/utils/utils.hpp" |
22 | |
23 | #include <sstream> |
24 | #include <string> |
25 | #include <vector> |
26 | #include <initializer_list> |
27 | |
28 | #if defined(__GNUC__) && __GNUC__ == 7 |
29 | // GCC 7.x issues a false positive warning 'array subscript is above array bounds' |
30 | #pragma GCC diagnostic push |
31 | #pragma GCC diagnostic ignored "-Warray-bounds" |
32 | #endif |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace gpu { |
37 | namespace jit { |
38 | |
39 | namespace { |
40 | |
41 | // Helper structure to access HW-specific information. |
42 | struct hw_context_t { |
43 | hw_context_t(ngen::HW hw, int regs) |
44 | : hw(hw), regs(regs), reg_size(ngen::GRF::bytes(hw)) { |
45 | int bank0 = reg_bank(0); |
46 | for (int i = 1; i < regs; i++) |
47 | if (reg_bank(i) != bank0) { |
48 | reg_bank_stride = i; |
49 | break; |
50 | } |
51 | ir_assert(reg_bank_stride != -1); |
52 | |
53 | bank_masks.resize(ngen::Bundle::bank_count(hw)); |
54 | bundle_masks.resize(ngen::Bundle::bundle_count(hw)); |
55 | |
56 | for (int i = 0; i < regs; i++) { |
57 | int bank = reg_bank(i); |
58 | int bundle = reg_bundle(i); |
59 | if (i < 64) { |
60 | bank_masks[bank] |= (1ull << i); |
61 | bundle_masks[bundle] |= (1ull << i); |
62 | } else { |
63 | // Ensure bank/bundle pattern is repeated. |
64 | int j = (i % 64); |
65 | ir_assert((bank_masks[bank] & (1ull << j)) != 0); |
66 | ir_assert((bundle_masks[bundle] & (1ull << j)) != 0); |
67 | } |
68 | } |
69 | } |
70 | |
71 | int reg_bank(int reg) const { |
72 | auto bundle = ngen::Bundle::locate(hw, ngen::GRF(reg)); |
73 | return bundle.bank_id; |
74 | } |
75 | |
76 | int reg_bundle(int reg) const { |
77 | auto bundle = ngen::Bundle::locate(hw, ngen::GRF(reg)); |
78 | return bundle.bundle_id; |
79 | } |
80 | |
81 | int hw_simd() const { |
82 | switch (hw) { |
83 | case ngen::HW::Gen9: |
84 | case ngen::HW::Gen10: |
85 | case ngen::HW::Gen11: |
86 | case ngen::HW::XeLP: |
87 | case ngen::HW::XeHP: |
88 | case ngen::HW::XeHPG: return 8; |
89 | case ngen::HW::XeHPC: return 16; |
90 | default: ir_error_not_expected(); |
91 | } |
92 | return -1; |
93 | } |
94 | |
95 | ngen::HW hw; |
96 | int regs; // Number of registers. |
97 | int reg_size; // Size of register in bytes. |
98 | |
99 | // Stride in registers between different GRF banks. |
100 | int reg_bank_stride = -1; |
101 | |
102 | // 64-bit bitmasks for each bank/bundle. If i-th bit is set in the mask for |
103 | // B bank/bundle then i-th register belongs to B bank/bundle. Assume the |
104 | // pattern is repeated for the next 64 registers, etc. |
105 | std::vector<uint64_t> bank_masks; |
106 | std::vector<uint64_t> bundle_masks; |
107 | }; |
108 | |
109 | // Bitmask for registers, one bit per register. Many interfaces are named after |
110 | // std::bitset API. |
111 | struct reg_mask_t { |
112 | reg_mask_t() = default; |
113 | |
114 | reg_mask_t(const hw_context_t *hw_ctx, uint64_t chunk_mask = -1) |
115 | : hw_ctx(hw_ctx), nchunks(hw_ctx->regs / chunk_bits) { |
116 | for (int i = 0; i < nchunks; i++) |
117 | chunks[i] = chunk_mask; |
118 | } |
119 | |
120 | bool none() const { |
121 | uint64_t mask = 0; |
122 | for (int i = 0; i < nchunks; i++) |
123 | mask |= chunks[i]; |
124 | return mask == 0; |
125 | } |
126 | |
127 | bool test(int i) const { |
128 | int ichunk = i / chunk_bits; |
129 | int bit = i % chunk_bits; |
130 | return (chunks[ichunk] >> bit) & 0x1; |
131 | } |
132 | |
133 | // Returns true if all bits from [off, off + len - 1] range are not set. |
134 | bool is_unset(int off, int len) const; |
135 | |
136 | void set(int i, bool value = true) { |
137 | int ichunk = i / chunk_bits; |
138 | int bit = i % chunk_bits; |
139 | if (value) |
140 | chunks[ichunk] |= (1ull << bit); |
141 | else |
142 | chunks[ichunk] &= ~(1ull << bit); |
143 | } |
144 | |
145 | void set(int off, int len, bool value = true) { |
146 | for (int i = off; i < off + len; i++) |
147 | set(i, value); |
148 | } |
149 | |
150 | void reset() { |
151 | for (int i = 0; i < nchunks; i++) |
152 | chunks[i] = 0; |
153 | } |
154 | |
155 | // Returns the number of set register bits. |
156 | int count() const { |
157 | int ret = 0; |
158 | for (int i = 0; i < nchunks; i++) |
159 | ret += ngen::utils::popcnt(chunks[i]); |
160 | return ret; |
161 | } |
162 | |
163 | // Returns the index of the first set register bit. |
164 | int bsf() const { |
165 | for (int i = 0; i < hw_ctx->regs; i++) { |
166 | if (test(i)) return i; |
167 | } |
168 | return -1; |
169 | } |
170 | |
171 | // Returns the index of the last set register bit. |
172 | int bsr() const { |
173 | UNUSED(®_mask_t::bsr); |
174 | for (int i = hw_ctx->regs - 1; i >= 0; i--) { |
175 | if (test(i)) return i; |
176 | } |
177 | return -1; |
178 | } |
179 | |
180 | // Returns a mask where all bits in [off, off + len - 1] range are set and |
181 | // other bits are not set. |
182 | reg_mask_t range_mask(int off, int len) const { |
183 | reg_mask_t ret(hw_ctx); |
184 | ret = ret << (hw_ctx->regs - len); |
185 | ret = ret >> (hw_ctx->regs - len - off); |
186 | return ret; |
187 | } |
188 | |
189 | // Returns GRF bank for all set register bits if they share the same bank, |
190 | // otherwise returns -1. |
191 | int bank() const; |
192 | |
193 | void subtract(const reg_mask_t &other) { *this &= ~other; } |
194 | |
195 | reg_mask_t &operator&=(const reg_mask_t &other) { |
196 | for (int i = 0; i < nchunks; i++) |
197 | chunks[i] = chunks[i] & other.chunks[i]; |
198 | return *this; |
199 | } |
200 | |
201 | reg_mask_t &operator|=(const reg_mask_t &other) { |
202 | UNUSED(®_mask_t::operator|=); |
203 | for (int i = 0; i < nchunks; i++) |
204 | chunks[i] = chunks[i] | other.chunks[i]; |
205 | return *this; |
206 | } |
207 | |
208 | reg_mask_t operator<<(int shift) const { |
209 | int idx = shift / chunk_bits; |
210 | int bit = shift % chunk_bits; |
211 | reg_mask_t ret(hw_ctx, 0); |
212 | for (int i = idx + 1; i < nchunks; i++) { |
213 | auto c0 = (chunks[i - idx] << bit); |
214 | auto c1 = (bit == 0 ? 0 |
215 | : (chunks[i - idx - 1] >> (chunk_bits - bit))); |
216 | ret.chunks[i] = c0 | c1; |
217 | } |
218 | ret.chunks[idx] = (chunks[0] << bit); |
219 | return ret; |
220 | } |
221 | |
222 | reg_mask_t operator>>(int shift) const { |
223 | int idx = shift / chunk_bits; |
224 | int bit = shift % chunk_bits; |
225 | reg_mask_t ret(hw_ctx, 0); |
226 | for (int i = 0; i + idx + 1 < nchunks; i++) { |
227 | auto c0 = (chunks[i + idx] >> bit); |
228 | auto c1 = (bit == 0 ? 0 |
229 | : (chunks[i + idx + 1] << (chunk_bits - bit))); |
230 | ret.chunks[i] = c0 | c1; |
231 | } |
232 | ret.chunks[nchunks - idx - 1] = (chunks[nchunks - 1] >> bit); |
233 | return ret; |
234 | } |
235 | |
236 | bool operator==(const reg_mask_t &other) const { |
237 | for (int i = 0; i < nchunks; i++) |
238 | if (chunks[i] != other.chunks[i]) return false; |
239 | return true; |
240 | } |
241 | |
242 | reg_mask_t operator~() const { |
243 | reg_mask_t ret(hw_ctx); |
244 | for (int i = 0; i < nchunks; i++) |
245 | ret.chunks[i] = ~chunks[i]; |
246 | return ret; |
247 | } |
248 | |
249 | std::string str() const { |
250 | UNUSED(®_mask_t::str); |
251 | std::ostringstream oss; |
252 | for (int i = hw_ctx->regs - 1; i >= 0; i--) { |
253 | oss << (test(i) ? "1" : "0" ); |
254 | } |
255 | return oss.str(); |
256 | } |
257 | |
258 | IR_DEFINE_DUMP() |
259 | |
260 | static const int chunk_bits = 64; |
261 | static const int max_regs = 256; |
262 | static const int max_nchunks = max_regs / chunk_bits; |
263 | |
264 | const hw_context_t *hw_ctx = nullptr; |
265 | int nchunks = 0; |
266 | uint64_t chunks[max_nchunks] = {0}; |
267 | }; |
268 | |
269 | inline reg_mask_t operator&(const reg_mask_t &a, const reg_mask_t &b) { |
270 | auto ret = a; |
271 | return ret &= b; |
272 | } |
273 | |
274 | inline bool reg_mask_t::is_unset(int off, int len) const { |
275 | auto m = range_mask(off, len); |
276 | return m == (*this & m); |
277 | } |
278 | |
279 | inline int reg_mask_t::bank() const { |
280 | if (*this == (*this & reg_mask_t(hw_ctx, hw_ctx->bank_masks[0]))) return 0; |
281 | if (*this == (*this & reg_mask_t(hw_ctx, hw_ctx->bank_masks[1]))) return 1; |
282 | return -1; |
283 | } |
284 | |
285 | // Represents a compound mask for a contiguous block of registers. For each |
286 | // register in the block its mask describes potential candidates for the |
287 | // register. |
288 | struct reg_block_mask_t { |
289 | reg_block_mask_t() = default; |
290 | |
291 | reg_block_mask_t(const hw_context_t *hw_ctx, int regs) : regs(regs) { |
292 | masks.reserve(regs); |
293 | for (int i = 0; i < regs; i++) |
294 | masks.emplace_back(hw_ctx); |
295 | |
296 | auto &mask0 = masks[0]; |
297 | |
298 | // Align all blocks to a GRF bank boundary. |
299 | int step = hw_ctx->reg_bank_stride; |
300 | for (int i = 0; i < hw_ctx->regs; i += step) { |
301 | for (int j = i + 1; j < i + step; j++) { |
302 | mask0.set(j, false); |
303 | } |
304 | } |
305 | // Exclude base registers that result in crossing the last register. |
306 | for (int i = hw_ctx->regs - regs + 1; i < hw_ctx->regs; i++) { |
307 | mask0.set(i, false); |
308 | } |
309 | // Update other masks. |
310 | propagate_masks(); |
311 | } |
312 | |
313 | void exclude(const reg_mask_t &mask) { |
314 | for (auto &m : masks) |
315 | m.subtract(mask); |
316 | } |
317 | |
318 | bool can_be_assigned() const { |
319 | for (auto &m : masks) |
320 | if (m.none()) return false; |
321 | return true; |
322 | } |
323 | |
324 | bool is_assigned() const { return masks[0].count() == 1; } |
325 | |
326 | void propagate_masks() { |
327 | // Limit the first register mask based on other register masks. |
328 | for (int j = 1; j < regs; j++) { |
329 | masks[0] &= (masks[j] >> j); |
330 | } |
331 | // Propagate back. |
332 | for (int j = 1; j < regs; j++) { |
333 | masks[j] &= (masks[0] << j); |
334 | } |
335 | } |
336 | |
337 | std::string str() const { |
338 | UNUSED(®_block_mask_t::str); |
339 | std::ostringstream oss; |
340 | for (int i = 0; i < regs; i++) { |
341 | oss << "#" << i << " mask: " << masks[i].str(); |
342 | if (i != regs - 1) oss << std::endl; |
343 | } |
344 | return oss.str(); |
345 | } |
346 | |
347 | IR_DEFINE_DUMP() |
348 | |
349 | int regs; |
350 | std::vector<reg_mask_t> masks; |
351 | }; |
352 | |
353 | // Represents a single register in a register block. |
354 | struct reg_t { |
355 | reg_t() = default; |
356 | |
357 | reg_t(reg_block_mask_t *block, int off) : block(block), off(off) {} |
358 | |
359 | bool is_empty() const { return !block; } |
360 | |
361 | int bank() const { |
362 | if (is_empty()) return -1; |
363 | return block->masks[off].bank(); |
364 | } |
365 | |
366 | void exclude(const reg_mask_t &mask) { |
367 | if (is_empty()) return; |
368 | block->masks[off].subtract(mask); |
369 | } |
370 | |
371 | bool operator==(const reg_t &other) const { |
372 | return (other.block == block) && (other.off == off); |
373 | } |
374 | |
375 | std::string str() const { |
376 | UNUSED(®_t::str); |
377 | if (is_empty()) return "null" ; |
378 | std::ostringstream oss; |
379 | if (block->is_assigned()) { |
380 | int reg = block->masks[off].bsf(); |
381 | oss << "r" << reg; |
382 | } else { |
383 | oss << "R" << off; |
384 | } |
385 | return oss.str(); |
386 | } |
387 | |
388 | IR_DEFINE_DUMP() |
389 | |
390 | reg_block_mask_t *block = nullptr; |
391 | int off = -1; |
392 | }; |
393 | |
394 | // Mask for a blocked register buffer. Buffer consists of blocks: B0, B1, ... |
395 | // Each block B(i) has the same size and is contiguous inside, B(i) and B(i+1) |
396 | // are not necessarily contiguous. |
397 | struct reg_buf_mask_t { |
398 | reg_buf_mask_t(const hw_context_t *hw_ctx, int regs, int block_regs = 0) |
399 | : hw_ctx(hw_ctx), regs(regs), block_regs(block_regs) { |
400 | if (block_regs == 0) this->block_regs = regs; |
401 | ir_assert(regs % this->block_regs == 0); |
402 | for (int i = 0; i < nblocks(); i++) { |
403 | blocks.emplace_back(hw_ctx, this->block_regs); |
404 | } |
405 | } |
406 | |
407 | // Size in bytes. |
408 | int size() const { return regs * hw_ctx->reg_size; } |
409 | |
410 | int nblocks() const { return regs / block_regs; } |
411 | |
412 | reg_t get_reg(int off_bytes) { |
413 | ir_assert(off_bytes < size()); |
414 | off_bytes /= hw_ctx->reg_size; |
415 | int block_idx = off_bytes / block_regs; |
416 | int reg_idx = off_bytes % block_regs; |
417 | return reg_t(&blocks[block_idx], reg_idx); |
418 | } |
419 | |
420 | const hw_context_t *hw_ctx; |
421 | int regs; // Number of registers in the buffer. |
422 | int block_regs; // Number of registers in one block. |
423 | std::vector<reg_block_mask_t> blocks; |
424 | }; |
425 | |
426 | // Represents a 3-src instruction. |
427 | struct instruction_t { |
428 | instruction_t(const reg_t &src0, const reg_t &src1, const reg_t &src2) |
429 | : src0(src0), src1(src1), src2(src2) {} |
430 | |
431 | bool has(const reg_t ®) const { |
432 | return reg == src0 || reg == src1 || reg == src2; |
433 | } |
434 | |
435 | reg_t src0; |
436 | reg_t src1; |
437 | reg_t src2; |
438 | }; |
439 | |
440 | // Helper structure for GRF assignment search. |
441 | struct search_context_t { |
442 | search_context_t(const hw_context_t *hw_ctx, const reg_mask_t ®_mask, |
443 | std::vector<reg_block_mask_t *> &blocks, |
444 | const std::vector<instruction_t> &instructions) |
445 | : hw_ctx(hw_ctx) |
446 | , reg_mask(reg_mask) |
447 | , blocks(blocks) |
448 | , instructions(instructions) { |
449 | saved_blocks.resize(nblocks() * nblocks()); |
450 | } |
451 | |
452 | int nblocks() { return int(blocks.size()); } |
453 | |
454 | void set_check_bundles(bool value = true) { check_bundles = value; } |
455 | |
456 | void set_check_diff_banks_src02(bool value = true) { |
457 | check_diff_banks_src02 = value; |
458 | } |
459 | |
460 | // Saves block masks for the current recursion level. |
461 | void save_blocks() { |
462 | ir_assert(saved_block_idx + nblocks() <= int(saved_blocks.size())); |
463 | for (int i = 0; i < nblocks(); i++) { |
464 | saved_blocks[saved_block_idx + i] = *blocks[i]; |
465 | } |
466 | saved_block_idx += nblocks(); |
467 | steps++; |
468 | } |
469 | |
470 | // Restores saved block masks. |
471 | void restore_blocks() { |
472 | saved_block_idx -= nblocks(); |
473 | ir_assert(saved_block_idx >= 0); |
474 | for (int i = 0; i < nblocks(); i++) { |
475 | *blocks[i] = saved_blocks[saved_block_idx + i]; |
476 | } |
477 | } |
478 | |
479 | bool should_stop() const { |
480 | int max_steps = 250; |
481 | return steps > max_steps; |
482 | } |
483 | |
484 | void reset_steps() { steps = 0; } |
485 | |
486 | const hw_context_t *hw_ctx; |
487 | |
488 | int steps = 0; |
489 | reg_mask_t reg_mask; |
490 | std::vector<reg_block_mask_t *> blocks; |
491 | std::vector<instruction_t> instructions; |
492 | |
493 | int saved_block_idx = 0; |
494 | std::vector<reg_block_mask_t> saved_blocks; |
495 | |
496 | // Whether to require bundle check. |
497 | bool check_bundles = false; |
498 | |
499 | // Whether to require src0 and src2 to be in different banks (dpas-specific). |
500 | bool check_diff_banks_src02 = false; |
501 | }; |
502 | |
503 | bool search(search_context_t &ctx, int block_idx = 0) { |
504 | // All blocks are assigned, success. |
505 | if (block_idx >= ctx.nblocks()) return true; |
506 | |
507 | auto *hw_ctx = ctx.hw_ctx; |
508 | |
509 | auto &block = *ctx.blocks[block_idx]; |
510 | auto &mask0 = block.masks[0]; |
511 | |
512 | // 1. Assign i-th register for the current block base |
513 | // 2. Update register constraints for other blocks |
514 | // 3. If the remaining blocks still can be assigned, move to the next |
515 | // block. Otherwise try the next register in step 1. |
516 | for (int i = 0; i < ctx.hw_ctx->regs; i++) { |
517 | if (!mask0.test(i)) continue; |
518 | if (!ctx.reg_mask.is_unset(i, block.regs)) continue; |
519 | // Stop the search if it takes too many steps. |
520 | if (ctx.should_stop()) return false; |
521 | |
522 | // Try to assign the current block to i-th register. |
523 | ctx.save_blocks(); |
524 | |
525 | // Claim the register block. |
526 | ctx.reg_mask.set(i, block.regs, false); |
527 | reg_mask_t i_reg_mask(hw_ctx, 0); |
528 | for (int j = 0; j < block.regs; j++) { |
529 | block.masks[j].reset(); |
530 | block.masks[j].set(i + j); |
531 | i_reg_mask.set(i + j); |
532 | } |
533 | |
534 | bool conflicts_ok = true; |
535 | |
536 | // Exclude the new region from the remaining masks. |
537 | for (int j = block_idx + 1; j < ctx.nblocks(); j++) { |
538 | ctx.blocks[j]->exclude(i_reg_mask); |
539 | if (!ctx.blocks[j]->can_be_assigned()) { |
540 | conflicts_ok = false; |
541 | break; |
542 | } |
543 | } |
544 | |
545 | // Update constraints according to register usages in instructions. |
546 | std::vector<reg_mask_t> bundle_masks; |
547 | bundle_masks.reserve(block.regs); |
548 | for (int j = 0; j < block.regs; j++) { |
549 | int bundle = hw_ctx->reg_bundle(i + j); |
550 | bundle_masks.emplace_back(hw_ctx, hw_ctx->bundle_masks[bundle]); |
551 | } |
552 | |
553 | for (auto &insn : ctx.instructions) { |
554 | for (int j = 0; j < block.regs; j++) { |
555 | reg_t j_reg(&block, j); |
556 | if (!insn.has(j_reg)) continue; |
557 | |
558 | int bank0 = insn.src0.bank(); |
559 | int bank1 = insn.src1.bank(); |
560 | int bank2 = insn.src2.bank(); |
561 | |
562 | if (ctx.check_diff_banks_src02) { |
563 | if (bank0 != -1 && bank0 == bank2) { |
564 | conflicts_ok = false; |
565 | break; |
566 | } |
567 | } |
568 | |
569 | // Handle bank conflict condition. |
570 | if (bank0 != -1 && bank0 == bank1 && bank1 == bank2) { |
571 | conflicts_ok = false; |
572 | break; |
573 | } |
574 | |
575 | // Handle bundle conflict condition. |
576 | if (ctx.check_bundles) { |
577 | for (auto *reg : {&insn.src0, &insn.src1, &insn.src2}) { |
578 | if (*reg == j_reg) continue; |
579 | reg->exclude(bundle_masks[j]); |
580 | } |
581 | } |
582 | break; |
583 | } |
584 | } |
585 | |
586 | if (conflicts_ok) { |
587 | for (auto *b : ctx.blocks) { |
588 | b->propagate_masks(); |
589 | if (!b->can_be_assigned()) { |
590 | conflicts_ok = false; |
591 | break; |
592 | } |
593 | } |
594 | } |
595 | |
596 | if (conflicts_ok) { |
597 | bool ok = search(ctx, block_idx + 1); |
598 | if (ok) return true; |
599 | } |
600 | |
601 | // Release the register block, move to the next candidate. |
602 | ctx.reg_mask.set(i, block.regs, true); |
603 | ctx.restore_blocks(); |
604 | } |
605 | return false; |
606 | } |
607 | |
608 | reg_mask_t create_available_reg_mask( |
609 | reg_allocator_t &ra, const hw_context_t *hw_ctx) { |
610 | reg_mask_t reg_mask(hw_ctx, 0); |
611 | ra.start_speculate(); |
612 | |
613 | // Query the allocator to get information about free registers. |
614 | for (;;) { |
615 | auto grf = ra.try_alloc(); |
616 | if (grf.isInvalid()) break; |
617 | reg_mask.set(grf.getBase()); |
618 | } |
619 | |
620 | for (int i = 0; i < hw_ctx->regs; i++) { |
621 | if (reg_mask.test(i)) { |
622 | ngen::GRF grf(i); |
623 | ra.safeRelease(grf); |
624 | } |
625 | } |
626 | |
627 | ra.finish_speculate(); |
628 | return reg_mask; |
629 | } |
630 | |
631 | } // namespace |
632 | |
633 | bank_conflict_allocation_t bank_conflict_allocation_t::create( |
634 | reg_allocator_t &ra, int regs, const bank_conflict_attr_t &attr) { |
635 | hw_context_t hw_ctx(ra.hardware(), regs); |
636 | |
637 | bool is_dpas = false; |
638 | bool is_dp4a = false; |
639 | expr_t dst_base; |
640 | if (!attr.instructions.empty()) { |
641 | auto &s = attr.instructions[0]; |
642 | auto &func = s.as<func_call_t>().func; |
643 | if (func.is<dpas_t>()) { |
644 | is_dp4a = func.as<dpas_t>().is_dp4a(); |
645 | is_dpas = !is_dp4a; |
646 | dst_base = get_base(dpas_t::arg_dst(s)); |
647 | } else if (func.is<mad_t>()) { |
648 | dst_base = get_base(mad_t::arg_dst(s)); |
649 | } else { |
650 | ir_error_not_expected(); |
651 | } |
652 | } |
653 | |
654 | // Heuristics for src/dst block sizes. |
655 | int dst_block_regs = (is_dpas ? 0 : 2); |
656 | int src_block_regs = (is_dpas || is_dp4a ? 0 : 16); |
657 | |
658 | std::vector<expr_t> bufs = attr.bufs; |
659 | std::vector<reg_buf_mask_t> buf_masks; |
660 | std::vector<int> buf_src_idx(bufs.size(), -1); |
661 | for (int i = 0; i < int(bufs.size()); i++) { |
662 | int buf_size = attr.buf_sizes[i]; |
663 | int min_block_size = attr.buf_min_block_sizes[i]; |
664 | int regs = utils::div_up(buf_size, hw_ctx.reg_size); |
665 | int block_regs = (bufs[i].is_equal(dst_base) ? dst_block_regs |
666 | : src_block_regs); |
667 | // Ensure that blocks are uniform, otherwise allocate as a single block. |
668 | if (block_regs != 0 && regs % block_regs != 0) block_regs = 0; |
669 | // Ensure that the block size is allowed (to avoid unhandled boundary |
670 | // crossing), otherwise allocate as a single block. |
671 | if (block_regs != 0 && block_regs * hw_ctx.reg_size < min_block_size) |
672 | block_regs = 0; |
673 | buf_masks.emplace_back(&hw_ctx, regs, block_regs); |
674 | } |
675 | |
676 | auto create_reg = [&](const expr_t &e, int src_idx, int off_bytes) { |
677 | if (is_zero(e)) return reg_t(); |
678 | auto base = get_base(e); |
679 | int off = 0; |
680 | if (!is_var(e)) off = to_cpp<int>(e.as<ptr_t>().off); |
681 | off += off_bytes; |
682 | for (size_t i = 0; i < bufs.size(); i++) { |
683 | if (base.is_same(bufs[i])) { |
684 | buf_src_idx[i] = src_idx; |
685 | return buf_masks[i].get_reg(off); |
686 | } |
687 | } |
688 | ir_error_not_expected(); |
689 | return reg_t(); |
690 | }; |
691 | |
692 | int hw_simd = hw_ctx.hw_simd(); |
693 | std::vector<instruction_t> instructions; |
694 | for (auto &s : attr.instructions) { |
695 | auto &call = s.as<func_call_t>(); |
696 | expr_t src0, src1, src2; |
697 | int simd = 0; |
698 | int src0_stride_bytes; |
699 | int src1_stride_bytes; |
700 | int src2_stride_bytes; |
701 | if (call.func.is<dpas_t>()) { |
702 | auto &dpas = call.func.as<dpas_t>(); |
703 | simd = dpas.exec_size; |
704 | src0_stride_bytes = dpas.dst_type.size(); |
705 | src1_stride_bytes = dpas.src1_type.size(); |
706 | src2_stride_bytes = dpas.is_dp4a() ? 0 : dpas.src2_type.size(); |
707 | src0 = dpas_t::arg_src0(call); |
708 | src1 = dpas_t::arg_src1(call); |
709 | src2 = dpas_t::arg_src2(call); |
710 | if (!dpas.is_dp4a()) ir_assert(simd == hw_simd); |
711 | } else if (call.func.is<mad_t>()) { |
712 | auto &mad = call.func.as<mad_t>(); |
713 | simd = mad.exec_size; |
714 | src0_stride_bytes = mad.dst_type.size(); |
715 | src1_stride_bytes = mad.src1_stride * mad.src1_type.size(); |
716 | src2_stride_bytes = mad.src2_stride * mad.src2_type.size(); |
717 | src0 = mad_t::arg_src0(call); |
718 | src1 = mad_t::arg_src1(call); |
719 | src2 = mad_t::arg_src2(call); |
720 | } else { |
721 | ir_error_not_expected(); |
722 | } |
723 | for (int off = 0; off < simd; off += hw_simd) { |
724 | auto _src0 = create_reg(src0, 0, off * src0_stride_bytes); |
725 | auto _src1 = create_reg(src1, 1, off * src1_stride_bytes); |
726 | auto _src2 = create_reg(src2, 2, off * src2_stride_bytes); |
727 | instructions.emplace_back(_src0, _src1, _src2); |
728 | } |
729 | } |
730 | |
731 | std::vector<reg_block_mask_t *> blocks; |
732 | |
733 | for (size_t i = 0; i < bufs.size(); i++) |
734 | ir_assert(buf_src_idx[i] != -1) |
735 | << "Buffer is not referenced: " << bufs[i]; |
736 | |
737 | // Heuristic: search for register blocks in this order: src1, src2, src0. |
738 | for (int i : {1, 2, 0}) { |
739 | for (size_t j = 0; j < bufs.size(); j++) { |
740 | if (buf_src_idx[j] == i) { |
741 | for (auto &block : buf_masks[j].blocks) |
742 | blocks.push_back(&block); |
743 | } |
744 | } |
745 | } |
746 | |
747 | auto reg_mask = create_available_reg_mask(ra, &hw_ctx); |
748 | search_context_t ctx(&hw_ctx, reg_mask, blocks, instructions); |
749 | |
750 | if (is_dpas) ctx.set_check_diff_banks_src02(); |
751 | |
752 | bool found = false; |
753 | |
754 | // First try to find an allocation with bundle check, if it fails check |
755 | // only for bank conflicts. |
756 | for (bool check_bundles : {true, false}) { |
757 | // dpas doesn't need bundle check. |
758 | if (is_dpas && check_bundles) continue; |
759 | |
760 | ctx.reset_steps(); |
761 | ctx.set_check_bundles(check_bundles); |
762 | |
763 | ir_assert(ctx.saved_block_idx == 0); |
764 | ir_assert(ctx.reg_mask == reg_mask); |
765 | |
766 | #ifdef GEN_CONV_DEBUG |
767 | double search_time = get_msec(); |
768 | #endif |
769 | found = search(ctx); |
770 | #ifdef GEN_CONV_DEBUG |
771 | search_time = get_msec() - search_time; |
772 | ir_trace() << "Bank conflict allocation:" << std::endl; |
773 | ir_trace() << " Search time: " << search_time << " ms" << std::endl; |
774 | ir_trace() << " Status: " << (found ? "OK" : "FAIL" ) << std::endl; |
775 | ir_trace() << " Steps: " << ctx.steps << std::endl; |
776 | ir_trace() << " Bundle check: " |
777 | << ir_utils::to_string(ctx.check_bundles) << std::endl; |
778 | #endif |
779 | if (found) break; |
780 | } |
781 | |
782 | bool was_claimed = false; |
783 | if (!found) { |
784 | // Can't find allocation without conflicts, use the fallback scheme: |
785 | // use different banks for src0 and src2. |
786 | int bank = -1; |
787 | for (size_t i = 0; i < bufs.size(); i++) { |
788 | bool is_src02 = utils::one_of(buf_src_idx[i], 0, 2); |
789 | int regs = buf_masks[i].regs; |
790 | // Always use single block buffer. |
791 | buf_masks[i] = reg_buf_mask_t(&hw_ctx, regs); |
792 | ngen::Bundle bundle; |
793 | // Choose the opposite bank for src0 or src2. |
794 | if (is_src02 && bank != -1) |
795 | bundle = ngen::Bundle(1 - bank, ngen::Bundle::any); |
796 | auto &mask = buf_masks[i].blocks[0].masks[0]; |
797 | auto range = ra.alloc_range(regs, bundle); |
798 | int base = range[0].getBase(); |
799 | if (is_src02 && bank == -1) bank = hw_ctx.reg_bank(base); |
800 | mask.reset(); |
801 | mask.set(base); |
802 | } |
803 | was_claimed = true; |
804 | } |
805 | |
806 | // Initialize register buffers with found assignment. |
807 | bank_conflict_allocation_t bca(ra); |
808 | for (size_t i = 0; i < bufs.size(); i++) { |
809 | int nblocks = buf_masks[i].nblocks(); |
810 | std::vector<int> block_bases(nblocks); |
811 | for (int j = 0; j < nblocks; j++) { |
812 | int reg = buf_masks[i].blocks[j].masks[0].bsf(); |
813 | block_bases[j] = reg; |
814 | } |
815 | reg_buf_t reg_buf(ra.hardware(), buf_masks[i].block_regs, block_bases); |
816 | if (!was_claimed) reg_buf.claim(ra); |
817 | bca.set_reg_buf(bufs[i], reg_buf); |
818 | } |
819 | return bca; |
820 | } |
821 | |
822 | } // namespace jit |
823 | } // namespace gpu |
824 | } // namespace impl |
825 | } // namespace dnnl |
826 | |
827 | #if defined(__GNUC__) && __GNUC__ == 7 |
828 | #pragma GCC diagnostic pop |
829 | #endif |
830 | |