1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | // The GatherImmutableLocalVars pass gathers all immutable local vars as input |
9 | // to the EliminateImmutableLocalVars pass. An immutable local var is an alloca |
10 | // which is stored only once (in the same block) and only loaded after that |
11 | // store. |
12 | class GatherImmutableLocalVars : public BasicStmtVisitor { |
13 | private: |
14 | using BasicStmtVisitor::visit; |
15 | |
16 | enum class AllocaStatus { kCreated = 0, kStoredOnce = 1, kInvalid = 2 }; |
17 | std::unordered_map<Stmt *, AllocaStatus> alloca_status_; |
18 | |
19 | public: |
20 | explicit GatherImmutableLocalVars() { |
21 | invoke_default_visitor = true; |
22 | } |
23 | |
24 | void visit(AllocaStmt *stmt) override { |
25 | TI_ASSERT(alloca_status_.find(stmt) == alloca_status_.end()); |
26 | alloca_status_[stmt] = AllocaStatus::kCreated; |
27 | } |
28 | |
29 | void visit(LocalLoadStmt *stmt) override { |
30 | if (stmt->src->is<AllocaStmt>()) { |
31 | auto status_iter = alloca_status_.find(stmt->src); |
32 | TI_ASSERT(status_iter != alloca_status_.end()); |
33 | if (status_iter->second == AllocaStatus::kCreated) { |
34 | status_iter->second = AllocaStatus::kInvalid; |
35 | } |
36 | } |
37 | } |
38 | |
39 | void visit(LocalStoreStmt *stmt) override { |
40 | if (stmt->dest->is<AllocaStmt>()) { |
41 | auto status_iter = alloca_status_.find(stmt->dest); |
42 | TI_ASSERT(status_iter != alloca_status_.end()); |
43 | if (stmt->parent != stmt->dest->parent || |
44 | status_iter->second == AllocaStatus::kStoredOnce || |
45 | stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) { |
46 | status_iter->second = AllocaStatus::kInvalid; |
47 | } else if (status_iter->second == AllocaStatus::kCreated) { |
48 | status_iter->second = AllocaStatus::kStoredOnce; |
49 | } |
50 | } |
51 | } |
52 | |
53 | void default_visit(Stmt *stmt) { |
54 | for (auto &op : stmt->get_operands()) { |
55 | if (op != nullptr && op->is<AllocaStmt>()) { |
56 | auto status_iter = alloca_status_.find(op); |
57 | TI_ASSERT(status_iter != alloca_status_.end()); |
58 | status_iter->second = AllocaStatus::kInvalid; |
59 | } |
60 | } |
61 | } |
62 | |
63 | void visit(Stmt *stmt) override { |
64 | default_visit(stmt); |
65 | } |
66 | |
67 | void preprocess_container_stmt(Stmt *stmt) override { |
68 | default_visit(stmt); |
69 | } |
70 | |
71 | static std::unordered_set<Stmt *> run(IRNode *node) { |
72 | GatherImmutableLocalVars pass; |
73 | node->accept(&pass); |
74 | std::unordered_set<Stmt *> result; |
75 | for (auto &[k, v] : pass.alloca_status_) { |
76 | if (v == AllocaStatus::kStoredOnce) { |
77 | result.insert(k); |
78 | } |
79 | } |
80 | return result; |
81 | } |
82 | }; |
83 | |
84 | namespace irpass::analysis { |
85 | |
86 | std::unordered_set<Stmt *> gather_immutable_local_vars(IRNode *root) { |
87 | return GatherImmutableLocalVars::run(root); |
88 | } |
89 | |
90 | } // namespace irpass::analysis |
91 | |
92 | } // namespace taichi::lang |
93 |