1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5
6namespace taichi::lang {
7
8// The GatherImmutableLocalVars pass gathers all immutable local vars as input
9// to the EliminateImmutableLocalVars pass. An immutable local var is an alloca
10// which is stored only once (in the same block) and only loaded after that
11// store.
12class GatherImmutableLocalVars : public BasicStmtVisitor {
13 private:
14 using BasicStmtVisitor::visit;
15
16 enum class AllocaStatus { kCreated = 0, kStoredOnce = 1, kInvalid = 2 };
17 std::unordered_map<Stmt *, AllocaStatus> alloca_status_;
18
19 public:
20 explicit GatherImmutableLocalVars() {
21 invoke_default_visitor = true;
22 }
23
24 void visit(AllocaStmt *stmt) override {
25 TI_ASSERT(alloca_status_.find(stmt) == alloca_status_.end());
26 alloca_status_[stmt] = AllocaStatus::kCreated;
27 }
28
29 void visit(LocalLoadStmt *stmt) override {
30 if (stmt->src->is<AllocaStmt>()) {
31 auto status_iter = alloca_status_.find(stmt->src);
32 TI_ASSERT(status_iter != alloca_status_.end());
33 if (status_iter->second == AllocaStatus::kCreated) {
34 status_iter->second = AllocaStatus::kInvalid;
35 }
36 }
37 }
38
39 void visit(LocalStoreStmt *stmt) override {
40 if (stmt->dest->is<AllocaStmt>()) {
41 auto status_iter = alloca_status_.find(stmt->dest);
42 TI_ASSERT(status_iter != alloca_status_.end());
43 if (stmt->parent != stmt->dest->parent ||
44 status_iter->second == AllocaStatus::kStoredOnce ||
45 stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) {
46 status_iter->second = AllocaStatus::kInvalid;
47 } else if (status_iter->second == AllocaStatus::kCreated) {
48 status_iter->second = AllocaStatus::kStoredOnce;
49 }
50 }
51 }
52
53 void default_visit(Stmt *stmt) {
54 for (auto &op : stmt->get_operands()) {
55 if (op != nullptr && op->is<AllocaStmt>()) {
56 auto status_iter = alloca_status_.find(op);
57 TI_ASSERT(status_iter != alloca_status_.end());
58 status_iter->second = AllocaStatus::kInvalid;
59 }
60 }
61 }
62
63 void visit(Stmt *stmt) override {
64 default_visit(stmt);
65 }
66
67 void preprocess_container_stmt(Stmt *stmt) override {
68 default_visit(stmt);
69 }
70
71 static std::unordered_set<Stmt *> run(IRNode *node) {
72 GatherImmutableLocalVars pass;
73 node->accept(&pass);
74 std::unordered_set<Stmt *> result;
75 for (auto &[k, v] : pass.alloca_status_) {
76 if (v == AllocaStatus::kStoredOnce) {
77 result.insert(k);
78 }
79 }
80 return result;
81 }
82};
83
84namespace irpass::analysis {
85
86std::unordered_set<Stmt *> gather_immutable_local_vars(IRNode *root) {
87 return GatherImmutableLocalVars::run(root);
88}
89
90} // namespace irpass::analysis
91
92} // namespace taichi::lang
93