1#include "taichi/ir/ir.h"
2#include "taichi/ir/analysis.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/visitors.h"
5
6namespace taichi::lang {
7
8// Find the **last** store, or return invalid if there is an AtomicOpStmt
9// after the last store.
10class LocalStoreForwarder : public BasicStmtVisitor {
11 private:
12 Stmt *var_;
13 bool is_valid_;
14 Stmt *result_;
15
16 public:
17 using BasicStmtVisitor::visit;
18
19 explicit LocalStoreForwarder(Stmt *var)
20 : var_(var), is_valid_(true), result_(nullptr) {
21 TI_ASSERT(var->is<AllocaStmt>());
22 allow_undefined_visitor = true;
23 invoke_default_visitor = true;
24 }
25
26 void visit(LocalStoreStmt *stmt) override {
27 if (stmt->dest == var_) {
28 is_valid_ = true;
29 result_ = stmt;
30 }
31 }
32
33 void visit(AllocaStmt *stmt) override {
34 if (stmt == var_) {
35 is_valid_ = true;
36 result_ = stmt;
37 }
38 }
39
40 void visit(AtomicOpStmt *stmt) override {
41 if (stmt->dest == var_) {
42 is_valid_ = false;
43 }
44 }
45
46 // Only if **both** branches finally store the variable with exactly the same
47 // data, can we forward it to the local load statement.
48 void visit(IfStmt *if_stmt) override {
49 // the default return value: valid, no stores
50 std::pair<bool, Stmt *> true_branch(true, nullptr);
51 if (if_stmt->true_statements) {
52 // create a new LocalStoreForwarder instance
53 true_branch = run(if_stmt->true_statements.get(), var_);
54 }
55 std::pair<bool, Stmt *> false_branch(true, nullptr);
56 if (if_stmt->false_statements) {
57 false_branch = run(if_stmt->false_statements.get(), var_);
58 }
59 auto true_stmt = true_branch.second;
60 auto false_stmt = false_branch.second;
61 if (!true_branch.first || !false_branch.first) {
62 // at least one branch finally modifies the variable without storing
63 is_valid_ = false;
64 } else if (true_stmt == nullptr && false_stmt == nullptr) {
65 // both branches don't modify the variable
66 return;
67 } else if (true_stmt == nullptr || false_stmt == nullptr) {
68 // only one branch modifies the variable
69 is_valid_ = false;
70 } else {
71 TI_ASSERT(true_stmt->is<LocalStoreStmt>());
72 TI_ASSERT(false_stmt->is<LocalStoreStmt>());
73 if (true_stmt->as<LocalStoreStmt>()->val !=
74 false_stmt->as<LocalStoreStmt>()->val) {
75 // two branches finally store the variable differently
76 is_valid_ = false;
77 } else {
78 is_valid_ = true;
79 result_ = true_stmt; // same as false_stmt
80 }
81 }
82 }
83
84 // We don't know if a loop's body will be executed, so we cannot forward
85 // the "last" store inside a loop to the local load statement.
86 // What we can do is just check if the loop doesn't modify the variable.
87 void visit(WhileStmt *stmt) override {
88 if (irpass::analysis::has_store_or_atomic(stmt, {var_})) {
89 is_valid_ = false;
90 }
91 }
92
93 void visit(RangeForStmt *stmt) override {
94 if (irpass::analysis::has_store_or_atomic(stmt, {var_})) {
95 is_valid_ = false;
96 }
97 }
98
99 void visit(StructForStmt *stmt) override {
100 if (irpass::analysis::has_store_or_atomic(stmt, {var_})) {
101 is_valid_ = false;
102 }
103 }
104
105 void visit(OffloadedStmt *stmt) override {
106 if (irpass::analysis::has_store_or_atomic(stmt, {var_})) {
107 is_valid_ = false;
108 }
109 }
110
111 static std::pair<bool, Stmt *> run(IRNode *root, Stmt *var) {
112 LocalStoreForwarder searcher(var);
113 root->accept(&searcher);
114 return std::make_pair(searcher.is_valid_, searcher.result_);
115 }
116};
117
118namespace irpass::analysis {
119std::pair<bool, Stmt *> last_store_or_atomic(IRNode *root, Stmt *var) {
120 return LocalStoreForwarder::run(root, var);
121}
122} // namespace irpass::analysis
123
124} // namespace taichi::lang
125