1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/cse.hpp"
18
19#include <algorithm>
20#include <iostream>
21#include <utility>
22#include <vector>
23#include <type_traits>
24#include <unordered_map>
25
26#include "gpu/jit/codegen/register_allocator.hpp"
27#include "gpu/jit/ir/message.hpp"
28#include "gpu/jit/utils/trace.hpp"
29#include "gpu/jit/utils/utils.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace jit {
35
36// Common subexpression elimination support.
37
38// Represents an expression-candidate to eliminate.
39class cse_expr_t {
40public:
41 cse_expr_t(const expr_t &expr, const expr_t &orig_expr,
42 const ir_path_t &path, int refs = 1, const expr_t &cse_var = {})
43 : expr(expr)
44 , orig_expr(orig_expr)
45 , path(path)
46 , refs(refs)
47 , cse_var(cse_var) {
48 ir_trace() << "cse_pass: add expression: " << expr << std::endl;
49 }
50
51 void add_usage(const ir_path_t &other_path, bool do_increment = true) {
52 if (do_increment) refs++;
53 path.merge(other_path);
54 ir_trace() << "cse_pass: add usage: " << expr
55 << ", total refs: " << refs << std::endl;
56 }
57
58 // Expression to eliminate via let.
59 expr_t expr;
60 // Original expression to eliminate (doesn't contain any CSEed vars).
61 expr_t orig_expr;
62 // Path to the innermost IR node where the expression can be defined.
63 ir_path_t path;
64 // Number of references to the expression.
65 int refs;
66 // Variable assigned to the expression (if decided to eliminate).
67 expr_t cse_var;
68};
69
70// Helper class for CSE variables to query computational cost
71// while tracking dependencies to other potential CSE
72// variables.
73class cse_var_entry_t {
74public:
75 cse_var_entry_t(const cse_expr_t *cse_expr) : cse_expr_(cse_expr) {}
76
77 const cse_expr_t *cse_expr() const { return cse_expr_; }
78
79 bool allocated() const { return allocated_; }
80
81 int size() const {
82 return utils::rnd_up(
83 cse_expr_->cse_var.type().size(), reg_allocator_t::granularity);
84 }
85
86 int cost() const { return cost_; }
87
88 void set_var2entry(
89 const object_map_t<expr_t, cse_var_entry_t *> &var2entry) {
90 var2entry_ = &var2entry;
91 }
92
93 void add_back_ref(cse_var_entry_t *e) {
94 ir_assert(e != this);
95 back_refs_.insert(e);
96 }
97
98 void mark_as_allocated() {
99 allocated_ = true;
100 update_back_ref_cost();
101 }
102
103 void recompute_cost() {
104 cost_ = expr_cost(cse_expr_->expr) * cse_expr_->refs;
105 }
106
107private:
108 void update_back_ref_cost() {
109 for (auto *e : back_refs_) {
110 e->recompute_cost();
111 e->update_back_ref_cost();
112 }
113 }
114
115 int expr_cost(const expr_t &e) {
116 if (is_var(e)) {
117 auto it = var2entry_->find(e);
118 if (it == var2entry_->end()) return 0;
119 if (it->second->allocated_) return 0;
120 // If variable is not allocated, its value
121 // has to be recomputed every time.
122 return it->second->cost();
123 }
124 if (is_const(e)) return 0;
125 if (e.is<cast_t>()) return 0;
126 if (auto *op = e.as_ptr<binary_op_t>()) {
127 return expr_cost(op->a) + expr_cost(op->b) + 1;
128 }
129 if (auto *op = e.as_ptr<unary_op_t>()) { return expr_cost(op->a) + 1; }
130 if (auto *s = e.as_ptr<shuffle_t>()) {
131 if (s->is_broadcast()) return 0;
132 return s->elems();
133 }
134 ir_error_not_expected() << "Unhandled expression: " << e;
135 return 0;
136 }
137
138 const cse_expr_t *cse_expr_ = nullptr;
139 int cost_ = 0;
140 bool allocated_ = false;
141
142 std::unordered_set<cse_var_entry_t *> back_refs_;
143 const object_map_t<expr_t, cse_var_entry_t *> *var2entry_ = nullptr;
144};
145
146// Helper class for IR nodes where CSE variables may be
147// generated. Entry stores the peak GRF usage and propagates
148// additional memory usage up and down the IR tree.
149class cse_stmt_entry_t {
150public:
151 bool visited() const { return visited_; }
152
153 void set_usage(int usage) {
154 usage_ = usage;
155 visited_ = true;
156 };
157
158 void set_parent(cse_stmt_entry_t *parent) {
159 parent_ = parent;
160 parent_->childs_.push_back(this);
161 }
162
163 bool try_allocate(int size, int limit) {
164 const auto alloc_size
165 = utils::rnd_up(size, reg_allocator_t::granularity);
166 if (usage_ + alloc_size > limit) return false;
167 propagate_usage_down(alloc_size);
168 if (parent_) parent_->propagate_usage_up(this);
169 return true;
170 }
171
172 void propagate_usage_up() {
173 for (auto *c : childs_) {
174 propagate_usage_up(c);
175 }
176 }
177
178private:
179 void propagate_usage_up(const cse_stmt_entry_t *child) {
180 if (child->usage_ <= usage_) return;
181 usage_ = child->usage_;
182 if (parent_) parent_->propagate_usage_up(this);
183 }
184
185 void propagate_usage_down(int size) {
186 usage_ += size;
187 for (auto *c : childs_)
188 c->propagate_usage_down(size);
189 }
190
191 int usage_ = 0;
192 bool visited_ = false;
193 cse_stmt_entry_t *parent_ = nullptr;
194 std::vector<cse_stmt_entry_t *> childs_;
195};
196
197class cse_memory_usage_visitor_t : public ir_visitor_t {
198public:
199 cse_memory_usage_visitor_t(
200 std::unordered_map<const object_impl_t *, cse_stmt_entry_t>
201 &entries,
202 const object_eq_map_t<expr_t, cse_expr_t> &cse_exprs, int grf_size)
203 : entries_(entries), grf_size_(grf_size) {
204 for (auto &kv : cse_exprs) {
205 auto &cse_expr = kv.second;
206 if (cse_expr.cse_var.is_empty()) continue;
207 auto *obj = cse_expr.path.back();
208 entries_.emplace(obj, cse_stmt_entry_t());
209 }
210 }
211
212 ~cse_memory_usage_visitor_t() override {
213 for (auto &kv : entries_) {
214 ir_assert(kv.second.visited()) << *kv.first;
215 }
216 }
217
218#define HANDLE_IR_OBJECT(type) \
219 void _visit(const type &obj) override { visit_stmt(obj); }
220
221 HANDLE_STMT_IR_OBJECTS()
222
223#undef HANDLE_IR_OBJECT
224
225private:
226 mem_usage_guard_t grf_usage_guard(int size) {
227 return mem_usage_guard_t(&cur_usage_, size);
228 }
229
230 template <typename T>
231 void visit_stmt(const T &obj) {
232 int obj_usage = 0;
233 if (auto *alloc = obj.template as_ptr<alloc_t>()) {
234 if (alloc->kind == alloc_kind_t::grf)
235 obj_usage = utils::rnd_up(alloc->size, grf_size_);
236 } else if (auto *let = obj.template as_ptr<let_t>()) {
237 obj_usage = utils::rnd_up(
238 let->var.type().size(), reg_allocator_t::granularity);
239 }
240
241 auto guard = grf_usage_guard(obj_usage);
242 cse_stmt_entry_t *entry = nullptr;
243 auto it = entries_.find(&obj);
244 if (it != entries_.end()) entry = &it->second;
245 if (entry) {
246 entry->set_usage(cur_usage_);
247 if (!path_.empty()) entry->set_parent(path_.back());
248 path_.push_back(entry);
249 }
250 ir_visitor_t::_visit(obj);
251 if (entry) path_.pop_back();
252 }
253
254 std::unordered_map<const object_impl_t *, cse_stmt_entry_t> &entries_;
255 int grf_size_;
256
257 int cur_usage_ = 0;
258 std::vector<cse_stmt_entry_t *> path_;
259};
260
261// Stores information about all expressions subject to CSEing.
262class cse_context_t {
263public:
264 cse_context_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
265
266 ir_context_t &ir_ctx() { return ir_ctx_; }
267
268 bool has(const expr_t &e) const { return cse_exprs_.count(e) != 0; }
269
270 cse_expr_t &find_cse_expr(const expr_t &e) {
271 ir_assert(has(e)) << e;
272 return cse_exprs_.at(e);
273 }
274
275 const cse_expr_t &find_cse_expr(const expr_t &e) const {
276 ir_assert(has(e)) << e;
277 return cse_exprs_.at(e);
278 }
279
280 bool has_var(const expr_t &e) const {
281 return !find_cse_expr(e).cse_var.is_empty();
282 }
283
284 int get_refs(const expr_t &e) const {
285 if (!has(e)) return 0;
286 return find_cse_expr(e).refs;
287 }
288
289 void register_expr(const expr_t &e, const ir_path_t &path) {
290 if (e.type().is_bool()) return; // Ignore booleans.
291 auto ret = cse_exprs_.insert({e, cse_expr_t(e, e, path)});
292 ir_assert(ret.second) << e;
293 MAYBE_UNUSED(ret);
294 }
295
296 void register_expr(const cse_expr_t &cse_expr) {
297 auto ret = cse_exprs_.insert({cse_expr.expr, cse_expr});
298 ir_assert(ret.second);
299 MAYBE_UNUSED(ret);
300 }
301
302 expr_t get_or_assign_var(const expr_t &e) {
303 auto &cse_expr = find_cse_expr(e);
304 if (cse_expr.cse_var.is_empty()) {
305 cse_expr.cse_var = ir_ctx_.create_tmp_var(e.type());
306 ir_trace() << "cse_pass: assigning var: " << e << " -> "
307 << cse_expr.cse_var << std::endl;
308 }
309 return cse_expr.cse_var;
310 }
311
312 const expr_t &get_var(const expr_t &e) const {
313 return find_cse_expr(e).cse_var;
314 }
315
316 const ir_path_t &get_path(const expr_t &e) const {
317 return find_cse_expr(e).path;
318 }
319
320 void add_usage(
321 const expr_t &e, const ir_path_t &path, bool do_increment = true) {
322 if (e.type().is_bool()) return; // Ignore booleans.
323 return find_cse_expr(e).add_usage(path, do_increment);
324 }
325
326 void update_expr(const expr_t &old_expr, const expr_t &new_expr) {
327 auto it = cse_exprs_.find(old_expr);
328 ir_assert(it != cse_exprs_.end()) << old_expr;
329 auto &old_cse_expr = it->second;
330 auto new_cse_expr = cse_expr_t(new_expr, old_cse_expr.orig_expr,
331 old_cse_expr.path, old_cse_expr.refs, old_cse_expr.cse_var);
332 cse_exprs_.erase(it);
333 auto ret = cse_exprs_.insert({new_expr, new_cse_expr});
334 ir_assert(ret.second);
335 MAYBE_UNUSED(ret);
336 }
337
338 template <typename F>
339 void for_each(const F &f) const {
340 for (auto &kv : cse_exprs_)
341 f(kv.first);
342 }
343
344 bool should_assign_var(const expr_t &e) const {
345 if (!has(e)) return false;
346 auto &cse_expr = find_cse_expr(e);
347 if (cse_expr.refs <= 1) return false;
348 if (skip_exprs_.count(cse_expr.orig_expr) != 0) return false;
349 return true;
350 }
351
352 void set_skip_exprs(const stmt_t &root, int limit, int grf_size) {
353 // Initialize variable-entry for each potential CSE variable.
354 std::vector<cse_var_entry_t> var_entries;
355 for (auto &kv : cse_exprs_) {
356 auto &cse_expr = kv.second;
357 if (cse_expr.cse_var.is_empty()) continue;
358 var_entries.emplace_back(&cse_expr);
359 }
360 // Create mapping from CSE var to entry.
361 object_map_t<expr_t, cse_var_entry_t *> var2entry;
362 for (auto &e : var_entries) {
363 var2entry.emplace(e.cse_expr()->cse_var, &e);
364 e.set_var2entry(var2entry);
365 }
366 // Initialize back references.
367 for (auto &e : var_entries) {
368 auto vars = find_objects<var_t>(e.cse_expr()->expr);
369 for (auto &v : vars) {
370 auto it = var2entry.find(v);
371 if (it == var2entry.end()) continue;
372 it->second->add_back_ref(&e);
373 }
374 var2entry.emplace(e.cse_expr()->cse_var, &e);
375 }
376 // Initialize cost.
377 for (auto &e : var_entries) {
378 e.recompute_cost();
379 }
380 // Initialize statement-entry for each potential statement of CSE
381 // variable attachement.
382 std::unordered_map<const object_impl_t *, cse_stmt_entry_t>
383 stmt_entries;
384 cse_memory_usage_visitor_t mem_usage_visitor(
385 stmt_entries, cse_exprs_, grf_size);
386 mem_usage_visitor.visit(root);
387 for (auto &kv : stmt_entries)
388 kv.second.propagate_usage_up();
389
390 // Greedily find the variable with the highest current complexity that
391 // won't exceed the usage limit, mark it as allocated and recompute
392 // complexity for other dependent vars. Stop once there are no
393 // such variables.
394 std::vector<cse_var_entry_t *> sorted_var_entries;
395 for (auto &e : var_entries)
396 sorted_var_entries.push_back(&e);
397
398 for (auto it = sorted_var_entries.begin();
399 it != sorted_var_entries.end();) {
400 std::sort(it, sorted_var_entries.end(),
401 [&](const cse_var_entry_t *a, const cse_var_entry_t *b) {
402 return a->cost() > b->cost();
403 });
404 while (it != sorted_var_entries.end()) {
405 auto &e = **it;
406 auto &stmt_entry = stmt_entries.at(e.cse_expr()->path.back());
407 if (stmt_entry.try_allocate(e.size(), limit)) {
408 e.mark_as_allocated();
409 ++it;
410 break;
411 }
412 ++it;
413 }
414 }
415
416 // Skip not allocated variables.
417 for (auto &e : var_entries) {
418 if (e.allocated()) continue;
419 skip_exprs_.insert(e.cse_expr()->orig_expr);
420 }
421 }
422
423 void reset_cse_exprs() { cse_exprs_.clear(); }
424
425private:
426 ir_context_t &ir_ctx_;
427 object_eq_map_t<expr_t, cse_expr_t> cse_exprs_;
428 object_eq_set_t<expr_t> skip_exprs_;
429};
430
431// Collects statistics about expressions for common subexpression elimination.
432class cse_visitor_t : public ir_visitor_t {
433public:
434 cse_visitor_t(cse_context_t &ctx) : ctx_(ctx) {}
435
436 void _visit(const binary_op_t &obj) override { visit_expr(obj); }
437 void _visit(const shuffle_t &obj) override {
438 if (is_const_broadcast(obj)) return;
439 visit_expr(obj);
440 }
441 void _visit(const unary_op_t &obj) override { visit_expr(obj); }
442
443#define HANDLE_IR_OBJECT(type) \
444 void _visit(const type &obj) override { visit_stmt(obj); }
445
446 HANDLE_STMT_IR_OBJECTS()
447
448#undef HANDLE_IR_OBJECT
449
450private:
451 template <typename T>
452 void visit_expr(const T &obj) {
453 // Exclude loads as they may have side effects.
454 if (count_objects<load_t>(obj) > 0) {
455 ir_visitor_t::_visit(obj);
456 return;
457 }
458
459 if (std::is_same<T, shuffle_t>::value) {
460 auto &shuffle = reinterpret_cast<const shuffle_t &>(obj);
461 if (shuffle.is_broadcast()) {
462 ir_visitor_t::_visit(obj);
463 return;
464 }
465 }
466
467 if (propagate_path_) {
468 if (ctx_.has(obj))
469 ctx_.add_usage(obj, root_path_, /*do_increment=*/false);
470 ir_visitor_t::_visit(obj);
471 return;
472 }
473 if (ctx_.has(obj)) {
474 ctx_.add_usage(obj, root_path_);
475 propagate_path_ = true;
476 ir_visitor_t::_visit(obj);
477 propagate_path_ = false;
478 return;
479 }
480 ir_visitor_t::_visit(obj);
481 ctx_.register_expr(obj, root_path_);
482 }
483
484 template <typename T>
485 void visit_stmt(const T &obj) {
486 if (std::is_same<T, for_t>::value) {
487 visit_for((const object_impl_t &)obj);
488 return;
489 }
490 if (std::is_same<T, let_t>::value) {
491 visit_let((const object_impl_t &)obj);
492 return;
493 }
494 root_path_.push(&obj);
495 ir_visitor_t::_visit(obj);
496 root_path_.pop();
497 }
498
499 void visit_for(const object_impl_t &_obj) {
500 auto &obj = (const for_t &)_obj;
501
502 visit(obj.var);
503 visit(obj.init);
504 visit(obj.bound);
505 root_path_.push(&obj);
506 visit(obj.body);
507 root_path_.pop();
508 }
509
510 void visit_let(const object_impl_t &_obj) {
511 auto &obj = (const let_t &)_obj;
512
513 visit(obj.var);
514 visit(obj.value);
515 root_path_.push(&obj);
516 visit(obj.body);
517 root_path_.pop();
518 }
519
520 cse_context_t &ctx_;
521 ir_path_t root_path_;
522
523 bool propagate_path_ = false;
524};
525
526// Verifies all IR paths are correct (for debugging purposes).
527class cse_verifier_t : public scope_visitor_t {
528public:
529 cse_verifier_t(cse_context_t &ctx) : ctx_(ctx) {}
530
531 ~cse_verifier_t() override { ir_assert(to_check_.empty()); }
532
533 void _visit(const binary_op_t &obj) override { visit_expr(obj); }
534 void _visit(const shuffle_t &obj) override { return visit_expr(obj); }
535 void _visit(const unary_op_t &obj) override { visit_expr(obj); }
536
537#define HANDLE_IR_OBJECT(type) \
538 void _visit(const type &obj) override { visit_stmt(obj); }
539
540 HANDLE_STMT_IR_OBJECTS()
541
542#undef HANDLE_IR_OBJECT
543
544 void verify(const stmt_t &s) {
545 // Phase 0: collect IR paths for expressions.
546 phase_ = 0;
547 visit(s);
548
549 // Phase 1: verify all expressions are defined at their path.
550 phase_ = 1;
551 visit(s);
552 }
553
554private:
555 template <typename T>
556 void visit_expr(const T &obj) {
557 // Expressions are not used during phase 1.
558 if (phase_ == 1) return;
559 if (ctx_.has(obj)) {
560 auto &path = ctx_.get_path(obj);
561 to_check_[path.back()].push_back(obj);
562 }
563 scope_visitor_t::_visit(obj);
564 }
565
566 template <typename T>
567 void visit_stmt(const T &obj) {
568 scope_visitor_t::_visit(obj);
569
570 // Statements are not used during phase 0.
571 if (phase_ == 0) return;
572
573 // Phase 1: check that all attached expressions are defined at this
574 // statement.
575 auto it = to_check_.find(obj);
576 if (it != to_check_.end()) {
577 for (auto &e : it->second) {
578 ir_assert(is_expr_defined(e))
579 << "Expression contains undefined variables: " << e;
580 MAYBE_UNUSED(e);
581 }
582 to_check_.erase(it);
583 }
584 }
585
586 cse_context_t &ctx_;
587
588 int phase_ = 0;
589 object_map_t<stmt_t, std::vector<expr_t>> to_check_;
590};
591
592// Generates let statements for expressions being eliminated.
593class cse_let_generator_t : public ir_visitor_t {
594public:
595 cse_let_generator_t(const cse_context_t &ctx, const stmt_t &stmt)
596 : ctx_(ctx), stmt_(stmt) {}
597
598 void _visit(const binary_op_t &obj) override { visit_expr(obj); }
599 void _visit(const shuffle_t &obj) override { visit_expr(obj); }
600 void _visit(const unary_op_t &obj) override { visit_expr(obj); }
601 void _visit(const var_t &obj) override {
602 auto it = all_vars_.find(obj);
603 if (it == all_vars_.end()) return;
604 if (seen_vars_.count(obj) == 0) generate_for_expr(it->second);
605 }
606
607 stmt_t generate() {
608 ctx_.for_each([&](const expr_t &e) {
609 auto &cse_var = ctx_.get_var(e);
610 auto ret = all_vars_.insert({cse_var, e});
611 ir_assert(ret.second);
612 MAYBE_UNUSED(ret);
613 });
614 ctx_.for_each([&](const expr_t &e) { generate_for_expr(e); });
615 for (auto it = lets_.rbegin(); it != lets_.rend(); ++it) {
616 auto &let = it->as<let_t>();
617 stmt_ = let_t::make(let.var, let.value, stmt_);
618 }
619 return stmt_;
620 }
621
622private:
623 void generate_for_expr(const expr_t &e) {
624 auto &cse_var = ctx_.get_var(e);
625 if (seen_vars_.count(cse_var) == 1) return;
626 visit(e);
627 }
628
629 template <typename T>
630 void visit_expr(const T &obj) {
631 ir_visitor_t::_visit(obj);
632 if (ctx_.has(obj) && ctx_.has_var(obj)) {
633 auto &var = ctx_.get_var(obj);
634 auto ret = seen_vars_.insert(var);
635 if (ret.second) lets_.push_back(let_t::make(var, obj));
636 }
637 }
638
639 const cse_context_t &ctx_;
640 stmt_t stmt_;
641
642 object_map_t<expr_t, expr_t> all_vars_; // Var -> expression.
643 object_set_t<expr_t> seen_vars_;
644
645 std::vector<stmt_t> lets_;
646};
647
648// Eliminates expressions from the statement.
649class cse_mutator_t : public ir_mutator_t {
650public:
651 cse_mutator_t(cse_context_t &ctx) : ctx_(ctx) {}
652
653 object_t _mutate(const binary_op_t &obj) override {
654 return mutate_expr(obj);
655 }
656 object_t _mutate(const shuffle_t &obj) override { return mutate_expr(obj); }
657 object_t _mutate(const unary_op_t &obj) override {
658 return mutate_expr(obj);
659 }
660
661#define HANDLE_IR_OBJECT(type) \
662 object_t _mutate(const type &obj) override { return mutate_stmt(obj); }
663
664 HANDLE_STMT_IR_OBJECTS()
665
666#undef HANDLE_IR_OBJECT
667
668private:
669 template <typename T>
670 object_t mutate_expr(const T &obj) {
671 auto new_obj = ir_mutator_t::_mutate(obj);
672 if (ctx_.has(obj) && !new_obj.is_equal(obj)) {
673 ctx_.update_expr(obj, new_obj);
674 }
675 if (ctx_.should_assign_var(new_obj)) {
676 bool has_var = ctx_.has_var(new_obj);
677 auto var = ctx_.get_or_assign_var(new_obj);
678 auto &path = ctx_.get_path(new_obj);
679 if (!has_var) to_update_[path.back()].push_back(new_obj);
680 return std::move(var);
681 }
682 return new_obj;
683 }
684
685 template <typename T>
686 object_t mutate_stmt(const T &obj) {
687 auto new_obj = ir_mutator_t::_mutate(obj);
688 auto it = to_update_.find(obj);
689 if (it == to_update_.end()) return new_obj;
690
691 cse_context_t local_ctx(ctx_.ir_ctx());
692 for (auto &e : it->second) {
693 local_ctx.register_expr(ctx_.find_cse_expr(e));
694 }
695 to_update_.erase(it);
696
697 auto body = get_stmt_body(new_obj);
698 cse_let_generator_t g(local_ctx, body);
699 body = g.generate();
700 new_obj = replace_stmt_body(new_obj, body);
701 return new_obj;
702 }
703
704 cse_context_t &ctx_;
705 object_map_t<stmt_t, std::vector<expr_t>> to_update_;
706};
707
708stmt_t eliminate_common_subexprs_impl(const stmt_t &_stmt, cse_context_t &ctx,
709 int grf_size, int memory_usage_limit, int run_idx) {
710 auto stmt = _stmt;
711
712 // Collect statistics.
713 cse_visitor_t visitor(ctx);
714 visitor.visit(stmt);
715
716#if !defined(NDEBUG) || defined(GEN_CONV_DEBUG)
717 // Verify that collected IR paths are correct (cse_expr_t objects are
718 // defined at those paths).
719 cse_verifier_t verifier(ctx);
720 verifier.verify(stmt);
721#endif
722
723 // Eliminate subexpressions.
724 cse_mutator_t mutator(ctx);
725 stmt = mutator.mutate(stmt);
726
727 // The second run is the last run.
728 if (run_idx != 0) return stmt;
729
730 // If memory usage exceeds the limit, exclude some
731 // expressions from CSE and retry the whole process from
732 // scratch.
733 int memory_usage = get_peak_grf_usage(stmt, grf_size) * grf_size;
734 if (memory_usage > memory_usage_limit) {
735 ir_trace() << "CSE exceeded GRF usage limit. Usage: " << memory_usage
736 << ", limit: " << memory_usage_limit
737 << ". Retry CSE and skip some expressions..." << std::endl;
738 ctx.set_skip_exprs(_stmt, memory_usage_limit, grf_size);
739 ctx.reset_cse_exprs();
740 return stmt_t();
741 }
742
743 return stmt;
744}
745
746stmt_t eliminate_common_subexprs(
747 const stmt_t &_stmt, ir_context_t &ir_ctx, int memory_usage_limit) {
748 trace_start();
749 stmt_t stmt;
750 cse_context_t cse_ctx(ir_ctx);
751
752 int grf_size = ir_ctx.hw_cfg().grf_size();
753 stmt = eliminate_common_subexprs_impl(
754 _stmt, cse_ctx, grf_size, memory_usage_limit, 0);
755 // Retry if statement is empty, rely on the updated
756 // skip_exprs from the CSE context.
757 if (stmt.is_empty()) {
758 stmt = eliminate_common_subexprs_impl(
759 _stmt, cse_ctx, grf_size, memory_usage_limit, 1);
760 }
761 trace_pass("eliminate_common_subexprs", stmt, ir_ctx);
762 return stmt;
763}
764
765class g2s_buf_visitor_t : public ir_visitor_t {
766public:
767 int g2s_buf_size() const {
768 int ret = 0;
769 for (auto &kv : g2s_bufs_) {
770 ir_assert(kv.second != 0);
771 ret += kv.second;
772 }
773 return ret;
774 }
775
776 void _visit(const alloc_t &obj) override {
777 ir_visitor_t::_visit(obj);
778 auto it = g2s_bufs_.find(obj.buf);
779 if (it != g2s_bufs_.end()) it->second = obj.size;
780 }
781
782 void _visit(const func_call_t &obj) override {
783 if (!in_g2s_) {
784 ir_visitor_t::_visit(obj);
785 return;
786 }
787 if (auto *func = obj.func.as_ptr<send_t>()) {
788 ir_assert(func->is_load()) << func;
789 auto &buf = send_t::arg_reg_buf(obj);
790 g2s_bufs_.emplace(get_base(buf), 0);
791 }
792 ir_visitor_t::_visit(obj);
793 }
794
795 void _visit(const stmt_group_t &obj) override {
796 bool is_g2s = obj.label == stmt_label_t::g2s_load();
797 if (is_g2s) in_g2s_ = true;
798 ir_visitor_t::_visit(obj);
799 if (is_g2s) in_g2s_ = false;
800 }
801
802private:
803 object_map_t<expr_t, int> g2s_bufs_;
804 bool in_g2s_ = false;
805};
806
807stmt_t eliminate_common_subexprs(const stmt_t &stmt, ir_context_t &ir_ctx,
808 int reserved_regs, int gmem_bufs) {
809 int grf_size = ir_ctx.grf_size();
810 int available_regs = ir_ctx.exec_cfg().regs() - reserved_regs;
811 int memory_usage_limit = available_regs * grf_size;
812 if (gmem_bufs > 1) {
813 g2s_buf_visitor_t v;
814 v.visit(stmt);
815 memory_usage_limit -= (gmem_bufs - 1) * v.g2s_buf_size();
816 }
817 return eliminate_common_subexprs(stmt, ir_ctx, memory_usage_limit);
818}
819
820} // namespace jit
821} // namespace gpu
822} // namespace impl
823} // namespace dnnl
824