1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/visitors.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | // Find the **last** store, or return invalid if there is an AtomicOpStmt |
9 | // after the last store. |
10 | class LocalStoreForwarder : public BasicStmtVisitor { |
11 | private: |
12 | Stmt *var_; |
13 | bool is_valid_; |
14 | Stmt *result_; |
15 | |
16 | public: |
17 | using BasicStmtVisitor::visit; |
18 | |
19 | explicit LocalStoreForwarder(Stmt *var) |
20 | : var_(var), is_valid_(true), result_(nullptr) { |
21 | TI_ASSERT(var->is<AllocaStmt>()); |
22 | allow_undefined_visitor = true; |
23 | invoke_default_visitor = true; |
24 | } |
25 | |
26 | void visit(LocalStoreStmt *stmt) override { |
27 | if (stmt->dest == var_) { |
28 | is_valid_ = true; |
29 | result_ = stmt; |
30 | } |
31 | } |
32 | |
33 | void visit(AllocaStmt *stmt) override { |
34 | if (stmt == var_) { |
35 | is_valid_ = true; |
36 | result_ = stmt; |
37 | } |
38 | } |
39 | |
40 | void visit(AtomicOpStmt *stmt) override { |
41 | if (stmt->dest == var_) { |
42 | is_valid_ = false; |
43 | } |
44 | } |
45 | |
46 | // Only if **both** branches finally store the variable with exactly the same |
47 | // data, can we forward it to the local load statement. |
48 | void visit(IfStmt *if_stmt) override { |
49 | // the default return value: valid, no stores |
50 | std::pair<bool, Stmt *> true_branch(true, nullptr); |
51 | if (if_stmt->true_statements) { |
52 | // create a new LocalStoreForwarder instance |
53 | true_branch = run(if_stmt->true_statements.get(), var_); |
54 | } |
55 | std::pair<bool, Stmt *> false_branch(true, nullptr); |
56 | if (if_stmt->false_statements) { |
57 | false_branch = run(if_stmt->false_statements.get(), var_); |
58 | } |
59 | auto true_stmt = true_branch.second; |
60 | auto false_stmt = false_branch.second; |
61 | if (!true_branch.first || !false_branch.first) { |
62 | // at least one branch finally modifies the variable without storing |
63 | is_valid_ = false; |
64 | } else if (true_stmt == nullptr && false_stmt == nullptr) { |
65 | // both branches don't modify the variable |
66 | return; |
67 | } else if (true_stmt == nullptr || false_stmt == nullptr) { |
68 | // only one branch modifies the variable |
69 | is_valid_ = false; |
70 | } else { |
71 | TI_ASSERT(true_stmt->is<LocalStoreStmt>()); |
72 | TI_ASSERT(false_stmt->is<LocalStoreStmt>()); |
73 | if (true_stmt->as<LocalStoreStmt>()->val != |
74 | false_stmt->as<LocalStoreStmt>()->val) { |
75 | // two branches finally store the variable differently |
76 | is_valid_ = false; |
77 | } else { |
78 | is_valid_ = true; |
79 | result_ = true_stmt; // same as false_stmt |
80 | } |
81 | } |
82 | } |
83 | |
84 | // We don't know if a loop's body will be executed, so we cannot forward |
85 | // the "last" store inside a loop to the local load statement. |
86 | // What we can do is just check if the loop doesn't modify the variable. |
87 | void visit(WhileStmt *stmt) override { |
88 | if (irpass::analysis::has_store_or_atomic(stmt, {var_})) { |
89 | is_valid_ = false; |
90 | } |
91 | } |
92 | |
93 | void visit(RangeForStmt *stmt) override { |
94 | if (irpass::analysis::has_store_or_atomic(stmt, {var_})) { |
95 | is_valid_ = false; |
96 | } |
97 | } |
98 | |
99 | void visit(StructForStmt *stmt) override { |
100 | if (irpass::analysis::has_store_or_atomic(stmt, {var_})) { |
101 | is_valid_ = false; |
102 | } |
103 | } |
104 | |
105 | void visit(OffloadedStmt *stmt) override { |
106 | if (irpass::analysis::has_store_or_atomic(stmt, {var_})) { |
107 | is_valid_ = false; |
108 | } |
109 | } |
110 | |
111 | static std::pair<bool, Stmt *> run(IRNode *root, Stmt *var) { |
112 | LocalStoreForwarder searcher(var); |
113 | root->accept(&searcher); |
114 | return std::make_pair(searcher.is_valid_, searcher.result_); |
115 | } |
116 | }; |
117 | |
118 | namespace irpass::analysis { |
119 | std::pair<bool, Stmt *> last_store_or_atomic(IRNode *root, Stmt *var) { |
120 | return LocalStoreForwarder::run(root, var); |
121 | } |
122 | } // namespace irpass::analysis |
123 | |
124 | } // namespace taichi::lang |
125 | |