1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/analysis.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/system/profiler.h"
6
7namespace taichi::lang {
8
9// The EliminateImmutableLocalVars pass eliminates all immutable local vars
10// calculated from the GatherImmutableLocalVars pass. An immutable local var
11// can be eliminated by forwarding the value of its only store to all loads
12// after that store. See https://github.com/taichi-dev/taichi/pull/6926 for the
13// background of this optimization.
14class EliminateImmutableLocalVars : public BasicStmtVisitor {
15 private:
16 using BasicStmtVisitor::visit;
17
18 std::unordered_set<Stmt *> immutable_local_vars_;
19 std::unordered_map<Stmt *, Stmt *> immutable_local_var_to_value_;
20 ImmediateIRModifier immediate_modifier_;
21 DelayedIRModifier delayed_modifier_;
22
23 public:
24 explicit EliminateImmutableLocalVars(
25 const std::unordered_set<Stmt *> &immutable_local_vars,
26 IRNode *node)
27 : immutable_local_vars_(immutable_local_vars), immediate_modifier_(node) {
28 }
29
30 void visit(AllocaStmt *stmt) override {
31 if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) {
32 delayed_modifier_.erase(stmt);
33 }
34 }
35
36 void visit(LocalLoadStmt *stmt) override {
37 if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) {
38 immediate_modifier_.replace_usages_with(
39 stmt, immutable_local_var_to_value_[stmt->src]);
40 delayed_modifier_.erase(stmt);
41 }
42 }
43
44 void visit(LocalStoreStmt *stmt) override {
45 if (immutable_local_vars_.find(stmt->dest) != immutable_local_vars_.end()) {
46 TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) ==
47 immutable_local_var_to_value_.end());
48 immutable_local_var_to_value_[stmt->dest] = stmt->val;
49 delayed_modifier_.erase(stmt);
50 }
51 }
52
53 static void run(IRNode *node) {
54 EliminateImmutableLocalVars pass(
55 irpass::analysis::gather_immutable_local_vars(node), node);
56 node->accept(&pass);
57 pass.delayed_modifier_.modify_ir();
58 }
59};
60
61namespace irpass {
62
63void eliminate_immutable_local_vars(IRNode *root) {
64 TI_AUTO_PROF;
65 EliminateImmutableLocalVars::run(root);
66}
67
68} // namespace irpass
69
70} // namespace taichi::lang
71