1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/conv/grf_usage.hpp" |
18 | |
19 | #include <sstream> |
20 | |
21 | #include "gpu/jit/codegen/register_allocator.hpp" |
22 | #include "gpu/jit/ir/message.hpp" |
23 | #include "gpu/jit/ir/reorder.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace gpu { |
28 | namespace jit { |
29 | |
30 | std::string to_string(grf_usage_label_t label) { |
31 | switch (label) { |
32 | #define CASE(l) \ |
33 | case grf_usage_label_t::l: return #l; |
34 | CASE(unknown) |
35 | CASE(gmem_load) |
36 | CASE(out_buf) |
37 | CASE(reorder) |
38 | CASE(reserved) |
39 | CASE(reused_headers) |
40 | CASE(slm_load) |
41 | CASE(slm_store) |
42 | CASE(tmp_vars) |
43 | CASE(zero_points) |
44 | #undef CASE |
45 | default: ir_error_not_expected(); |
46 | } |
47 | return "" ; |
48 | } |
49 | |
50 | std::ostream &operator<<(std::ostream &out, grf_usage_label_t label) { |
51 | out << to_string(label); |
52 | return out; |
53 | } |
54 | |
55 | std::string grf_buf_usage_t::str() const { |
56 | std::ostringstream oss; |
57 | oss << "Buffers:" ; |
58 | for (auto label : all_grf_usage_labels()) { |
59 | int regs = total_regs(label); |
60 | if (regs == 0) continue; |
61 | oss << std::endl << " " << label << " (" << regs << "): " ; |
62 | bool is_first = true; |
63 | for (auto &buf : sorted_bufs()) { |
64 | if (get_label(buf) != label) continue; |
65 | if (!is_first) oss << ", " ; |
66 | is_first = false; |
67 | oss << buf << "[" << get_size(buf) << "]" ; |
68 | } |
69 | } |
70 | return oss.str(); |
71 | } |
72 | |
73 | std::string grf_usage_t::str() const { |
74 | std::vector<std::string> headers = {"Label" , "Regs" }; |
75 | ir_utils::table_t table("GRF usage (registers):" , headers); |
76 | int total = 0; |
77 | for (auto label : all_grf_usage_labels()) { |
78 | int regs = regs_.at(label); |
79 | if (regs == 0) continue; |
80 | table << " " + to_string(label) << regs << std::endl; |
81 | total += regs; |
82 | } |
83 | table << " Total" << total << std::endl; |
84 | std::ostringstream oss; |
85 | oss << table << std::endl; |
86 | oss << buf_usage_; |
87 | return oss.str(); |
88 | } |
89 | |
90 | class access_grf_usage_helper_t { |
91 | public: |
92 | access_grf_usage_helper_t(const layout_t &mem_layout, int elems, |
93 | int reg_bytes, bool is_slm, bool use_2d_send) |
94 | : mem_type_size_(mem_layout.type().size()) |
95 | , reg_bytes_(reg_bytes) |
96 | , is_slm_(is_slm) |
97 | , use_2d_send_(use_2d_send) { |
98 | init_message_size(mem_layout); |
99 | init_payload_size(elems); |
100 | init_header_size(); |
101 | } |
102 | |
103 | // This setting is related to dpasw loads. dpasw reuses registers between |
104 | // fused threads so each of the fused threads need to load only half of the |
105 | // data it will access. |
106 | void enable_fused_eus_sharing() { enabled_fused_eus_sharing_ = true; } |
107 | |
108 | int payload_regs() const { |
109 | int ret = payload_size_ / reg_bytes_; |
110 | if (enabled_fused_eus_sharing_) ret = utils::div_up(ret, 2); |
111 | return ret; |
112 | } |
113 | |
114 | int () const { |
115 | return header_size_per_msg_ / reg_bytes_; |
116 | } |
117 | |
118 | int () const { |
119 | int ret = nmsgs_ * header_regs_per_msg(); |
120 | if (enabled_fused_eus_sharing_) ret = utils::div_up(ret, 2); |
121 | return ret; |
122 | } |
123 | |
124 | private: |
125 | void init_message_size(const layout_t &mem_layout) { |
126 | auto l = mem_layout.innermost_block_layout(); |
127 | int block_bytes = (is_slm_ ? oword_bytes_ : hword_bytes_); |
128 | int max_block_bytes = (is_slm_ ? 16 * oword_bytes_ : 8 * hword_bytes_); |
129 | block_t b0; |
130 | int b0_size = mem_type_size_; |
131 | auto &mem_blocks = mem_layout.blocks(); |
132 | if (!mem_blocks.empty()) { |
133 | b0 = mem_blocks[0]; |
134 | b0_size = b0.block * mem_type_size_; |
135 | } |
136 | if (use_2d_send_) { |
137 | is_block_ = true; |
138 | // It's hard to determine 2D block message decomposition at this |
139 | // point but in general 2D block messages are larger so use 2x of a |
140 | // regular block message (empirical estimate). |
141 | msg_size_ = 2 * max_block_bytes; |
142 | payload_bytes_per_elem_ = mem_type_size_; |
143 | } else if (l.size() % block_bytes == 0) { |
144 | is_block_ = true; |
145 | msg_size_ = (l.size() % max_block_bytes == 0) ? max_block_bytes |
146 | : block_bytes; |
147 | payload_bytes_per_elem_ = mem_type_size_; |
148 | } else if (!b0.is_empty() && b0_size % block_bytes == 0) { |
149 | is_block_ = true; |
150 | msg_size_ = block_bytes; |
151 | payload_bytes_per_elem_ = mem_type_size_; |
152 | } else { |
153 | ir_assert(!is_slm_) << "Unexpected scattered messages with SLM." ; |
154 | // Assume scattered byte SIMD16 load as the worst case. Check if |
155 | // we can use byte x {1,2,4} messages. |
156 | int slots = 16; |
157 | int bytes_per_slot = 4; |
158 | for (int x : {4, 2, 1}) { |
159 | if (x < bytes_per_slot && mem_type_size_ != x) continue; |
160 | if (b0_size % x == 0) { |
161 | msg_size_ = slots * x; |
162 | payload_bytes_per_elem_ |
163 | = mem_type_size_ * (bytes_per_slot / x); |
164 | break; |
165 | } |
166 | } |
167 | ir_assert(msg_size_ > 0); |
168 | } |
169 | } |
170 | |
171 | void init_payload_size(int elems) { |
172 | int elems_per_msg = utils::div_up(msg_size_, mem_type_size_); |
173 | int payload_per_msg = elems_per_msg * payload_bytes_per_elem_; |
174 | int payload_per_msg_grf_aligned |
175 | = utils::rnd_up(payload_per_msg, reg_bytes_); |
176 | nmsgs_ = utils::div_up(elems * mem_type_size_, msg_size_); |
177 | payload_size_ = nmsgs_ * payload_per_msg_grf_aligned; |
178 | } |
179 | |
180 | void () { |
181 | if (is_block_) { |
182 | // One register per header for block messages. |
183 | header_size_per_msg_ = reg_bytes_; |
184 | } else { |
185 | // Assume SIMD16 with A64 address model. |
186 | int slots = 16; |
187 | int bytes_per_slot = sizeof(uint64_t); |
188 | header_size_per_msg_ |
189 | = utils::rnd_up(slots * bytes_per_slot, reg_bytes_); |
190 | } |
191 | } |
192 | |
193 | static const int oword_bytes_ = 16; |
194 | static const int hword_bytes_ = 32; |
195 | |
196 | int mem_type_size_ = 0; |
197 | int reg_bytes_ = 0; |
198 | bool is_slm_ = false; |
199 | bool use_2d_send_ = false; |
200 | bool enabled_fused_eus_sharing_ = false; |
201 | |
202 | // Whether message is block or scattered. |
203 | bool is_block_ = false; |
204 | |
205 | // Amount of memory that can be read by a single message from global memory. |
206 | int msg_size_ = 0; |
207 | |
208 | // How many bytes are occupied by a single element in the message payload. |
209 | int payload_bytes_per_elem_ = 0; |
210 | |
211 | // Size of GRF buffers for all messages to load data. |
212 | int payload_size_ = 0; |
213 | |
214 | // Number of messages to load data. |
215 | int nmsgs_ = 0; |
216 | |
217 | // Size of header buffer per message. |
218 | int = 0; |
219 | }; |
220 | |
221 | // Helper class to provide GRF usage estimation. |
222 | class grf_usage_helper_t { |
223 | public: |
224 | grf_usage_helper_t(const conv_config_t &cfg) : prb_(cfg.prb()), cfg_(cfg) { |
225 | auto &tg_grid = cfg_.thread_group_grid(); |
226 | |
227 | reg_bytes_ = cfg_.grf_size(); |
228 | tg_size_ = tg_grid.elems(); |
229 | |
230 | bmnk_dim_helper_t h(cfg_); |
231 | m_tg_dim_ = h.thread_group_dim('m'); |
232 | n_tg_dim_ = h.thread_group_dim('n'); |
233 | |
234 | int b_iter_blk = h.iter_dim('b'); |
235 | int m_iter_blk = h.iter_dim('m'); |
236 | int n_iter_blk = h.iter_dim('n'); |
237 | int k_iter_blk = h.iter_dim('k'); |
238 | |
239 | if (!cfg_.ow_kw_grf_cache()) { |
240 | a_thr_elems_ = b_iter_blk * m_iter_blk * k_iter_blk; |
241 | } else { |
242 | ir_assert(!cfg_.slm().a()); |
243 | int a_m_blk = (prb_.sw * (m_iter_blk - 1) |
244 | + (prb_.kw - 1) * (1 + prb_.dw) + 1); |
245 | int a_k_blk = utils::div_up(k_iter_blk, prb_.kw); |
246 | a_thr_elems_ = b_iter_blk * a_m_blk * a_k_blk; |
247 | } |
248 | |
249 | b_thr_elems_ = b_iter_blk * n_iter_blk * k_iter_blk; |
250 | c_thr_elems_ = b_iter_blk * m_iter_blk * n_iter_blk; |
251 | a_tg_elems_ = a_thr_elems_ * m_tg_dim_; |
252 | b_tg_elems_ = b_thr_elems_ * n_tg_dim_; |
253 | a_subtile_elems_ = utils::div_up(a_thr_elems_, cfg_.subtiles().a()); |
254 | b_subtile_elems_ = utils::div_up(b_thr_elems_, cfg_.subtiles().b()); |
255 | can_reliably_use_dpasw_ = can_reliably_use_dpasw(h); |
256 | } |
257 | |
258 | grf_usage_t estimate() const { |
259 | int = 0; |
260 | int a_slm_store_payload_regs = 0; |
261 | int b_slm_store_payload_regs = 0; |
262 | |
263 | int c_buf_usage = estimate_c_buf_usage(); |
264 | int gmem_load_usage = estimate_gmem_load_usage(max_reuse_header_regs); |
265 | int slm_store_usage = estimate_slm_store_usage(a_slm_store_payload_regs, |
266 | b_slm_store_payload_regs, max_reuse_header_regs); |
267 | int slm_load_usage = estimate_slm_load_usage(max_reuse_header_regs); |
268 | int reorder_usage = estimate_reorder_usage( |
269 | a_slm_store_payload_regs, b_slm_store_payload_regs); |
270 | int zp_usage = estimate_zero_point_usage(); |
271 | |
272 | grf_usage_t info(cfg_.grf_size()); |
273 | info.add(grf_usage_label_t::out_buf, c_buf_usage); |
274 | info.add(grf_usage_label_t::gmem_load, gmem_load_usage); |
275 | info.add(grf_usage_label_t::slm_store, slm_store_usage); |
276 | info.add(grf_usage_label_t::slm_load, slm_load_usage); |
277 | info.add(grf_usage_label_t::reorder, reorder_usage); |
278 | info.add(grf_usage_label_t::reused_headers, max_reuse_header_regs); |
279 | info.add(grf_usage_label_t::reserved, constants::reserved_regs); |
280 | info.add(grf_usage_label_t::zero_points, zp_usage); |
281 | return info; |
282 | } |
283 | |
284 | private: |
285 | int estimate_c_buf_usage() const { |
286 | int c_bytes = c_thr_elems_ * prb_.acc_data_type_size; |
287 | return utils::div_up(c_bytes, reg_bytes_); |
288 | } |
289 | |
290 | int estimate_gmem_load_usage(int &) const { |
291 | int regs = 0; |
292 | bool use_a_2d_send = can_use_a_2d_send(cfg_); |
293 | bool use_b_2d_send = can_use_b_2d_send(cfg_); |
294 | for (bool is_a : {true, false}) { |
295 | bool use_slm = ab_use_slm(is_a); |
296 | int per_thr_elems = utils::div_up(ab_tg_elems(is_a), tg_size_); |
297 | int load_elems = (use_slm ? per_thr_elems : ab_subtile_elems(is_a)); |
298 | auto layout = get_gmem_layout(is_a); |
299 | bool use_2d_send = (is_a ? use_a_2d_send : use_b_2d_send); |
300 | access_grf_usage_helper_t load(layout, load_elems, reg_bytes_, |
301 | /*is_slm=*/false, use_2d_send); |
302 | if (is_a && !use_slm && can_reliably_use_dpasw_) |
303 | load.enable_fused_eus_sharing(); |
304 | int mult = (use_slm ? cfg_.slm().gmem_bufs() : 1); |
305 | regs += mult * load.payload_regs(); |
306 | if (cfg_.pipeline().reuse_headers()) { |
307 | max_reuse_header_regs = std::max( |
308 | max_reuse_header_regs, load.header_regs_per_msg()); |
309 | } else { |
310 | int subtiles |
311 | = (is_a ? cfg_.subtiles().a() : cfg_.subtiles().b()); |
312 | int mult = (use_slm ? 1 : subtiles); |
313 | regs += mult * load.header_regs(); |
314 | if (cfg_.prefetch()) { |
315 | access_grf_usage_helper_t prefetch(layout, per_thr_elems, |
316 | reg_bytes_, /*is_slm=*/false, use_2d_send); |
317 | regs += prefetch.header_regs(); |
318 | } |
319 | } |
320 | } |
321 | return regs; |
322 | } |
323 | |
324 | int estimate_slm_store_usage(int &a_payload_regs, int &b_payload_regs, |
325 | int &) const { |
326 | int regs = 0; |
327 | for (bool is_a : {true, false}) { |
328 | if (!ab_use_slm(is_a)) continue; |
329 | |
330 | int per_thr_elems = utils::div_up(ab_tg_elems(is_a), tg_size_); |
331 | int bytes = per_thr_elems * ab_type_size(is_a); |
332 | auto slm_layout = dummy_slm_layout(bytes); |
333 | access_grf_usage_helper_t store(slm_layout, bytes, reg_bytes_, |
334 | /*is_slm=*/true, /*use_2d_send=*/false); |
335 | int &payload_regs = (is_a ? a_payload_regs : b_payload_regs); |
336 | payload_regs = store.payload_regs(); |
337 | if (cfg_.pipeline().reuse_headers()) { |
338 | max_reuse_header_regs = std::max( |
339 | max_reuse_header_regs, store.header_regs_per_msg()); |
340 | } else { |
341 | regs += store.header_regs(); |
342 | } |
343 | } |
344 | return regs; |
345 | } |
346 | |
347 | int estimate_slm_load_usage(int &) const { |
348 | int regs = 0; |
349 | for (bool is_a : {true, false}) { |
350 | if (!ab_use_slm(is_a)) continue; |
351 | |
352 | int bytes = ab_subtile_elems(is_a) * ab_type_size(is_a); |
353 | auto slm_layout = dummy_slm_layout(bytes); |
354 | access_grf_usage_helper_t load(slm_layout, bytes, reg_bytes_, |
355 | /*is_slm=*/true, /*use_2d_send=*/false); |
356 | if (is_a && can_reliably_use_dpasw_) |
357 | load.enable_fused_eus_sharing(); |
358 | regs += load.payload_regs(); |
359 | if (cfg_.pipeline().reuse_headers()) { |
360 | max_reuse_header_regs = std::max( |
361 | max_reuse_header_regs, load.header_regs_per_msg()); |
362 | } else { |
363 | regs += load.header_regs(); |
364 | } |
365 | } |
366 | |
367 | return regs; |
368 | } |
369 | |
370 | // Extra registers for GRF <-> GRF reorders. |
371 | // Estimates upper bound for A/B reorders to temporary buffers. |
372 | int estimate_reorder_usage(int a_payload_regs, int b_payload_regs) const { |
373 | if (!cfg_.allow_a_grf_reorder() && !cfg_.allow_b_grf_reorder()) |
374 | return 0; |
375 | |
376 | int regs = 0; |
377 | if (prb_.is_bwd_w) { |
378 | // Hardcode the size of the temporary reorder buffer for BWD_W to |
379 | // avoid suboptimal performance. |
380 | int bwd_w_reorder_regs = 16; |
381 | regs += bwd_w_reorder_regs; |
382 | } |
383 | |
384 | for (bool is_a : {true, false}) { |
385 | bool allow_grf_reorder = (is_a ? cfg_.allow_a_grf_reorder() |
386 | : cfg_.allow_b_grf_reorder()); |
387 | if (!allow_grf_reorder) continue; |
388 | int reorder_regs = 0; |
389 | if (ab_use_slm(is_a)) { |
390 | int &payload_regs = (is_a ? a_payload_regs : b_payload_regs); |
391 | reorder_regs = payload_regs; |
392 | } else { |
393 | int size = ab_subtile_elems(is_a) * ab_type_size(is_a); |
394 | reorder_regs = utils::div_up(size, reg_bytes_); |
395 | } |
396 | regs += reorder_regs; |
397 | } |
398 | |
399 | return regs; |
400 | } |
401 | |
402 | int estimate_zero_point_usage() const { |
403 | if (!prb_.zp_cfg.do_src_compensation) return 0; |
404 | int sp_iter_dim = 1; |
405 | for (auto *name : {"ow" , "iw" , "osp" }) { |
406 | sp_iter_dim *= cfg_.iter_dim(name); |
407 | } |
408 | int subtiles = cfg_.subtiles().a() * cfg_.subtiles().b(); |
409 | int zp_mask0_regs = 2 |
410 | * utils::div_up( |
411 | sp_iter_dim * (int)sizeof(uint32_t), reg_bytes_); |
412 | int zp_mask1_regs = subtiles |
413 | * utils::div_up( |
414 | sp_iter_dim * (int)sizeof(uint16_t), reg_bytes_); |
415 | int zp_buf_regs = subtiles * utils::div_up(128, reg_bytes_); |
416 | int = subtiles; |
417 | int zp_let_regs = 4; |
418 | return zp_mask0_regs + zp_mask1_regs + zp_buf_regs + zp_header_regs |
419 | + zp_let_regs; |
420 | } |
421 | |
422 | layout_t get_gmem_layout(bool is_a) const { |
423 | auto layout = (is_a ? cfg_.a_layout() : cfg_.b_layout()).compute(); |
424 | bool is_src_dst = is_a || prb_.is_bwd_w; |
425 | if (is_src_dst && prb_.is_dw) { |
426 | auto &blocks = layout.blocks(); |
427 | if (!blocks.empty()) { |
428 | auto &b0 = blocks[0]; |
429 | std::vector<block_t> new_blocks( |
430 | blocks.begin() + 1, blocks.end()); |
431 | // Remove the innermost block of channels for depthwise |
432 | // convolution. |
433 | if (b0.dim_idx == 2 && b0.block == 1) { |
434 | layout = layout_t(layout.type(), layout.ndims(), |
435 | layout.offset(), new_blocks, |
436 | /*do_normalize=*/false); |
437 | } |
438 | } |
439 | } |
440 | return layout; |
441 | } |
442 | |
443 | int ab_type_size(bool is_a) const { |
444 | auto ret = is_a ? prb_.a_data_type_size : prb_.b_data_type_size; |
445 | if (prb_.is_s32_accumulator() && cfg_.fma_kind() == fma_kind_t::mad) { |
446 | // s8/u8 is converted to dword-strided word for mad. |
447 | ir_assert(ret == 1); |
448 | ret = 4; |
449 | } |
450 | return ret; |
451 | } |
452 | |
453 | int ab_tg_elems(bool is_a) const { |
454 | return is_a ? a_tg_elems_ : b_tg_elems_; |
455 | } |
456 | |
457 | int ab_thr_elems(bool is_a) const { |
458 | return is_a ? a_thr_elems_ : b_thr_elems_; |
459 | } |
460 | |
461 | int ab_subtile_elems(bool is_a) const { |
462 | return is_a ? a_subtile_elems_ : b_subtile_elems_; |
463 | } |
464 | |
465 | int ab_use_slm(bool is_a) const { |
466 | return is_a ? cfg_.slm().a() : cfg_.slm().b(); |
467 | } |
468 | |
469 | bool can_reliably_use_dpasw(const bmnk_dim_helper_t &h) { |
470 | if (cfg_.fma_kind() != fma_kind_t::dpasw) return false; |
471 | if (!cfg_.slm().a()) return false; |
472 | int m_tg_bytes = h.thread_group_dim('m') * h.iter_dim('m') |
473 | * prb_.a_data_type_size; |
474 | int m_thr_bytes |
475 | = ir_utils::safe_divide(m_tg_bytes, h.thread_group_dim('m')); |
476 | int owordx16_size = 256; |
477 | if (cfg_.a_layout().compute().innermost_block_layout().size() |
478 | < owordx16_size) |
479 | return false; |
480 | int k_iter_blk = h.iter_dim('k'); |
481 | if (m_thr_bytes * k_iter_blk % owordx16_size != 0) return false; |
482 | int nmsgs = m_thr_bytes * k_iter_blk / owordx16_size; |
483 | if (nmsgs % 2 != 0) return false; |
484 | return true; |
485 | } |
486 | |
487 | layout_t dummy_slm_layout(int size) const { |
488 | int inner_block = 16; // In bytes. |
489 | int outer_block = utils::div_up(size, inner_block); |
490 | std::vector<block_t> blocks; |
491 | blocks.emplace_back(0, inner_block, 1); |
492 | blocks.emplace_back(1, outer_block, inner_block); |
493 | blocks.emplace_back(0, 1, size); |
494 | blocks.emplace_back(1, 1, size); |
495 | return layout_t(type_t::byte(), 2, 0, blocks, /*do_normalize=*/false); |
496 | } |
497 | |
498 | const conv_problem_t &prb_; |
499 | const conv_config_t &cfg_; |
500 | |
501 | int reg_bytes_; |
502 | int tg_size_; |
503 | int m_tg_dim_; |
504 | int n_tg_dim_; |
505 | int a_tg_elems_; |
506 | int b_tg_elems_; |
507 | int a_thr_elems_; |
508 | int b_thr_elems_; |
509 | int c_thr_elems_; |
510 | int a_subtile_elems_; |
511 | int b_subtile_elems_; |
512 | bool can_reliably_use_dpasw_; |
513 | }; |
514 | |
515 | grf_usage_t estimate_grf_usage(const conv_config_t &cfg) { |
516 | grf_usage_helper_t helper(cfg); |
517 | return helper.estimate(); |
518 | } |
519 | |
520 | class ir_usage_analyzer_t : public ir_visitor_t { |
521 | public: |
522 | ir_usage_analyzer_t(int grf_size) |
523 | : grf_size_(grf_size), buf_usage_(grf_size) {} |
524 | |
525 | void analyze(const stmt_t &stmt, bool allow_errors = false) { |
526 | visit(stmt); |
527 | if (!is_invalid_) { |
528 | if (peak_headers_ <= 1) { |
529 | std::vector<expr_t> header_bufs; |
530 | for (auto &buf : buf_usage_.bufs()) { |
531 | if (is_header(buf)) header_bufs.push_back(buf); |
532 | } |
533 | for (auto &buf : header_bufs) { |
534 | buf_usage_.remove(buf); |
535 | } |
536 | } |
537 | } |
538 | if (!verify(allow_errors)) is_invalid_ = true; |
539 | } |
540 | |
541 | void _visit(const alloc_t &obj) override { |
542 | if (is_invalid_) return; |
543 | int size = (obj.kind == alloc_kind_t::grf ? obj.size : 0); |
544 | size = utils::rnd_up(size, grf_size_); |
545 | mem_usage_guard_t alloc_guard(&alloc_usage_, &peak_alloc_usage_, size); |
546 | mem_usage_guard_t guard(&grf_usage_, &peak_grf_usage_, size); |
547 | if (size > 0) { |
548 | buf_usage_.add(obj.buf, obj.size, grf_usage_label_t::unknown); |
549 | mark_known_bufs(obj.buf); |
550 | } |
551 | mem_usage_guard_t ; |
552 | if (is_header(obj.buf)) |
553 | header_guard = mem_usage_guard_t(&headers_, &peak_headers_, 1); |
554 | ir_visitor_t::_visit(obj); |
555 | } |
556 | |
557 | void _visit(const func_call_t &obj) override { |
558 | if (is_invalid_) return; |
559 | auto &func = obj.func; |
560 | if (auto *reorder = func.as_ptr<reorder_t>()) { |
561 | auto &src = get_base(reorder_t::arg_src_buf(obj)); |
562 | auto &dst = get_base(reorder_t::arg_dst_buf(obj)); |
563 | mark_bufs(*reorder, src, dst); |
564 | } else if (auto *send = func.as_ptr<send_t>()) { |
565 | if (!send_t::arg_header_buf(obj).is<var_t>()) { |
566 | is_invalid_ = true; |
567 | return; |
568 | } |
569 | auto &buf = get_base(send_t::arg_reg_buf(obj)); |
570 | auto & = get_base(send_t::arg_header_buf(obj)); |
571 | mark_bufs(*send, buf, header); |
572 | } else if (is_func_call<dpas_t>(obj)) { |
573 | auto &dst = get_base(dpas_t::arg_dst(obj)); |
574 | auto &src1 = get_base(dpas_t::arg_src1(obj)); |
575 | auto &src2 = get_base(dpas_t::arg_src2(obj)); |
576 | mark_fma_bufs(dst, src1, src2); |
577 | } else if (is_func_call<mad_t>(obj)) { |
578 | auto &dst = get_base(mad_t::arg_dst(obj)); |
579 | auto &src1 = get_base(mad_t::arg_src1(obj)); |
580 | auto &src2 = get_base(mad_t::arg_src2(obj)); |
581 | mark_fma_bufs(dst, src1, src2); |
582 | } |
583 | } |
584 | |
585 | void _visit(const let_t &obj) override { |
586 | if (is_invalid_) return; |
587 | int size = (obj.value.is_empty() ? 0 : obj.var.type().size()); |
588 | size = utils::rnd_up(size, reg_allocator_t::granularity); |
589 | mem_usage_guard_t guard(&grf_usage_, &peak_grf_usage_, size); |
590 | ir_visitor_t::_visit(obj); |
591 | } |
592 | |
593 | void _visit(const stmt_group_t &obj) override { |
594 | if (is_invalid_) return; |
595 | if (obj.label == stmt_label_t::c_store()) { |
596 | // Do not analyze C store consumption for simplicity. Assume there |
597 | // is enough space after releasing A/B and other buffers after the |
598 | // main loop. |
599 | return; |
600 | } |
601 | ir_visitor_t::_visit(obj); |
602 | } |
603 | |
604 | void _visit(const store_t &obj) override { |
605 | if (is_invalid_) return; |
606 | auto loads = find_objects<load_t>(obj); |
607 | for (auto &l : loads) { |
608 | if (obj.buf.is_same(l.as<load_t>().buf)) { |
609 | set_label(obj.buf, grf_usage_label_t::tmp_vars); |
610 | break; |
611 | } |
612 | } |
613 | ir_visitor_t::_visit(obj); |
614 | } |
615 | |
616 | grf_usage_t get_grf_usage(int external_regs) const { |
617 | if (is_invalid_) return grf_usage_t(); |
618 | grf_usage_t info(grf_size_); |
619 | info.add(buf_usage_); |
620 | info.add(grf_usage_label_t::reserved, external_regs); |
621 | info.add(grf_usage_label_t::tmp_vars, |
622 | utils::div_up(peak_grf_usage_ - peak_alloc_usage_, grf_size_)); |
623 | return info; |
624 | } |
625 | |
626 | private: |
627 | bool verify(bool allow_errors) const { |
628 | if (is_invalid_) { |
629 | if (!allow_errors) |
630 | ir_error_not_expected() << "Can't collect GRF usage." ; |
631 | return false; |
632 | } |
633 | for (auto &buf : buf_usage_.bufs()) { |
634 | if (buf_usage_.get_label(buf) != grf_usage_label_t::unknown) |
635 | continue; |
636 | if (!allow_errors) |
637 | ir_error_not_expected() << "Buffer doesn't have label: " << buf; |
638 | return false; |
639 | } |
640 | return true; |
641 | } |
642 | |
643 | bool is_buffer(const expr_t &buf) const { return buf_usage_.has(buf); } |
644 | |
645 | bool (const expr_t &buf) const { |
646 | if (!is_buffer(buf)) return false; |
647 | auto &name = buf.as<var_t>().name; |
648 | return name.find("h_" ) == 0; |
649 | } |
650 | |
651 | bool should_skip_if_set(const expr_t &buf, grf_usage_label_t label) const { |
652 | if (is_known_buf(buf)) return true; |
653 | switch (label) { |
654 | case grf_usage_label_t::tmp_vars: |
655 | case grf_usage_label_t::slm_store: return true; |
656 | default: return false; |
657 | } |
658 | } |
659 | |
660 | void set_label(const expr_t &buf, grf_usage_label_t label) { |
661 | if (is_invalid_) return; |
662 | bool skip_if_set = should_skip_if_set(buf, label); |
663 | auto buf_label = buf_usage_.get_label(buf); |
664 | if (utils::one_of(buf_label, grf_usage_label_t::unknown, label)) { |
665 | buf_usage_.set_label(buf, label); |
666 | } else { |
667 | if (skip_if_set) return; |
668 | ir_error_not_expected() |
669 | << "Label already set. Buffer: " << buf |
670 | << ", old label: " << buf_label << ", new label: " << label; |
671 | } |
672 | } |
673 | |
674 | void mark_known_bufs(const expr_t &buf) { |
675 | if (is_invalid_) return; |
676 | ir_assert(is_buffer(buf)); |
677 | auto &name = buf.as<var_t>().name; |
678 | if (name == "b_reduced" ) { |
679 | set_label(buf, grf_usage_label_t::out_buf); |
680 | } else if (name.find("zp_" ) == 0) { |
681 | set_label(buf, grf_usage_label_t::zero_points); |
682 | } |
683 | } |
684 | |
685 | bool is_known_buf(const expr_t &buf) const { |
686 | ir_assert(is_buffer(buf)); |
687 | auto &name = buf.as<var_t>().name; |
688 | if (name.find("zp_" ) == 0) return true; |
689 | if (name == "b_reduced" ) return true; |
690 | return false; |
691 | } |
692 | |
693 | void mark_bufs( |
694 | const reorder_t &reorder, const expr_t &src, const expr_t &dst) { |
695 | if (is_invalid_) return; |
696 | ir_assert(is_buffer(src)); |
697 | ir_assert(is_buffer(dst)); |
698 | set_label(dst, grf_usage_label_t::reorder); |
699 | } |
700 | |
701 | void mark_bufs( |
702 | const send_t &send, const expr_t &buf, const expr_t &) { |
703 | if (is_invalid_) return; |
704 | if (!buf.is_empty()) ir_assert(is_buffer(buf)); |
705 | ir_assert(is_buffer(header)); |
706 | ir_assert(is_header(header)); |
707 | grf_usage_label_t label = grf_usage_label_t::unknown; |
708 | if (buf.is_empty()) { |
709 | label = grf_usage_label_t::gmem_load; |
710 | } else if (buf_usage_.get_label(buf) |
711 | == grf_usage_label_t::zero_points) { |
712 | label = grf_usage_label_t::zero_points; |
713 | } else if (send.is_slm()) { |
714 | label = (send.is_load() ? grf_usage_label_t::slm_load |
715 | : grf_usage_label_t::slm_store); |
716 | } else { |
717 | if (!send.is_load() && !send.is_load_2d()) { |
718 | is_invalid_ = true; |
719 | return; |
720 | } |
721 | label = grf_usage_label_t::gmem_load; |
722 | } |
723 | if (!buf.is_empty()) set_label(buf, label); |
724 | set_label(header, label); |
725 | } |
726 | |
727 | void mark_fma_bufs( |
728 | const expr_t &dst, const expr_t &src1, const expr_t &src2) { |
729 | if (is_invalid_) return; |
730 | ir_assert(is_buffer(dst)); |
731 | ir_assert(is_buffer(src1)); |
732 | ir_assert(is_buffer(src2)); |
733 | set_label(dst, grf_usage_label_t::out_buf); |
734 | } |
735 | |
736 | int grf_size_; |
737 | bool is_invalid_ = false; |
738 | grf_buf_usage_t buf_usage_; |
739 | |
740 | int grf_usage_ = 0; |
741 | int alloc_usage_ = 0; |
742 | |
743 | int peak_grf_usage_ = 0; |
744 | int peak_alloc_usage_ = 0; |
745 | |
746 | int = 0; |
747 | int = 0; |
748 | }; |
749 | |
750 | grf_usage_t get_grf_usage(const stmt_t &body, int grf_size) { |
751 | ir_usage_analyzer_t analyzer(grf_size); |
752 | analyzer.visit(body); |
753 | return analyzer.get_grf_usage(0); |
754 | } |
755 | |
756 | void compare(const grf_usage_t &est_usage, const grf_usage_t &ir_usage, |
757 | const ir_usage_analyzer_t &analyzer) { |
758 | std::vector<std::string> headers |
759 | = {"Label" , "Estimated regs" , "IR regs" , "Status" }; |
760 | ir_utils::table_t table("Compare GRF usage:" , headers); |
761 | int est_total = 0; |
762 | int ir_total = 0; |
763 | for (auto label : all_grf_usage_labels()) { |
764 | int est_regs = est_usage.get(label); |
765 | int ir_regs = ir_usage.get(label); |
766 | table << " " + to_string(label) << est_regs << ir_regs; |
767 | table << (ir_regs > est_regs ? "FAIL" : "" ); |
768 | table << std::endl; |
769 | est_total += est_regs; |
770 | ir_total += ir_regs; |
771 | } |
772 | table << " Total" << est_total << ir_total; |
773 | table << (ir_total > est_total ? "FAIL" : "" ); |
774 | table << std::endl; |
775 | ir_trace() << table << std::endl; |
776 | ir_trace() << ir_usage.buf_usage() << std::endl; |
777 | } |
778 | |
779 | void verify_grf_usage( |
780 | const conv_config_t &cfg, const stmt_t &body, int external_usage) { |
781 | ir_usage_analyzer_t analyzer(cfg.grf_size()); |
782 | analyzer.analyze(body); |
783 | |
784 | auto ir_info = analyzer.get_grf_usage(external_usage); |
785 | auto est_info = estimate_grf_usage(cfg); |
786 | compare(est_info, ir_info, analyzer); |
787 | } |
788 | |
789 | } // namespace jit |
790 | } // namespace gpu |
791 | } // namespace impl |
792 | } // namespace dnnl |
793 | |