1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/cse.hpp" |
18 | |
19 | #include <algorithm> |
20 | #include <iostream> |
21 | #include <utility> |
22 | #include <vector> |
23 | #include <type_traits> |
24 | #include <unordered_map> |
25 | |
26 | #include "gpu/jit/codegen/register_allocator.hpp" |
27 | #include "gpu/jit/ir/message.hpp" |
28 | #include "gpu/jit/utils/trace.hpp" |
29 | #include "gpu/jit/utils/utils.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace jit { |
35 | |
36 | // Common subexpression elimination support. |
37 | |
38 | // Represents an expression-candidate to eliminate. |
39 | class cse_expr_t { |
40 | public: |
41 | cse_expr_t(const expr_t &expr, const expr_t &orig_expr, |
42 | const ir_path_t &path, int refs = 1, const expr_t &cse_var = {}) |
43 | : expr(expr) |
44 | , orig_expr(orig_expr) |
45 | , path(path) |
46 | , refs(refs) |
47 | , cse_var(cse_var) { |
48 | ir_trace() << "cse_pass: add expression: " << expr << std::endl; |
49 | } |
50 | |
51 | void add_usage(const ir_path_t &other_path, bool do_increment = true) { |
52 | if (do_increment) refs++; |
53 | path.merge(other_path); |
54 | ir_trace() << "cse_pass: add usage: " << expr |
55 | << ", total refs: " << refs << std::endl; |
56 | } |
57 | |
58 | // Expression to eliminate via let. |
59 | expr_t expr; |
60 | // Original expression to eliminate (doesn't contain any CSEed vars). |
61 | expr_t orig_expr; |
62 | // Path to the innermost IR node where the expression can be defined. |
63 | ir_path_t path; |
64 | // Number of references to the expression. |
65 | int refs; |
66 | // Variable assigned to the expression (if decided to eliminate). |
67 | expr_t cse_var; |
68 | }; |
69 | |
70 | // Helper class for CSE variables to query computational cost |
71 | // while tracking dependencies to other potential CSE |
72 | // variables. |
73 | class cse_var_entry_t { |
74 | public: |
75 | cse_var_entry_t(const cse_expr_t *cse_expr) : cse_expr_(cse_expr) {} |
76 | |
77 | const cse_expr_t *cse_expr() const { return cse_expr_; } |
78 | |
79 | bool allocated() const { return allocated_; } |
80 | |
81 | int size() const { |
82 | return utils::rnd_up( |
83 | cse_expr_->cse_var.type().size(), reg_allocator_t::granularity); |
84 | } |
85 | |
86 | int cost() const { return cost_; } |
87 | |
88 | void set_var2entry( |
89 | const object_map_t<expr_t, cse_var_entry_t *> &var2entry) { |
90 | var2entry_ = &var2entry; |
91 | } |
92 | |
93 | void add_back_ref(cse_var_entry_t *e) { |
94 | ir_assert(e != this); |
95 | back_refs_.insert(e); |
96 | } |
97 | |
98 | void mark_as_allocated() { |
99 | allocated_ = true; |
100 | update_back_ref_cost(); |
101 | } |
102 | |
103 | void recompute_cost() { |
104 | cost_ = expr_cost(cse_expr_->expr) * cse_expr_->refs; |
105 | } |
106 | |
107 | private: |
108 | void update_back_ref_cost() { |
109 | for (auto *e : back_refs_) { |
110 | e->recompute_cost(); |
111 | e->update_back_ref_cost(); |
112 | } |
113 | } |
114 | |
115 | int expr_cost(const expr_t &e) { |
116 | if (is_var(e)) { |
117 | auto it = var2entry_->find(e); |
118 | if (it == var2entry_->end()) return 0; |
119 | if (it->second->allocated_) return 0; |
120 | // If variable is not allocated, its value |
121 | // has to be recomputed every time. |
122 | return it->second->cost(); |
123 | } |
124 | if (is_const(e)) return 0; |
125 | if (e.is<cast_t>()) return 0; |
126 | if (auto *op = e.as_ptr<binary_op_t>()) { |
127 | return expr_cost(op->a) + expr_cost(op->b) + 1; |
128 | } |
129 | if (auto *op = e.as_ptr<unary_op_t>()) { return expr_cost(op->a) + 1; } |
130 | if (auto *s = e.as_ptr<shuffle_t>()) { |
131 | if (s->is_broadcast()) return 0; |
132 | return s->elems(); |
133 | } |
134 | ir_error_not_expected() << "Unhandled expression: " << e; |
135 | return 0; |
136 | } |
137 | |
138 | const cse_expr_t *cse_expr_ = nullptr; |
139 | int cost_ = 0; |
140 | bool allocated_ = false; |
141 | |
142 | std::unordered_set<cse_var_entry_t *> back_refs_; |
143 | const object_map_t<expr_t, cse_var_entry_t *> *var2entry_ = nullptr; |
144 | }; |
145 | |
146 | // Helper class for IR nodes where CSE variables may be |
147 | // generated. Entry stores the peak GRF usage and propagates |
148 | // additional memory usage up and down the IR tree. |
149 | class cse_stmt_entry_t { |
150 | public: |
151 | bool visited() const { return visited_; } |
152 | |
153 | void set_usage(int usage) { |
154 | usage_ = usage; |
155 | visited_ = true; |
156 | }; |
157 | |
158 | void set_parent(cse_stmt_entry_t *parent) { |
159 | parent_ = parent; |
160 | parent_->childs_.push_back(this); |
161 | } |
162 | |
163 | bool try_allocate(int size, int limit) { |
164 | const auto alloc_size |
165 | = utils::rnd_up(size, reg_allocator_t::granularity); |
166 | if (usage_ + alloc_size > limit) return false; |
167 | propagate_usage_down(alloc_size); |
168 | if (parent_) parent_->propagate_usage_up(this); |
169 | return true; |
170 | } |
171 | |
172 | void propagate_usage_up() { |
173 | for (auto *c : childs_) { |
174 | propagate_usage_up(c); |
175 | } |
176 | } |
177 | |
178 | private: |
179 | void propagate_usage_up(const cse_stmt_entry_t *child) { |
180 | if (child->usage_ <= usage_) return; |
181 | usage_ = child->usage_; |
182 | if (parent_) parent_->propagate_usage_up(this); |
183 | } |
184 | |
185 | void propagate_usage_down(int size) { |
186 | usage_ += size; |
187 | for (auto *c : childs_) |
188 | c->propagate_usage_down(size); |
189 | } |
190 | |
191 | int usage_ = 0; |
192 | bool visited_ = false; |
193 | cse_stmt_entry_t *parent_ = nullptr; |
194 | std::vector<cse_stmt_entry_t *> childs_; |
195 | }; |
196 | |
197 | class cse_memory_usage_visitor_t : public ir_visitor_t { |
198 | public: |
199 | cse_memory_usage_visitor_t( |
200 | std::unordered_map<const object_impl_t *, cse_stmt_entry_t> |
201 | &entries, |
202 | const object_eq_map_t<expr_t, cse_expr_t> &cse_exprs, int grf_size) |
203 | : entries_(entries), grf_size_(grf_size) { |
204 | for (auto &kv : cse_exprs) { |
205 | auto &cse_expr = kv.second; |
206 | if (cse_expr.cse_var.is_empty()) continue; |
207 | auto *obj = cse_expr.path.back(); |
208 | entries_.emplace(obj, cse_stmt_entry_t()); |
209 | } |
210 | } |
211 | |
212 | ~cse_memory_usage_visitor_t() override { |
213 | for (auto &kv : entries_) { |
214 | ir_assert(kv.second.visited()) << *kv.first; |
215 | } |
216 | } |
217 | |
218 | #define HANDLE_IR_OBJECT(type) \ |
219 | void _visit(const type &obj) override { visit_stmt(obj); } |
220 | |
221 | HANDLE_STMT_IR_OBJECTS() |
222 | |
223 | #undef HANDLE_IR_OBJECT |
224 | |
225 | private: |
226 | mem_usage_guard_t grf_usage_guard(int size) { |
227 | return mem_usage_guard_t(&cur_usage_, size); |
228 | } |
229 | |
230 | template <typename T> |
231 | void visit_stmt(const T &obj) { |
232 | int obj_usage = 0; |
233 | if (auto *alloc = obj.template as_ptr<alloc_t>()) { |
234 | if (alloc->kind == alloc_kind_t::grf) |
235 | obj_usage = utils::rnd_up(alloc->size, grf_size_); |
236 | } else if (auto *let = obj.template as_ptr<let_t>()) { |
237 | obj_usage = utils::rnd_up( |
238 | let->var.type().size(), reg_allocator_t::granularity); |
239 | } |
240 | |
241 | auto guard = grf_usage_guard(obj_usage); |
242 | cse_stmt_entry_t *entry = nullptr; |
243 | auto it = entries_.find(&obj); |
244 | if (it != entries_.end()) entry = &it->second; |
245 | if (entry) { |
246 | entry->set_usage(cur_usage_); |
247 | if (!path_.empty()) entry->set_parent(path_.back()); |
248 | path_.push_back(entry); |
249 | } |
250 | ir_visitor_t::_visit(obj); |
251 | if (entry) path_.pop_back(); |
252 | } |
253 | |
254 | std::unordered_map<const object_impl_t *, cse_stmt_entry_t> &entries_; |
255 | int grf_size_; |
256 | |
257 | int cur_usage_ = 0; |
258 | std::vector<cse_stmt_entry_t *> path_; |
259 | }; |
260 | |
261 | // Stores information about all expressions subject to CSEing. |
262 | class cse_context_t { |
263 | public: |
264 | cse_context_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {} |
265 | |
266 | ir_context_t &ir_ctx() { return ir_ctx_; } |
267 | |
268 | bool has(const expr_t &e) const { return cse_exprs_.count(e) != 0; } |
269 | |
270 | cse_expr_t &find_cse_expr(const expr_t &e) { |
271 | ir_assert(has(e)) << e; |
272 | return cse_exprs_.at(e); |
273 | } |
274 | |
275 | const cse_expr_t &find_cse_expr(const expr_t &e) const { |
276 | ir_assert(has(e)) << e; |
277 | return cse_exprs_.at(e); |
278 | } |
279 | |
280 | bool has_var(const expr_t &e) const { |
281 | return !find_cse_expr(e).cse_var.is_empty(); |
282 | } |
283 | |
284 | int get_refs(const expr_t &e) const { |
285 | if (!has(e)) return 0; |
286 | return find_cse_expr(e).refs; |
287 | } |
288 | |
289 | void register_expr(const expr_t &e, const ir_path_t &path) { |
290 | if (e.type().is_bool()) return; // Ignore booleans. |
291 | auto ret = cse_exprs_.insert({e, cse_expr_t(e, e, path)}); |
292 | ir_assert(ret.second) << e; |
293 | MAYBE_UNUSED(ret); |
294 | } |
295 | |
296 | void register_expr(const cse_expr_t &cse_expr) { |
297 | auto ret = cse_exprs_.insert({cse_expr.expr, cse_expr}); |
298 | ir_assert(ret.second); |
299 | MAYBE_UNUSED(ret); |
300 | } |
301 | |
302 | expr_t get_or_assign_var(const expr_t &e) { |
303 | auto &cse_expr = find_cse_expr(e); |
304 | if (cse_expr.cse_var.is_empty()) { |
305 | cse_expr.cse_var = ir_ctx_.create_tmp_var(e.type()); |
306 | ir_trace() << "cse_pass: assigning var: " << e << " -> " |
307 | << cse_expr.cse_var << std::endl; |
308 | } |
309 | return cse_expr.cse_var; |
310 | } |
311 | |
312 | const expr_t &get_var(const expr_t &e) const { |
313 | return find_cse_expr(e).cse_var; |
314 | } |
315 | |
316 | const ir_path_t &get_path(const expr_t &e) const { |
317 | return find_cse_expr(e).path; |
318 | } |
319 | |
320 | void add_usage( |
321 | const expr_t &e, const ir_path_t &path, bool do_increment = true) { |
322 | if (e.type().is_bool()) return; // Ignore booleans. |
323 | return find_cse_expr(e).add_usage(path, do_increment); |
324 | } |
325 | |
326 | void update_expr(const expr_t &old_expr, const expr_t &new_expr) { |
327 | auto it = cse_exprs_.find(old_expr); |
328 | ir_assert(it != cse_exprs_.end()) << old_expr; |
329 | auto &old_cse_expr = it->second; |
330 | auto new_cse_expr = cse_expr_t(new_expr, old_cse_expr.orig_expr, |
331 | old_cse_expr.path, old_cse_expr.refs, old_cse_expr.cse_var); |
332 | cse_exprs_.erase(it); |
333 | auto ret = cse_exprs_.insert({new_expr, new_cse_expr}); |
334 | ir_assert(ret.second); |
335 | MAYBE_UNUSED(ret); |
336 | } |
337 | |
338 | template <typename F> |
339 | void for_each(const F &f) const { |
340 | for (auto &kv : cse_exprs_) |
341 | f(kv.first); |
342 | } |
343 | |
344 | bool should_assign_var(const expr_t &e) const { |
345 | if (!has(e)) return false; |
346 | auto &cse_expr = find_cse_expr(e); |
347 | if (cse_expr.refs <= 1) return false; |
348 | if (skip_exprs_.count(cse_expr.orig_expr) != 0) return false; |
349 | return true; |
350 | } |
351 | |
352 | void set_skip_exprs(const stmt_t &root, int limit, int grf_size) { |
353 | // Initialize variable-entry for each potential CSE variable. |
354 | std::vector<cse_var_entry_t> var_entries; |
355 | for (auto &kv : cse_exprs_) { |
356 | auto &cse_expr = kv.second; |
357 | if (cse_expr.cse_var.is_empty()) continue; |
358 | var_entries.emplace_back(&cse_expr); |
359 | } |
360 | // Create mapping from CSE var to entry. |
361 | object_map_t<expr_t, cse_var_entry_t *> var2entry; |
362 | for (auto &e : var_entries) { |
363 | var2entry.emplace(e.cse_expr()->cse_var, &e); |
364 | e.set_var2entry(var2entry); |
365 | } |
366 | // Initialize back references. |
367 | for (auto &e : var_entries) { |
368 | auto vars = find_objects<var_t>(e.cse_expr()->expr); |
369 | for (auto &v : vars) { |
370 | auto it = var2entry.find(v); |
371 | if (it == var2entry.end()) continue; |
372 | it->second->add_back_ref(&e); |
373 | } |
374 | var2entry.emplace(e.cse_expr()->cse_var, &e); |
375 | } |
376 | // Initialize cost. |
377 | for (auto &e : var_entries) { |
378 | e.recompute_cost(); |
379 | } |
380 | // Initialize statement-entry for each potential statement of CSE |
381 | // variable attachement. |
382 | std::unordered_map<const object_impl_t *, cse_stmt_entry_t> |
383 | stmt_entries; |
384 | cse_memory_usage_visitor_t mem_usage_visitor( |
385 | stmt_entries, cse_exprs_, grf_size); |
386 | mem_usage_visitor.visit(root); |
387 | for (auto &kv : stmt_entries) |
388 | kv.second.propagate_usage_up(); |
389 | |
390 | // Greedily find the variable with the highest current complexity that |
391 | // won't exceed the usage limit, mark it as allocated and recompute |
392 | // complexity for other dependent vars. Stop once there are no |
393 | // such variables. |
394 | std::vector<cse_var_entry_t *> sorted_var_entries; |
395 | for (auto &e : var_entries) |
396 | sorted_var_entries.push_back(&e); |
397 | |
398 | for (auto it = sorted_var_entries.begin(); |
399 | it != sorted_var_entries.end();) { |
400 | std::sort(it, sorted_var_entries.end(), |
401 | [&](const cse_var_entry_t *a, const cse_var_entry_t *b) { |
402 | return a->cost() > b->cost(); |
403 | }); |
404 | while (it != sorted_var_entries.end()) { |
405 | auto &e = **it; |
406 | auto &stmt_entry = stmt_entries.at(e.cse_expr()->path.back()); |
407 | if (stmt_entry.try_allocate(e.size(), limit)) { |
408 | e.mark_as_allocated(); |
409 | ++it; |
410 | break; |
411 | } |
412 | ++it; |
413 | } |
414 | } |
415 | |
416 | // Skip not allocated variables. |
417 | for (auto &e : var_entries) { |
418 | if (e.allocated()) continue; |
419 | skip_exprs_.insert(e.cse_expr()->orig_expr); |
420 | } |
421 | } |
422 | |
423 | void reset_cse_exprs() { cse_exprs_.clear(); } |
424 | |
425 | private: |
426 | ir_context_t &ir_ctx_; |
427 | object_eq_map_t<expr_t, cse_expr_t> cse_exprs_; |
428 | object_eq_set_t<expr_t> skip_exprs_; |
429 | }; |
430 | |
431 | // Collects statistics about expressions for common subexpression elimination. |
432 | class cse_visitor_t : public ir_visitor_t { |
433 | public: |
434 | cse_visitor_t(cse_context_t &ctx) : ctx_(ctx) {} |
435 | |
436 | void _visit(const binary_op_t &obj) override { visit_expr(obj); } |
437 | void _visit(const shuffle_t &obj) override { |
438 | if (is_const_broadcast(obj)) return; |
439 | visit_expr(obj); |
440 | } |
441 | void _visit(const unary_op_t &obj) override { visit_expr(obj); } |
442 | |
443 | #define HANDLE_IR_OBJECT(type) \ |
444 | void _visit(const type &obj) override { visit_stmt(obj); } |
445 | |
446 | HANDLE_STMT_IR_OBJECTS() |
447 | |
448 | #undef HANDLE_IR_OBJECT |
449 | |
450 | private: |
451 | template <typename T> |
452 | void visit_expr(const T &obj) { |
453 | // Exclude loads as they may have side effects. |
454 | if (count_objects<load_t>(obj) > 0) { |
455 | ir_visitor_t::_visit(obj); |
456 | return; |
457 | } |
458 | |
459 | if (std::is_same<T, shuffle_t>::value) { |
460 | auto &shuffle = reinterpret_cast<const shuffle_t &>(obj); |
461 | if (shuffle.is_broadcast()) { |
462 | ir_visitor_t::_visit(obj); |
463 | return; |
464 | } |
465 | } |
466 | |
467 | if (propagate_path_) { |
468 | if (ctx_.has(obj)) |
469 | ctx_.add_usage(obj, root_path_, /*do_increment=*/false); |
470 | ir_visitor_t::_visit(obj); |
471 | return; |
472 | } |
473 | if (ctx_.has(obj)) { |
474 | ctx_.add_usage(obj, root_path_); |
475 | propagate_path_ = true; |
476 | ir_visitor_t::_visit(obj); |
477 | propagate_path_ = false; |
478 | return; |
479 | } |
480 | ir_visitor_t::_visit(obj); |
481 | ctx_.register_expr(obj, root_path_); |
482 | } |
483 | |
484 | template <typename T> |
485 | void visit_stmt(const T &obj) { |
486 | if (std::is_same<T, for_t>::value) { |
487 | visit_for((const object_impl_t &)obj); |
488 | return; |
489 | } |
490 | if (std::is_same<T, let_t>::value) { |
491 | visit_let((const object_impl_t &)obj); |
492 | return; |
493 | } |
494 | root_path_.push(&obj); |
495 | ir_visitor_t::_visit(obj); |
496 | root_path_.pop(); |
497 | } |
498 | |
499 | void visit_for(const object_impl_t &_obj) { |
500 | auto &obj = (const for_t &)_obj; |
501 | |
502 | visit(obj.var); |
503 | visit(obj.init); |
504 | visit(obj.bound); |
505 | root_path_.push(&obj); |
506 | visit(obj.body); |
507 | root_path_.pop(); |
508 | } |
509 | |
510 | void visit_let(const object_impl_t &_obj) { |
511 | auto &obj = (const let_t &)_obj; |
512 | |
513 | visit(obj.var); |
514 | visit(obj.value); |
515 | root_path_.push(&obj); |
516 | visit(obj.body); |
517 | root_path_.pop(); |
518 | } |
519 | |
520 | cse_context_t &ctx_; |
521 | ir_path_t root_path_; |
522 | |
523 | bool propagate_path_ = false; |
524 | }; |
525 | |
526 | // Verifies all IR paths are correct (for debugging purposes). |
527 | class cse_verifier_t : public scope_visitor_t { |
528 | public: |
529 | cse_verifier_t(cse_context_t &ctx) : ctx_(ctx) {} |
530 | |
531 | ~cse_verifier_t() override { ir_assert(to_check_.empty()); } |
532 | |
533 | void _visit(const binary_op_t &obj) override { visit_expr(obj); } |
534 | void _visit(const shuffle_t &obj) override { return visit_expr(obj); } |
535 | void _visit(const unary_op_t &obj) override { visit_expr(obj); } |
536 | |
537 | #define HANDLE_IR_OBJECT(type) \ |
538 | void _visit(const type &obj) override { visit_stmt(obj); } |
539 | |
540 | HANDLE_STMT_IR_OBJECTS() |
541 | |
542 | #undef HANDLE_IR_OBJECT |
543 | |
544 | void verify(const stmt_t &s) { |
545 | // Phase 0: collect IR paths for expressions. |
546 | phase_ = 0; |
547 | visit(s); |
548 | |
549 | // Phase 1: verify all expressions are defined at their path. |
550 | phase_ = 1; |
551 | visit(s); |
552 | } |
553 | |
554 | private: |
555 | template <typename T> |
556 | void visit_expr(const T &obj) { |
557 | // Expressions are not used during phase 1. |
558 | if (phase_ == 1) return; |
559 | if (ctx_.has(obj)) { |
560 | auto &path = ctx_.get_path(obj); |
561 | to_check_[path.back()].push_back(obj); |
562 | } |
563 | scope_visitor_t::_visit(obj); |
564 | } |
565 | |
566 | template <typename T> |
567 | void visit_stmt(const T &obj) { |
568 | scope_visitor_t::_visit(obj); |
569 | |
570 | // Statements are not used during phase 0. |
571 | if (phase_ == 0) return; |
572 | |
573 | // Phase 1: check that all attached expressions are defined at this |
574 | // statement. |
575 | auto it = to_check_.find(obj); |
576 | if (it != to_check_.end()) { |
577 | for (auto &e : it->second) { |
578 | ir_assert(is_expr_defined(e)) |
579 | << "Expression contains undefined variables: " << e; |
580 | MAYBE_UNUSED(e); |
581 | } |
582 | to_check_.erase(it); |
583 | } |
584 | } |
585 | |
586 | cse_context_t &ctx_; |
587 | |
588 | int phase_ = 0; |
589 | object_map_t<stmt_t, std::vector<expr_t>> to_check_; |
590 | }; |
591 | |
592 | // Generates let statements for expressions being eliminated. |
593 | class cse_let_generator_t : public ir_visitor_t { |
594 | public: |
595 | cse_let_generator_t(const cse_context_t &ctx, const stmt_t &stmt) |
596 | : ctx_(ctx), stmt_(stmt) {} |
597 | |
598 | void _visit(const binary_op_t &obj) override { visit_expr(obj); } |
599 | void _visit(const shuffle_t &obj) override { visit_expr(obj); } |
600 | void _visit(const unary_op_t &obj) override { visit_expr(obj); } |
601 | void _visit(const var_t &obj) override { |
602 | auto it = all_vars_.find(obj); |
603 | if (it == all_vars_.end()) return; |
604 | if (seen_vars_.count(obj) == 0) generate_for_expr(it->second); |
605 | } |
606 | |
607 | stmt_t generate() { |
608 | ctx_.for_each([&](const expr_t &e) { |
609 | auto &cse_var = ctx_.get_var(e); |
610 | auto ret = all_vars_.insert({cse_var, e}); |
611 | ir_assert(ret.second); |
612 | MAYBE_UNUSED(ret); |
613 | }); |
614 | ctx_.for_each([&](const expr_t &e) { generate_for_expr(e); }); |
615 | for (auto it = lets_.rbegin(); it != lets_.rend(); ++it) { |
616 | auto &let = it->as<let_t>(); |
617 | stmt_ = let_t::make(let.var, let.value, stmt_); |
618 | } |
619 | return stmt_; |
620 | } |
621 | |
622 | private: |
623 | void generate_for_expr(const expr_t &e) { |
624 | auto &cse_var = ctx_.get_var(e); |
625 | if (seen_vars_.count(cse_var) == 1) return; |
626 | visit(e); |
627 | } |
628 | |
629 | template <typename T> |
630 | void visit_expr(const T &obj) { |
631 | ir_visitor_t::_visit(obj); |
632 | if (ctx_.has(obj) && ctx_.has_var(obj)) { |
633 | auto &var = ctx_.get_var(obj); |
634 | auto ret = seen_vars_.insert(var); |
635 | if (ret.second) lets_.push_back(let_t::make(var, obj)); |
636 | } |
637 | } |
638 | |
639 | const cse_context_t &ctx_; |
640 | stmt_t stmt_; |
641 | |
642 | object_map_t<expr_t, expr_t> all_vars_; // Var -> expression. |
643 | object_set_t<expr_t> seen_vars_; |
644 | |
645 | std::vector<stmt_t> lets_; |
646 | }; |
647 | |
648 | // Eliminates expressions from the statement. |
649 | class cse_mutator_t : public ir_mutator_t { |
650 | public: |
651 | cse_mutator_t(cse_context_t &ctx) : ctx_(ctx) {} |
652 | |
653 | object_t _mutate(const binary_op_t &obj) override { |
654 | return mutate_expr(obj); |
655 | } |
656 | object_t _mutate(const shuffle_t &obj) override { return mutate_expr(obj); } |
657 | object_t _mutate(const unary_op_t &obj) override { |
658 | return mutate_expr(obj); |
659 | } |
660 | |
661 | #define HANDLE_IR_OBJECT(type) \ |
662 | object_t _mutate(const type &obj) override { return mutate_stmt(obj); } |
663 | |
664 | HANDLE_STMT_IR_OBJECTS() |
665 | |
666 | #undef HANDLE_IR_OBJECT |
667 | |
668 | private: |
669 | template <typename T> |
670 | object_t mutate_expr(const T &obj) { |
671 | auto new_obj = ir_mutator_t::_mutate(obj); |
672 | if (ctx_.has(obj) && !new_obj.is_equal(obj)) { |
673 | ctx_.update_expr(obj, new_obj); |
674 | } |
675 | if (ctx_.should_assign_var(new_obj)) { |
676 | bool has_var = ctx_.has_var(new_obj); |
677 | auto var = ctx_.get_or_assign_var(new_obj); |
678 | auto &path = ctx_.get_path(new_obj); |
679 | if (!has_var) to_update_[path.back()].push_back(new_obj); |
680 | return std::move(var); |
681 | } |
682 | return new_obj; |
683 | } |
684 | |
685 | template <typename T> |
686 | object_t mutate_stmt(const T &obj) { |
687 | auto new_obj = ir_mutator_t::_mutate(obj); |
688 | auto it = to_update_.find(obj); |
689 | if (it == to_update_.end()) return new_obj; |
690 | |
691 | cse_context_t local_ctx(ctx_.ir_ctx()); |
692 | for (auto &e : it->second) { |
693 | local_ctx.register_expr(ctx_.find_cse_expr(e)); |
694 | } |
695 | to_update_.erase(it); |
696 | |
697 | auto body = get_stmt_body(new_obj); |
698 | cse_let_generator_t g(local_ctx, body); |
699 | body = g.generate(); |
700 | new_obj = replace_stmt_body(new_obj, body); |
701 | return new_obj; |
702 | } |
703 | |
704 | cse_context_t &ctx_; |
705 | object_map_t<stmt_t, std::vector<expr_t>> to_update_; |
706 | }; |
707 | |
708 | stmt_t eliminate_common_subexprs_impl(const stmt_t &_stmt, cse_context_t &ctx, |
709 | int grf_size, int memory_usage_limit, int run_idx) { |
710 | auto stmt = _stmt; |
711 | |
712 | // Collect statistics. |
713 | cse_visitor_t visitor(ctx); |
714 | visitor.visit(stmt); |
715 | |
716 | #if !defined(NDEBUG) || defined(GEN_CONV_DEBUG) |
717 | // Verify that collected IR paths are correct (cse_expr_t objects are |
718 | // defined at those paths). |
719 | cse_verifier_t verifier(ctx); |
720 | verifier.verify(stmt); |
721 | #endif |
722 | |
723 | // Eliminate subexpressions. |
724 | cse_mutator_t mutator(ctx); |
725 | stmt = mutator.mutate(stmt); |
726 | |
727 | // The second run is the last run. |
728 | if (run_idx != 0) return stmt; |
729 | |
730 | // If memory usage exceeds the limit, exclude some |
731 | // expressions from CSE and retry the whole process from |
732 | // scratch. |
733 | int memory_usage = get_peak_grf_usage(stmt, grf_size) * grf_size; |
734 | if (memory_usage > memory_usage_limit) { |
735 | ir_trace() << "CSE exceeded GRF usage limit. Usage: " << memory_usage |
736 | << ", limit: " << memory_usage_limit |
737 | << ". Retry CSE and skip some expressions..." << std::endl; |
738 | ctx.set_skip_exprs(_stmt, memory_usage_limit, grf_size); |
739 | ctx.reset_cse_exprs(); |
740 | return stmt_t(); |
741 | } |
742 | |
743 | return stmt; |
744 | } |
745 | |
746 | stmt_t eliminate_common_subexprs( |
747 | const stmt_t &_stmt, ir_context_t &ir_ctx, int memory_usage_limit) { |
748 | trace_start(); |
749 | stmt_t stmt; |
750 | cse_context_t cse_ctx(ir_ctx); |
751 | |
752 | int grf_size = ir_ctx.hw_cfg().grf_size(); |
753 | stmt = eliminate_common_subexprs_impl( |
754 | _stmt, cse_ctx, grf_size, memory_usage_limit, 0); |
755 | // Retry if statement is empty, rely on the updated |
756 | // skip_exprs from the CSE context. |
757 | if (stmt.is_empty()) { |
758 | stmt = eliminate_common_subexprs_impl( |
759 | _stmt, cse_ctx, grf_size, memory_usage_limit, 1); |
760 | } |
761 | trace_pass("eliminate_common_subexprs" , stmt, ir_ctx); |
762 | return stmt; |
763 | } |
764 | |
765 | class g2s_buf_visitor_t : public ir_visitor_t { |
766 | public: |
767 | int g2s_buf_size() const { |
768 | int ret = 0; |
769 | for (auto &kv : g2s_bufs_) { |
770 | ir_assert(kv.second != 0); |
771 | ret += kv.second; |
772 | } |
773 | return ret; |
774 | } |
775 | |
776 | void _visit(const alloc_t &obj) override { |
777 | ir_visitor_t::_visit(obj); |
778 | auto it = g2s_bufs_.find(obj.buf); |
779 | if (it != g2s_bufs_.end()) it->second = obj.size; |
780 | } |
781 | |
782 | void _visit(const func_call_t &obj) override { |
783 | if (!in_g2s_) { |
784 | ir_visitor_t::_visit(obj); |
785 | return; |
786 | } |
787 | if (auto *func = obj.func.as_ptr<send_t>()) { |
788 | ir_assert(func->is_load()) << func; |
789 | auto &buf = send_t::arg_reg_buf(obj); |
790 | g2s_bufs_.emplace(get_base(buf), 0); |
791 | } |
792 | ir_visitor_t::_visit(obj); |
793 | } |
794 | |
795 | void _visit(const stmt_group_t &obj) override { |
796 | bool is_g2s = obj.label == stmt_label_t::g2s_load(); |
797 | if (is_g2s) in_g2s_ = true; |
798 | ir_visitor_t::_visit(obj); |
799 | if (is_g2s) in_g2s_ = false; |
800 | } |
801 | |
802 | private: |
803 | object_map_t<expr_t, int> g2s_bufs_; |
804 | bool in_g2s_ = false; |
805 | }; |
806 | |
807 | stmt_t eliminate_common_subexprs(const stmt_t &stmt, ir_context_t &ir_ctx, |
808 | int reserved_regs, int gmem_bufs) { |
809 | int grf_size = ir_ctx.grf_size(); |
810 | int available_regs = ir_ctx.exec_cfg().regs() - reserved_regs; |
811 | int memory_usage_limit = available_regs * grf_size; |
812 | if (gmem_bufs > 1) { |
813 | g2s_buf_visitor_t v; |
814 | v.visit(stmt); |
815 | memory_usage_limit -= (gmem_bufs - 1) * v.g2s_buf_size(); |
816 | } |
817 | return eliminate_common_subexprs(stmt, ir_ctx, memory_usage_limit); |
818 | } |
819 | |
820 | } // namespace jit |
821 | } // namespace gpu |
822 | } // namespace impl |
823 | } // namespace dnnl |
824 | |