1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/analysis.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/system/profiler.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | // The EliminateImmutableLocalVars pass eliminates all immutable local vars |
10 | // calculated from the GatherImmutableLocalVars pass. An immutable local var |
11 | // can be eliminated by forwarding the value of its only store to all loads |
12 | // after that store. See https://github.com/taichi-dev/taichi/pull/6926 for the |
13 | // background of this optimization. |
14 | class EliminateImmutableLocalVars : public BasicStmtVisitor { |
15 | private: |
16 | using BasicStmtVisitor::visit; |
17 | |
18 | std::unordered_set<Stmt *> immutable_local_vars_; |
19 | std::unordered_map<Stmt *, Stmt *> immutable_local_var_to_value_; |
20 | ImmediateIRModifier immediate_modifier_; |
21 | DelayedIRModifier delayed_modifier_; |
22 | |
23 | public: |
24 | explicit EliminateImmutableLocalVars( |
25 | const std::unordered_set<Stmt *> &immutable_local_vars, |
26 | IRNode *node) |
27 | : immutable_local_vars_(immutable_local_vars), immediate_modifier_(node) { |
28 | } |
29 | |
30 | void visit(AllocaStmt *stmt) override { |
31 | if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) { |
32 | delayed_modifier_.erase(stmt); |
33 | } |
34 | } |
35 | |
36 | void visit(LocalLoadStmt *stmt) override { |
37 | if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) { |
38 | immediate_modifier_.replace_usages_with( |
39 | stmt, immutable_local_var_to_value_[stmt->src]); |
40 | delayed_modifier_.erase(stmt); |
41 | } |
42 | } |
43 | |
44 | void visit(LocalStoreStmt *stmt) override { |
45 | if (immutable_local_vars_.find(stmt->dest) != immutable_local_vars_.end()) { |
46 | TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) == |
47 | immutable_local_var_to_value_.end()); |
48 | immutable_local_var_to_value_[stmt->dest] = stmt->val; |
49 | delayed_modifier_.erase(stmt); |
50 | } |
51 | } |
52 | |
53 | static void run(IRNode *node) { |
54 | EliminateImmutableLocalVars pass( |
55 | irpass::analysis::gather_immutable_local_vars(node), node); |
56 | node->accept(&pass); |
57 | pass.delayed_modifier_.modify_ir(); |
58 | } |
59 | }; |
60 | |
61 | namespace irpass { |
62 | |
63 | void eliminate_immutable_local_vars(IRNode *root) { |
64 | TI_AUTO_PROF; |
65 | EliminateImmutableLocalVars::run(root); |
66 | } |
67 | |
68 | } // namespace irpass |
69 | |
70 | } // namespace taichi::lang |
71 |