1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/transforms.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/system/profiler.h" |
7 | |
8 | #include <typeindex> |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | // A helper class to maintain WholeKernelCSE::visited |
13 | class MarkUndone : public BasicStmtVisitor { |
14 | private: |
15 | std::unordered_set<int> *const visited_; |
16 | Stmt *const modified_operand_; |
17 | |
18 | public: |
19 | using BasicStmtVisitor::visit; |
20 | |
21 | MarkUndone(std::unordered_set<int> *visited, Stmt *modified_operand) |
22 | : visited_(visited), modified_operand_(modified_operand) { |
23 | allow_undefined_visitor = true; |
24 | invoke_default_visitor = true; |
25 | } |
26 | |
27 | void visit(Stmt *stmt) override { |
28 | if (stmt->has_operand(modified_operand_)) { |
29 | visited_->erase(stmt->instance_id); |
30 | } |
31 | } |
32 | |
33 | void preprocess_container_stmt(Stmt *stmt) override { |
34 | if (stmt->has_operand(modified_operand_)) { |
35 | visited_->erase(stmt->instance_id); |
36 | } |
37 | } |
38 | |
39 | static void run(std::unordered_set<int> *visited, Stmt *modified_operand) { |
40 | MarkUndone marker(visited, modified_operand); |
41 | modified_operand->get_ir_root()->accept(&marker); |
42 | } |
43 | }; |
44 | |
45 | // Whole Kernel Common Subexpression Elimination |
46 | class WholeKernelCSE : public BasicStmtVisitor { |
47 | private: |
48 | std::unordered_set<int> visited_; |
49 | // each scope corresponds to an unordered_set |
50 | std::vector<std::unordered_map<std::size_t, std::unordered_set<Stmt *> > > |
51 | visible_stmts_; |
52 | DelayedIRModifier modifier_; |
53 | |
54 | public: |
55 | using BasicStmtVisitor::visit; |
56 | |
57 | WholeKernelCSE() { |
58 | allow_undefined_visitor = true; |
59 | invoke_default_visitor = true; |
60 | } |
61 | |
62 | bool is_done(Stmt *stmt) { |
63 | return visited_.find(stmt->instance_id) != visited_.end(); |
64 | } |
65 | |
66 | void set_done(Stmt *stmt) { |
67 | visited_.insert(stmt->instance_id); |
68 | } |
69 | |
70 | static std::size_t operand_hash(const Stmt *stmt) { |
71 | std::size_t hash_code{0}; |
72 | auto hash_type = |
73 | std::hash<std::type_index>{}(std::type_index(typeid(stmt))); |
74 | if (stmt->is<GlobalPtrStmt>() || stmt->is<LoopUniqueStmt>()) { |
75 | // special cases in common_statement_eliminable() |
76 | return hash_type; |
77 | } |
78 | auto op = stmt->get_operands(); |
79 | for (auto &x : op) { |
80 | if (x == nullptr) |
81 | continue; |
82 | // Hash the addresses of the operand pointers. |
83 | hash_code = |
84 | (hash_code * 33) ^ |
85 | (std::hash<unsigned long>{}(reinterpret_cast<unsigned long>(x))); |
86 | } |
87 | return hash_type ^ hash_code; |
88 | } |
89 | |
90 | static bool common_statement_eliminable(Stmt *this_stmt, Stmt *prev_stmt) { |
91 | // Is this_stmt eliminable given that prev_stmt appears before it and has |
92 | // the same type with it? |
93 | if (this_stmt->type() != prev_stmt->type()) |
94 | return false; |
95 | if (this_stmt->is<GlobalPtrStmt>()) { |
96 | auto this_ptr = this_stmt->as<GlobalPtrStmt>(); |
97 | auto prev_ptr = prev_stmt->as<GlobalPtrStmt>(); |
98 | return irpass::analysis::definitely_same_address(this_ptr, prev_ptr) && |
99 | (this_ptr->activate == prev_ptr->activate || prev_ptr->activate); |
100 | } |
101 | if (this_stmt->is<ExternalPtrStmt>()) { |
102 | auto this_ptr = this_stmt->as<ExternalPtrStmt>(); |
103 | auto prev_ptr = prev_stmt->as<ExternalPtrStmt>(); |
104 | return irpass::analysis::definitely_same_address(this_ptr, prev_ptr); |
105 | } |
106 | if (this_stmt->is<LoopUniqueStmt>()) { |
107 | auto this_loop_unique = this_stmt->as<LoopUniqueStmt>(); |
108 | auto prev_loop_unique = prev_stmt->as<LoopUniqueStmt>(); |
109 | if (irpass::analysis::same_value(this_loop_unique->input, |
110 | prev_loop_unique->input)) { |
111 | // Merge the "covers" information into prev_loop_unique. |
112 | // Notice that this_loop_unique->covers is corrupted here. |
113 | prev_loop_unique->covers.insert(this_loop_unique->covers.begin(), |
114 | this_loop_unique->covers.end()); |
115 | return true; |
116 | } |
117 | return false; |
118 | } |
119 | return irpass::analysis::same_statements(this_stmt, prev_stmt); |
120 | } |
121 | |
122 | void visit(Stmt *stmt) override { |
123 | if (!stmt->common_statement_eliminable()) |
124 | return; |
125 | // container_statement does not need to be CSE-ed |
126 | if (stmt->is_container_statement()) |
127 | return; |
128 | // Generic visitor for all CSE-able statements. |
129 | std::size_t hash_value = operand_hash(stmt); |
130 | if (is_done(stmt)) { |
131 | visible_stmts_.back()[hash_value].insert(stmt); |
132 | return; |
133 | } |
134 | for (auto &scope : visible_stmts_) { |
135 | for (auto &prev_stmt : scope[hash_value]) { |
136 | if (common_statement_eliminable(stmt, prev_stmt)) { |
137 | MarkUndone::run(&visited_, stmt); |
138 | stmt->replace_usages_with(prev_stmt); |
139 | modifier_.erase(stmt); |
140 | return; |
141 | } |
142 | } |
143 | } |
144 | visible_stmts_.back()[hash_value].insert(stmt); |
145 | set_done(stmt); |
146 | } |
147 | |
148 | void visit(Block *stmt_list) override { |
149 | visible_stmts_.emplace_back(); |
150 | for (auto &stmt : stmt_list->statements) { |
151 | stmt->accept(this); |
152 | } |
153 | visible_stmts_.pop_back(); |
154 | } |
155 | |
156 | void visit(IfStmt *if_stmt) override { |
157 | if (if_stmt->true_statements) { |
158 | if (if_stmt->true_statements->statements.empty()) { |
159 | if_stmt->set_true_statements(nullptr); |
160 | } |
161 | } |
162 | |
163 | if (if_stmt->false_statements) { |
164 | if (if_stmt->false_statements->statements.empty()) { |
165 | if_stmt->set_false_statements(nullptr); |
166 | } |
167 | } |
168 | |
169 | // Move common statements at the beginning or the end of both branches |
170 | // outside. |
171 | if (if_stmt->true_statements && if_stmt->false_statements) { |
172 | auto &true_clause = if_stmt->true_statements; |
173 | auto &false_clause = if_stmt->false_statements; |
174 | if (irpass::analysis::same_statements( |
175 | true_clause->statements[0].get(), |
176 | false_clause->statements[0].get())) { |
177 | // Directly modify this because it won't invalidate any iterators. |
178 | auto common_stmt = true_clause->extract(0); |
179 | irpass::replace_all_usages_with(false_clause.get(), |
180 | false_clause->statements[0].get(), |
181 | common_stmt.get()); |
182 | modifier_.insert_before(if_stmt, std::move(common_stmt)); |
183 | false_clause->erase(0); |
184 | } |
185 | if (!true_clause->statements.empty() && |
186 | !false_clause->statements.empty() && |
187 | irpass::analysis::same_statements( |
188 | true_clause->statements.back().get(), |
189 | false_clause->statements.back().get())) { |
190 | // Directly modify this because it won't invalidate any iterators. |
191 | auto common_stmt = true_clause->extract((int)true_clause->size() - 1); |
192 | irpass::replace_all_usages_with(false_clause.get(), |
193 | false_clause->statements.back().get(), |
194 | common_stmt.get()); |
195 | modifier_.insert_after(if_stmt, std::move(common_stmt)); |
196 | false_clause->erase((int)false_clause->size() - 1); |
197 | } |
198 | } |
199 | |
200 | if (if_stmt->true_statements) |
201 | if_stmt->true_statements->accept(this); |
202 | if (if_stmt->false_statements) |
203 | if_stmt->false_statements->accept(this); |
204 | } |
205 | |
206 | static bool run(IRNode *node) { |
207 | WholeKernelCSE eliminator; |
208 | bool modified = false; |
209 | while (true) { |
210 | node->accept(&eliminator); |
211 | if (eliminator.modifier_.modify_ir()) |
212 | modified = true; |
213 | else |
214 | break; |
215 | } |
216 | return modified; |
217 | } |
218 | }; |
219 | |
220 | namespace irpass { |
221 | bool whole_kernel_cse(IRNode *root) { |
222 | TI_AUTO_PROF; |
223 | return WholeKernelCSE::run(root); |
224 | } |
225 | } // namespace irpass |
226 | |
227 | } // namespace taichi::lang |
228 | |