1#include "taichi/ir/ir.h"
2#include "taichi/ir/analysis.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/transforms.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/system/profiler.h"
7
8#include <typeindex>
9
10namespace taichi::lang {
11
12// A helper class to maintain WholeKernelCSE::visited
13class MarkUndone : public BasicStmtVisitor {
14 private:
15 std::unordered_set<int> *const visited_;
16 Stmt *const modified_operand_;
17
18 public:
19 using BasicStmtVisitor::visit;
20
21 MarkUndone(std::unordered_set<int> *visited, Stmt *modified_operand)
22 : visited_(visited), modified_operand_(modified_operand) {
23 allow_undefined_visitor = true;
24 invoke_default_visitor = true;
25 }
26
27 void visit(Stmt *stmt) override {
28 if (stmt->has_operand(modified_operand_)) {
29 visited_->erase(stmt->instance_id);
30 }
31 }
32
33 void preprocess_container_stmt(Stmt *stmt) override {
34 if (stmt->has_operand(modified_operand_)) {
35 visited_->erase(stmt->instance_id);
36 }
37 }
38
39 static void run(std::unordered_set<int> *visited, Stmt *modified_operand) {
40 MarkUndone marker(visited, modified_operand);
41 modified_operand->get_ir_root()->accept(&marker);
42 }
43};
44
45// Whole Kernel Common Subexpression Elimination
46class WholeKernelCSE : public BasicStmtVisitor {
47 private:
48 std::unordered_set<int> visited_;
49 // each scope corresponds to an unordered_set
50 std::vector<std::unordered_map<std::size_t, std::unordered_set<Stmt *> > >
51 visible_stmts_;
52 DelayedIRModifier modifier_;
53
54 public:
55 using BasicStmtVisitor::visit;
56
57 WholeKernelCSE() {
58 allow_undefined_visitor = true;
59 invoke_default_visitor = true;
60 }
61
62 bool is_done(Stmt *stmt) {
63 return visited_.find(stmt->instance_id) != visited_.end();
64 }
65
66 void set_done(Stmt *stmt) {
67 visited_.insert(stmt->instance_id);
68 }
69
70 static std::size_t operand_hash(const Stmt *stmt) {
71 std::size_t hash_code{0};
72 auto hash_type =
73 std::hash<std::type_index>{}(std::type_index(typeid(stmt)));
74 if (stmt->is<GlobalPtrStmt>() || stmt->is<LoopUniqueStmt>()) {
75 // special cases in common_statement_eliminable()
76 return hash_type;
77 }
78 auto op = stmt->get_operands();
79 for (auto &x : op) {
80 if (x == nullptr)
81 continue;
82 // Hash the addresses of the operand pointers.
83 hash_code =
84 (hash_code * 33) ^
85 (std::hash<unsigned long>{}(reinterpret_cast<unsigned long>(x)));
86 }
87 return hash_type ^ hash_code;
88 }
89
90 static bool common_statement_eliminable(Stmt *this_stmt, Stmt *prev_stmt) {
91 // Is this_stmt eliminable given that prev_stmt appears before it and has
92 // the same type with it?
93 if (this_stmt->type() != prev_stmt->type())
94 return false;
95 if (this_stmt->is<GlobalPtrStmt>()) {
96 auto this_ptr = this_stmt->as<GlobalPtrStmt>();
97 auto prev_ptr = prev_stmt->as<GlobalPtrStmt>();
98 return irpass::analysis::definitely_same_address(this_ptr, prev_ptr) &&
99 (this_ptr->activate == prev_ptr->activate || prev_ptr->activate);
100 }
101 if (this_stmt->is<ExternalPtrStmt>()) {
102 auto this_ptr = this_stmt->as<ExternalPtrStmt>();
103 auto prev_ptr = prev_stmt->as<ExternalPtrStmt>();
104 return irpass::analysis::definitely_same_address(this_ptr, prev_ptr);
105 }
106 if (this_stmt->is<LoopUniqueStmt>()) {
107 auto this_loop_unique = this_stmt->as<LoopUniqueStmt>();
108 auto prev_loop_unique = prev_stmt->as<LoopUniqueStmt>();
109 if (irpass::analysis::same_value(this_loop_unique->input,
110 prev_loop_unique->input)) {
111 // Merge the "covers" information into prev_loop_unique.
112 // Notice that this_loop_unique->covers is corrupted here.
113 prev_loop_unique->covers.insert(this_loop_unique->covers.begin(),
114 this_loop_unique->covers.end());
115 return true;
116 }
117 return false;
118 }
119 return irpass::analysis::same_statements(this_stmt, prev_stmt);
120 }
121
122 void visit(Stmt *stmt) override {
123 if (!stmt->common_statement_eliminable())
124 return;
125 // container_statement does not need to be CSE-ed
126 if (stmt->is_container_statement())
127 return;
128 // Generic visitor for all CSE-able statements.
129 std::size_t hash_value = operand_hash(stmt);
130 if (is_done(stmt)) {
131 visible_stmts_.back()[hash_value].insert(stmt);
132 return;
133 }
134 for (auto &scope : visible_stmts_) {
135 for (auto &prev_stmt : scope[hash_value]) {
136 if (common_statement_eliminable(stmt, prev_stmt)) {
137 MarkUndone::run(&visited_, stmt);
138 stmt->replace_usages_with(prev_stmt);
139 modifier_.erase(stmt);
140 return;
141 }
142 }
143 }
144 visible_stmts_.back()[hash_value].insert(stmt);
145 set_done(stmt);
146 }
147
148 void visit(Block *stmt_list) override {
149 visible_stmts_.emplace_back();
150 for (auto &stmt : stmt_list->statements) {
151 stmt->accept(this);
152 }
153 visible_stmts_.pop_back();
154 }
155
156 void visit(IfStmt *if_stmt) override {
157 if (if_stmt->true_statements) {
158 if (if_stmt->true_statements->statements.empty()) {
159 if_stmt->set_true_statements(nullptr);
160 }
161 }
162
163 if (if_stmt->false_statements) {
164 if (if_stmt->false_statements->statements.empty()) {
165 if_stmt->set_false_statements(nullptr);
166 }
167 }
168
169 // Move common statements at the beginning or the end of both branches
170 // outside.
171 if (if_stmt->true_statements && if_stmt->false_statements) {
172 auto &true_clause = if_stmt->true_statements;
173 auto &false_clause = if_stmt->false_statements;
174 if (irpass::analysis::same_statements(
175 true_clause->statements[0].get(),
176 false_clause->statements[0].get())) {
177 // Directly modify this because it won't invalidate any iterators.
178 auto common_stmt = true_clause->extract(0);
179 irpass::replace_all_usages_with(false_clause.get(),
180 false_clause->statements[0].get(),
181 common_stmt.get());
182 modifier_.insert_before(if_stmt, std::move(common_stmt));
183 false_clause->erase(0);
184 }
185 if (!true_clause->statements.empty() &&
186 !false_clause->statements.empty() &&
187 irpass::analysis::same_statements(
188 true_clause->statements.back().get(),
189 false_clause->statements.back().get())) {
190 // Directly modify this because it won't invalidate any iterators.
191 auto common_stmt = true_clause->extract((int)true_clause->size() - 1);
192 irpass::replace_all_usages_with(false_clause.get(),
193 false_clause->statements.back().get(),
194 common_stmt.get());
195 modifier_.insert_after(if_stmt, std::move(common_stmt));
196 false_clause->erase((int)false_clause->size() - 1);
197 }
198 }
199
200 if (if_stmt->true_statements)
201 if_stmt->true_statements->accept(this);
202 if (if_stmt->false_statements)
203 if_stmt->false_statements->accept(this);
204 }
205
206 static bool run(IRNode *node) {
207 WholeKernelCSE eliminator;
208 bool modified = false;
209 while (true) {
210 node->accept(&eliminator);
211 if (eliminator.modifier_.modify_ir())
212 modified = true;
213 else
214 break;
215 }
216 return modified;
217 }
218};
219
220namespace irpass {
221bool whole_kernel_cse(IRNode *root) {
222 TI_AUTO_PROF;
223 return WholeKernelCSE::run(root);
224}
225} // namespace irpass
226
227} // namespace taichi::lang
228