1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/conv/grf_usage.hpp"
18
19#include <sstream>
20
21#include "gpu/jit/codegen/register_allocator.hpp"
22#include "gpu/jit/ir/message.hpp"
23#include "gpu/jit/ir/reorder.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30std::string to_string(grf_usage_label_t label) {
31 switch (label) {
32#define CASE(l) \
33 case grf_usage_label_t::l: return #l;
34 CASE(unknown)
35 CASE(gmem_load)
36 CASE(out_buf)
37 CASE(reorder)
38 CASE(reserved)
39 CASE(reused_headers)
40 CASE(slm_load)
41 CASE(slm_store)
42 CASE(tmp_vars)
43 CASE(zero_points)
44#undef CASE
45 default: ir_error_not_expected();
46 }
47 return "";
48}
49
50std::ostream &operator<<(std::ostream &out, grf_usage_label_t label) {
51 out << to_string(label);
52 return out;
53}
54
55std::string grf_buf_usage_t::str() const {
56 std::ostringstream oss;
57 oss << "Buffers:";
58 for (auto label : all_grf_usage_labels()) {
59 int regs = total_regs(label);
60 if (regs == 0) continue;
61 oss << std::endl << " " << label << " (" << regs << "): ";
62 bool is_first = true;
63 for (auto &buf : sorted_bufs()) {
64 if (get_label(buf) != label) continue;
65 if (!is_first) oss << ", ";
66 is_first = false;
67 oss << buf << "[" << get_size(buf) << "]";
68 }
69 }
70 return oss.str();
71}
72
73std::string grf_usage_t::str() const {
74 std::vector<std::string> headers = {"Label", "Regs"};
75 ir_utils::table_t table("GRF usage (registers):", headers);
76 int total = 0;
77 for (auto label : all_grf_usage_labels()) {
78 int regs = regs_.at(label);
79 if (regs == 0) continue;
80 table << " " + to_string(label) << regs << std::endl;
81 total += regs;
82 }
83 table << " Total" << total << std::endl;
84 std::ostringstream oss;
85 oss << table << std::endl;
86 oss << buf_usage_;
87 return oss.str();
88}
89
90class access_grf_usage_helper_t {
91public:
92 access_grf_usage_helper_t(const layout_t &mem_layout, int elems,
93 int reg_bytes, bool is_slm, bool use_2d_send)
94 : mem_type_size_(mem_layout.type().size())
95 , reg_bytes_(reg_bytes)
96 , is_slm_(is_slm)
97 , use_2d_send_(use_2d_send) {
98 init_message_size(mem_layout);
99 init_payload_size(elems);
100 init_header_size();
101 }
102
103 // This setting is related to dpasw loads. dpasw reuses registers between
104 // fused threads so each of the fused threads need to load only half of the
105 // data it will access.
106 void enable_fused_eus_sharing() { enabled_fused_eus_sharing_ = true; }
107
108 int payload_regs() const {
109 int ret = payload_size_ / reg_bytes_;
110 if (enabled_fused_eus_sharing_) ret = utils::div_up(ret, 2);
111 return ret;
112 }
113
114 int header_regs_per_msg() const {
115 return header_size_per_msg_ / reg_bytes_;
116 }
117
118 int header_regs() const {
119 int ret = nmsgs_ * header_regs_per_msg();
120 if (enabled_fused_eus_sharing_) ret = utils::div_up(ret, 2);
121 return ret;
122 }
123
124private:
125 void init_message_size(const layout_t &mem_layout) {
126 auto l = mem_layout.innermost_block_layout();
127 int block_bytes = (is_slm_ ? oword_bytes_ : hword_bytes_);
128 int max_block_bytes = (is_slm_ ? 16 * oword_bytes_ : 8 * hword_bytes_);
129 block_t b0;
130 int b0_size = mem_type_size_;
131 auto &mem_blocks = mem_layout.blocks();
132 if (!mem_blocks.empty()) {
133 b0 = mem_blocks[0];
134 b0_size = b0.block * mem_type_size_;
135 }
136 if (use_2d_send_) {
137 is_block_ = true;
138 // It's hard to determine 2D block message decomposition at this
139 // point but in general 2D block messages are larger so use 2x of a
140 // regular block message (empirical estimate).
141 msg_size_ = 2 * max_block_bytes;
142 payload_bytes_per_elem_ = mem_type_size_;
143 } else if (l.size() % block_bytes == 0) {
144 is_block_ = true;
145 msg_size_ = (l.size() % max_block_bytes == 0) ? max_block_bytes
146 : block_bytes;
147 payload_bytes_per_elem_ = mem_type_size_;
148 } else if (!b0.is_empty() && b0_size % block_bytes == 0) {
149 is_block_ = true;
150 msg_size_ = block_bytes;
151 payload_bytes_per_elem_ = mem_type_size_;
152 } else {
153 ir_assert(!is_slm_) << "Unexpected scattered messages with SLM.";
154 // Assume scattered byte SIMD16 load as the worst case. Check if
155 // we can use byte x {1,2,4} messages.
156 int slots = 16;
157 int bytes_per_slot = 4;
158 for (int x : {4, 2, 1}) {
159 if (x < bytes_per_slot && mem_type_size_ != x) continue;
160 if (b0_size % x == 0) {
161 msg_size_ = slots * x;
162 payload_bytes_per_elem_
163 = mem_type_size_ * (bytes_per_slot / x);
164 break;
165 }
166 }
167 ir_assert(msg_size_ > 0);
168 }
169 }
170
171 void init_payload_size(int elems) {
172 int elems_per_msg = utils::div_up(msg_size_, mem_type_size_);
173 int payload_per_msg = elems_per_msg * payload_bytes_per_elem_;
174 int payload_per_msg_grf_aligned
175 = utils::rnd_up(payload_per_msg, reg_bytes_);
176 nmsgs_ = utils::div_up(elems * mem_type_size_, msg_size_);
177 payload_size_ = nmsgs_ * payload_per_msg_grf_aligned;
178 }
179
180 void init_header_size() {
181 if (is_block_) {
182 // One register per header for block messages.
183 header_size_per_msg_ = reg_bytes_;
184 } else {
185 // Assume SIMD16 with A64 address model.
186 int slots = 16;
187 int bytes_per_slot = sizeof(uint64_t);
188 header_size_per_msg_
189 = utils::rnd_up(slots * bytes_per_slot, reg_bytes_);
190 }
191 }
192
193 static const int oword_bytes_ = 16;
194 static const int hword_bytes_ = 32;
195
196 int mem_type_size_ = 0;
197 int reg_bytes_ = 0;
198 bool is_slm_ = false;
199 bool use_2d_send_ = false;
200 bool enabled_fused_eus_sharing_ = false;
201
202 // Whether message is block or scattered.
203 bool is_block_ = false;
204
205 // Amount of memory that can be read by a single message from global memory.
206 int msg_size_ = 0;
207
208 // How many bytes are occupied by a single element in the message payload.
209 int payload_bytes_per_elem_ = 0;
210
211 // Size of GRF buffers for all messages to load data.
212 int payload_size_ = 0;
213
214 // Number of messages to load data.
215 int nmsgs_ = 0;
216
217 // Size of header buffer per message.
218 int header_size_per_msg_ = 0;
219};
220
221// Helper class to provide GRF usage estimation.
222class grf_usage_helper_t {
223public:
224 grf_usage_helper_t(const conv_config_t &cfg) : prb_(cfg.prb()), cfg_(cfg) {
225 auto &tg_grid = cfg_.thread_group_grid();
226
227 reg_bytes_ = cfg_.grf_size();
228 tg_size_ = tg_grid.elems();
229
230 bmnk_dim_helper_t h(cfg_);
231 m_tg_dim_ = h.thread_group_dim('m');
232 n_tg_dim_ = h.thread_group_dim('n');
233
234 int b_iter_blk = h.iter_dim('b');
235 int m_iter_blk = h.iter_dim('m');
236 int n_iter_blk = h.iter_dim('n');
237 int k_iter_blk = h.iter_dim('k');
238
239 if (!cfg_.ow_kw_grf_cache()) {
240 a_thr_elems_ = b_iter_blk * m_iter_blk * k_iter_blk;
241 } else {
242 ir_assert(!cfg_.slm().a());
243 int a_m_blk = (prb_.sw * (m_iter_blk - 1)
244 + (prb_.kw - 1) * (1 + prb_.dw) + 1);
245 int a_k_blk = utils::div_up(k_iter_blk, prb_.kw);
246 a_thr_elems_ = b_iter_blk * a_m_blk * a_k_blk;
247 }
248
249 b_thr_elems_ = b_iter_blk * n_iter_blk * k_iter_blk;
250 c_thr_elems_ = b_iter_blk * m_iter_blk * n_iter_blk;
251 a_tg_elems_ = a_thr_elems_ * m_tg_dim_;
252 b_tg_elems_ = b_thr_elems_ * n_tg_dim_;
253 a_subtile_elems_ = utils::div_up(a_thr_elems_, cfg_.subtiles().a());
254 b_subtile_elems_ = utils::div_up(b_thr_elems_, cfg_.subtiles().b());
255 can_reliably_use_dpasw_ = can_reliably_use_dpasw(h);
256 }
257
258 grf_usage_t estimate() const {
259 int max_reuse_header_regs = 0;
260 int a_slm_store_payload_regs = 0;
261 int b_slm_store_payload_regs = 0;
262
263 int c_buf_usage = estimate_c_buf_usage();
264 int gmem_load_usage = estimate_gmem_load_usage(max_reuse_header_regs);
265 int slm_store_usage = estimate_slm_store_usage(a_slm_store_payload_regs,
266 b_slm_store_payload_regs, max_reuse_header_regs);
267 int slm_load_usage = estimate_slm_load_usage(max_reuse_header_regs);
268 int reorder_usage = estimate_reorder_usage(
269 a_slm_store_payload_regs, b_slm_store_payload_regs);
270 int zp_usage = estimate_zero_point_usage();
271
272 grf_usage_t info(cfg_.grf_size());
273 info.add(grf_usage_label_t::out_buf, c_buf_usage);
274 info.add(grf_usage_label_t::gmem_load, gmem_load_usage);
275 info.add(grf_usage_label_t::slm_store, slm_store_usage);
276 info.add(grf_usage_label_t::slm_load, slm_load_usage);
277 info.add(grf_usage_label_t::reorder, reorder_usage);
278 info.add(grf_usage_label_t::reused_headers, max_reuse_header_regs);
279 info.add(grf_usage_label_t::reserved, constants::reserved_regs);
280 info.add(grf_usage_label_t::zero_points, zp_usage);
281 return info;
282 }
283
284private:
285 int estimate_c_buf_usage() const {
286 int c_bytes = c_thr_elems_ * prb_.acc_data_type_size;
287 return utils::div_up(c_bytes, reg_bytes_);
288 }
289
290 int estimate_gmem_load_usage(int &max_reuse_header_regs) const {
291 int regs = 0;
292 bool use_a_2d_send = can_use_a_2d_send(cfg_);
293 bool use_b_2d_send = can_use_b_2d_send(cfg_);
294 for (bool is_a : {true, false}) {
295 bool use_slm = ab_use_slm(is_a);
296 int per_thr_elems = utils::div_up(ab_tg_elems(is_a), tg_size_);
297 int load_elems = (use_slm ? per_thr_elems : ab_subtile_elems(is_a));
298 auto layout = get_gmem_layout(is_a);
299 bool use_2d_send = (is_a ? use_a_2d_send : use_b_2d_send);
300 access_grf_usage_helper_t load(layout, load_elems, reg_bytes_,
301 /*is_slm=*/false, use_2d_send);
302 if (is_a && !use_slm && can_reliably_use_dpasw_)
303 load.enable_fused_eus_sharing();
304 int mult = (use_slm ? cfg_.slm().gmem_bufs() : 1);
305 regs += mult * load.payload_regs();
306 if (cfg_.pipeline().reuse_headers()) {
307 max_reuse_header_regs = std::max(
308 max_reuse_header_regs, load.header_regs_per_msg());
309 } else {
310 int subtiles
311 = (is_a ? cfg_.subtiles().a() : cfg_.subtiles().b());
312 int mult = (use_slm ? 1 : subtiles);
313 regs += mult * load.header_regs();
314 if (cfg_.prefetch()) {
315 access_grf_usage_helper_t prefetch(layout, per_thr_elems,
316 reg_bytes_, /*is_slm=*/false, use_2d_send);
317 regs += prefetch.header_regs();
318 }
319 }
320 }
321 return regs;
322 }
323
324 int estimate_slm_store_usage(int &a_payload_regs, int &b_payload_regs,
325 int &max_reuse_header_regs) const {
326 int regs = 0;
327 for (bool is_a : {true, false}) {
328 if (!ab_use_slm(is_a)) continue;
329
330 int per_thr_elems = utils::div_up(ab_tg_elems(is_a), tg_size_);
331 int bytes = per_thr_elems * ab_type_size(is_a);
332 auto slm_layout = dummy_slm_layout(bytes);
333 access_grf_usage_helper_t store(slm_layout, bytes, reg_bytes_,
334 /*is_slm=*/true, /*use_2d_send=*/false);
335 int &payload_regs = (is_a ? a_payload_regs : b_payload_regs);
336 payload_regs = store.payload_regs();
337 if (cfg_.pipeline().reuse_headers()) {
338 max_reuse_header_regs = std::max(
339 max_reuse_header_regs, store.header_regs_per_msg());
340 } else {
341 regs += store.header_regs();
342 }
343 }
344 return regs;
345 }
346
347 int estimate_slm_load_usage(int &max_reuse_header_regs) const {
348 int regs = 0;
349 for (bool is_a : {true, false}) {
350 if (!ab_use_slm(is_a)) continue;
351
352 int bytes = ab_subtile_elems(is_a) * ab_type_size(is_a);
353 auto slm_layout = dummy_slm_layout(bytes);
354 access_grf_usage_helper_t load(slm_layout, bytes, reg_bytes_,
355 /*is_slm=*/true, /*use_2d_send=*/false);
356 if (is_a && can_reliably_use_dpasw_)
357 load.enable_fused_eus_sharing();
358 regs += load.payload_regs();
359 if (cfg_.pipeline().reuse_headers()) {
360 max_reuse_header_regs = std::max(
361 max_reuse_header_regs, load.header_regs_per_msg());
362 } else {
363 regs += load.header_regs();
364 }
365 }
366
367 return regs;
368 }
369
370 // Extra registers for GRF <-> GRF reorders.
371 // Estimates upper bound for A/B reorders to temporary buffers.
372 int estimate_reorder_usage(int a_payload_regs, int b_payload_regs) const {
373 if (!cfg_.allow_a_grf_reorder() && !cfg_.allow_b_grf_reorder())
374 return 0;
375
376 int regs = 0;
377 if (prb_.is_bwd_w) {
378 // Hardcode the size of the temporary reorder buffer for BWD_W to
379 // avoid suboptimal performance.
380 int bwd_w_reorder_regs = 16;
381 regs += bwd_w_reorder_regs;
382 }
383
384 for (bool is_a : {true, false}) {
385 bool allow_grf_reorder = (is_a ? cfg_.allow_a_grf_reorder()
386 : cfg_.allow_b_grf_reorder());
387 if (!allow_grf_reorder) continue;
388 int reorder_regs = 0;
389 if (ab_use_slm(is_a)) {
390 int &payload_regs = (is_a ? a_payload_regs : b_payload_regs);
391 reorder_regs = payload_regs;
392 } else {
393 int size = ab_subtile_elems(is_a) * ab_type_size(is_a);
394 reorder_regs = utils::div_up(size, reg_bytes_);
395 }
396 regs += reorder_regs;
397 }
398
399 return regs;
400 }
401
402 int estimate_zero_point_usage() const {
403 if (!prb_.zp_cfg.do_src_compensation) return 0;
404 int sp_iter_dim = 1;
405 for (auto *name : {"ow", "iw", "osp"}) {
406 sp_iter_dim *= cfg_.iter_dim(name);
407 }
408 int subtiles = cfg_.subtiles().a() * cfg_.subtiles().b();
409 int zp_mask0_regs = 2
410 * utils::div_up(
411 sp_iter_dim * (int)sizeof(uint32_t), reg_bytes_);
412 int zp_mask1_regs = subtiles
413 * utils::div_up(
414 sp_iter_dim * (int)sizeof(uint16_t), reg_bytes_);
415 int zp_buf_regs = subtiles * utils::div_up(128, reg_bytes_);
416 int zp_header_regs = subtiles;
417 int zp_let_regs = 4;
418 return zp_mask0_regs + zp_mask1_regs + zp_buf_regs + zp_header_regs
419 + zp_let_regs;
420 }
421
422 layout_t get_gmem_layout(bool is_a) const {
423 auto layout = (is_a ? cfg_.a_layout() : cfg_.b_layout()).compute();
424 bool is_src_dst = is_a || prb_.is_bwd_w;
425 if (is_src_dst && prb_.is_dw) {
426 auto &blocks = layout.blocks();
427 if (!blocks.empty()) {
428 auto &b0 = blocks[0];
429 std::vector<block_t> new_blocks(
430 blocks.begin() + 1, blocks.end());
431 // Remove the innermost block of channels for depthwise
432 // convolution.
433 if (b0.dim_idx == 2 && b0.block == 1) {
434 layout = layout_t(layout.type(), layout.ndims(),
435 layout.offset(), new_blocks,
436 /*do_normalize=*/false);
437 }
438 }
439 }
440 return layout;
441 }
442
443 int ab_type_size(bool is_a) const {
444 auto ret = is_a ? prb_.a_data_type_size : prb_.b_data_type_size;
445 if (prb_.is_s32_accumulator() && cfg_.fma_kind() == fma_kind_t::mad) {
446 // s8/u8 is converted to dword-strided word for mad.
447 ir_assert(ret == 1);
448 ret = 4;
449 }
450 return ret;
451 }
452
453 int ab_tg_elems(bool is_a) const {
454 return is_a ? a_tg_elems_ : b_tg_elems_;
455 }
456
457 int ab_thr_elems(bool is_a) const {
458 return is_a ? a_thr_elems_ : b_thr_elems_;
459 }
460
461 int ab_subtile_elems(bool is_a) const {
462 return is_a ? a_subtile_elems_ : b_subtile_elems_;
463 }
464
465 int ab_use_slm(bool is_a) const {
466 return is_a ? cfg_.slm().a() : cfg_.slm().b();
467 }
468
469 bool can_reliably_use_dpasw(const bmnk_dim_helper_t &h) {
470 if (cfg_.fma_kind() != fma_kind_t::dpasw) return false;
471 if (!cfg_.slm().a()) return false;
472 int m_tg_bytes = h.thread_group_dim('m') * h.iter_dim('m')
473 * prb_.a_data_type_size;
474 int m_thr_bytes
475 = ir_utils::safe_divide(m_tg_bytes, h.thread_group_dim('m'));
476 int owordx16_size = 256;
477 if (cfg_.a_layout().compute().innermost_block_layout().size()
478 < owordx16_size)
479 return false;
480 int k_iter_blk = h.iter_dim('k');
481 if (m_thr_bytes * k_iter_blk % owordx16_size != 0) return false;
482 int nmsgs = m_thr_bytes * k_iter_blk / owordx16_size;
483 if (nmsgs % 2 != 0) return false;
484 return true;
485 }
486
487 layout_t dummy_slm_layout(int size) const {
488 int inner_block = 16; // In bytes.
489 int outer_block = utils::div_up(size, inner_block);
490 std::vector<block_t> blocks;
491 blocks.emplace_back(0, inner_block, 1);
492 blocks.emplace_back(1, outer_block, inner_block);
493 blocks.emplace_back(0, 1, size);
494 blocks.emplace_back(1, 1, size);
495 return layout_t(type_t::byte(), 2, 0, blocks, /*do_normalize=*/false);
496 }
497
498 const conv_problem_t &prb_;
499 const conv_config_t &cfg_;
500
501 int reg_bytes_;
502 int tg_size_;
503 int m_tg_dim_;
504 int n_tg_dim_;
505 int a_tg_elems_;
506 int b_tg_elems_;
507 int a_thr_elems_;
508 int b_thr_elems_;
509 int c_thr_elems_;
510 int a_subtile_elems_;
511 int b_subtile_elems_;
512 bool can_reliably_use_dpasw_;
513};
514
515grf_usage_t estimate_grf_usage(const conv_config_t &cfg) {
516 grf_usage_helper_t helper(cfg);
517 return helper.estimate();
518}
519
520class ir_usage_analyzer_t : public ir_visitor_t {
521public:
522 ir_usage_analyzer_t(int grf_size)
523 : grf_size_(grf_size), buf_usage_(grf_size) {}
524
525 void analyze(const stmt_t &stmt, bool allow_errors = false) {
526 visit(stmt);
527 if (!is_invalid_) {
528 if (peak_headers_ <= 1) {
529 std::vector<expr_t> header_bufs;
530 for (auto &buf : buf_usage_.bufs()) {
531 if (is_header(buf)) header_bufs.push_back(buf);
532 }
533 for (auto &buf : header_bufs) {
534 buf_usage_.remove(buf);
535 }
536 }
537 }
538 if (!verify(allow_errors)) is_invalid_ = true;
539 }
540
541 void _visit(const alloc_t &obj) override {
542 if (is_invalid_) return;
543 int size = (obj.kind == alloc_kind_t::grf ? obj.size : 0);
544 size = utils::rnd_up(size, grf_size_);
545 mem_usage_guard_t alloc_guard(&alloc_usage_, &peak_alloc_usage_, size);
546 mem_usage_guard_t guard(&grf_usage_, &peak_grf_usage_, size);
547 if (size > 0) {
548 buf_usage_.add(obj.buf, obj.size, grf_usage_label_t::unknown);
549 mark_known_bufs(obj.buf);
550 }
551 mem_usage_guard_t header_guard;
552 if (is_header(obj.buf))
553 header_guard = mem_usage_guard_t(&headers_, &peak_headers_, 1);
554 ir_visitor_t::_visit(obj);
555 }
556
557 void _visit(const func_call_t &obj) override {
558 if (is_invalid_) return;
559 auto &func = obj.func;
560 if (auto *reorder = func.as_ptr<reorder_t>()) {
561 auto &src = get_base(reorder_t::arg_src_buf(obj));
562 auto &dst = get_base(reorder_t::arg_dst_buf(obj));
563 mark_bufs(*reorder, src, dst);
564 } else if (auto *send = func.as_ptr<send_t>()) {
565 if (!send_t::arg_header_buf(obj).is<var_t>()) {
566 is_invalid_ = true;
567 return;
568 }
569 auto &buf = get_base(send_t::arg_reg_buf(obj));
570 auto &header = get_base(send_t::arg_header_buf(obj));
571 mark_bufs(*send, buf, header);
572 } else if (is_func_call<dpas_t>(obj)) {
573 auto &dst = get_base(dpas_t::arg_dst(obj));
574 auto &src1 = get_base(dpas_t::arg_src1(obj));
575 auto &src2 = get_base(dpas_t::arg_src2(obj));
576 mark_fma_bufs(dst, src1, src2);
577 } else if (is_func_call<mad_t>(obj)) {
578 auto &dst = get_base(mad_t::arg_dst(obj));
579 auto &src1 = get_base(mad_t::arg_src1(obj));
580 auto &src2 = get_base(mad_t::arg_src2(obj));
581 mark_fma_bufs(dst, src1, src2);
582 }
583 }
584
585 void _visit(const let_t &obj) override {
586 if (is_invalid_) return;
587 int size = (obj.value.is_empty() ? 0 : obj.var.type().size());
588 size = utils::rnd_up(size, reg_allocator_t::granularity);
589 mem_usage_guard_t guard(&grf_usage_, &peak_grf_usage_, size);
590 ir_visitor_t::_visit(obj);
591 }
592
593 void _visit(const stmt_group_t &obj) override {
594 if (is_invalid_) return;
595 if (obj.label == stmt_label_t::c_store()) {
596 // Do not analyze C store consumption for simplicity. Assume there
597 // is enough space after releasing A/B and other buffers after the
598 // main loop.
599 return;
600 }
601 ir_visitor_t::_visit(obj);
602 }
603
604 void _visit(const store_t &obj) override {
605 if (is_invalid_) return;
606 auto loads = find_objects<load_t>(obj);
607 for (auto &l : loads) {
608 if (obj.buf.is_same(l.as<load_t>().buf)) {
609 set_label(obj.buf, grf_usage_label_t::tmp_vars);
610 break;
611 }
612 }
613 ir_visitor_t::_visit(obj);
614 }
615
616 grf_usage_t get_grf_usage(int external_regs) const {
617 if (is_invalid_) return grf_usage_t();
618 grf_usage_t info(grf_size_);
619 info.add(buf_usage_);
620 info.add(grf_usage_label_t::reserved, external_regs);
621 info.add(grf_usage_label_t::tmp_vars,
622 utils::div_up(peak_grf_usage_ - peak_alloc_usage_, grf_size_));
623 return info;
624 }
625
626private:
627 bool verify(bool allow_errors) const {
628 if (is_invalid_) {
629 if (!allow_errors)
630 ir_error_not_expected() << "Can't collect GRF usage.";
631 return false;
632 }
633 for (auto &buf : buf_usage_.bufs()) {
634 if (buf_usage_.get_label(buf) != grf_usage_label_t::unknown)
635 continue;
636 if (!allow_errors)
637 ir_error_not_expected() << "Buffer doesn't have label: " << buf;
638 return false;
639 }
640 return true;
641 }
642
643 bool is_buffer(const expr_t &buf) const { return buf_usage_.has(buf); }
644
645 bool is_header(const expr_t &buf) const {
646 if (!is_buffer(buf)) return false;
647 auto &name = buf.as<var_t>().name;
648 return name.find("h_") == 0;
649 }
650
651 bool should_skip_if_set(const expr_t &buf, grf_usage_label_t label) const {
652 if (is_known_buf(buf)) return true;
653 switch (label) {
654 case grf_usage_label_t::tmp_vars:
655 case grf_usage_label_t::slm_store: return true;
656 default: return false;
657 }
658 }
659
660 void set_label(const expr_t &buf, grf_usage_label_t label) {
661 if (is_invalid_) return;
662 bool skip_if_set = should_skip_if_set(buf, label);
663 auto buf_label = buf_usage_.get_label(buf);
664 if (utils::one_of(buf_label, grf_usage_label_t::unknown, label)) {
665 buf_usage_.set_label(buf, label);
666 } else {
667 if (skip_if_set) return;
668 ir_error_not_expected()
669 << "Label already set. Buffer: " << buf
670 << ", old label: " << buf_label << ", new label: " << label;
671 }
672 }
673
674 void mark_known_bufs(const expr_t &buf) {
675 if (is_invalid_) return;
676 ir_assert(is_buffer(buf));
677 auto &name = buf.as<var_t>().name;
678 if (name == "b_reduced") {
679 set_label(buf, grf_usage_label_t::out_buf);
680 } else if (name.find("zp_") == 0) {
681 set_label(buf, grf_usage_label_t::zero_points);
682 }
683 }
684
685 bool is_known_buf(const expr_t &buf) const {
686 ir_assert(is_buffer(buf));
687 auto &name = buf.as<var_t>().name;
688 if (name.find("zp_") == 0) return true;
689 if (name == "b_reduced") return true;
690 return false;
691 }
692
693 void mark_bufs(
694 const reorder_t &reorder, const expr_t &src, const expr_t &dst) {
695 if (is_invalid_) return;
696 ir_assert(is_buffer(src));
697 ir_assert(is_buffer(dst));
698 set_label(dst, grf_usage_label_t::reorder);
699 }
700
701 void mark_bufs(
702 const send_t &send, const expr_t &buf, const expr_t &header) {
703 if (is_invalid_) return;
704 if (!buf.is_empty()) ir_assert(is_buffer(buf));
705 ir_assert(is_buffer(header));
706 ir_assert(is_header(header));
707 grf_usage_label_t label = grf_usage_label_t::unknown;
708 if (buf.is_empty()) {
709 label = grf_usage_label_t::gmem_load;
710 } else if (buf_usage_.get_label(buf)
711 == grf_usage_label_t::zero_points) {
712 label = grf_usage_label_t::zero_points;
713 } else if (send.is_slm()) {
714 label = (send.is_load() ? grf_usage_label_t::slm_load
715 : grf_usage_label_t::slm_store);
716 } else {
717 if (!send.is_load() && !send.is_load_2d()) {
718 is_invalid_ = true;
719 return;
720 }
721 label = grf_usage_label_t::gmem_load;
722 }
723 if (!buf.is_empty()) set_label(buf, label);
724 set_label(header, label);
725 }
726
727 void mark_fma_bufs(
728 const expr_t &dst, const expr_t &src1, const expr_t &src2) {
729 if (is_invalid_) return;
730 ir_assert(is_buffer(dst));
731 ir_assert(is_buffer(src1));
732 ir_assert(is_buffer(src2));
733 set_label(dst, grf_usage_label_t::out_buf);
734 }
735
736 int grf_size_;
737 bool is_invalid_ = false;
738 grf_buf_usage_t buf_usage_;
739
740 int grf_usage_ = 0;
741 int alloc_usage_ = 0;
742
743 int peak_grf_usage_ = 0;
744 int peak_alloc_usage_ = 0;
745
746 int headers_ = 0;
747 int peak_headers_ = 0;
748};
749
750grf_usage_t get_grf_usage(const stmt_t &body, int grf_size) {
751 ir_usage_analyzer_t analyzer(grf_size);
752 analyzer.visit(body);
753 return analyzer.get_grf_usage(0);
754}
755
756void compare(const grf_usage_t &est_usage, const grf_usage_t &ir_usage,
757 const ir_usage_analyzer_t &analyzer) {
758 std::vector<std::string> headers
759 = {"Label", "Estimated regs", "IR regs", "Status"};
760 ir_utils::table_t table("Compare GRF usage:", headers);
761 int est_total = 0;
762 int ir_total = 0;
763 for (auto label : all_grf_usage_labels()) {
764 int est_regs = est_usage.get(label);
765 int ir_regs = ir_usage.get(label);
766 table << " " + to_string(label) << est_regs << ir_regs;
767 table << (ir_regs > est_regs ? "FAIL" : "");
768 table << std::endl;
769 est_total += est_regs;
770 ir_total += ir_regs;
771 }
772 table << " Total" << est_total << ir_total;
773 table << (ir_total > est_total ? "FAIL" : "");
774 table << std::endl;
775 ir_trace() << table << std::endl;
776 ir_trace() << ir_usage.buf_usage() << std::endl;
777}
778
779void verify_grf_usage(
780 const conv_config_t &cfg, const stmt_t &body, int external_usage) {
781 ir_usage_analyzer_t analyzer(cfg.grf_size());
782 analyzer.analyze(body);
783
784 auto ir_info = analyzer.get_grf_usage(external_usage);
785 auto est_info = estimate_grf_usage(cfg);
786 compare(est_info, ir_info, analyzer);
787}
788
789} // namespace jit
790} // namespace gpu
791} // namespace impl
792} // namespace dnnl
793