1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include <unordered_map> |
6 | #include <unordered_set> |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | // Compare if two IRNodes are equivalent. |
11 | class IRNodeComparator : public IRVisitor { |
12 | private: |
13 | IRNode *other_node_; |
14 | // map the id from this node to the other node |
15 | std::unordered_map<int, int> id_map_; |
16 | |
17 | bool recursively_check_; |
18 | |
19 | // Compare if two IRNodes definitely have the same value instead. |
20 | // When this is true, it's weaker in the sense that we don't require the |
21 | // activate field in the GlobalPtrStmt to be the same, but stronger in the |
22 | // sense that we require the value to be the same (especially stronger in |
23 | // GlobalLoadStmt, RandStmt, etc.). |
24 | bool check_same_value_; |
25 | |
26 | public: |
27 | bool same; |
28 | |
29 | explicit IRNodeComparator( |
30 | IRNode *other_node, |
31 | const std::optional<std::unordered_map<int, int>> &id_map, |
32 | bool check_same_value) |
33 | : other_node_(other_node) { |
34 | allow_undefined_visitor = true; |
35 | invoke_default_visitor = true; |
36 | same = true; |
37 | if (id_map.has_value()) { |
38 | recursively_check_ = true; |
39 | this->id_map_ = id_map.value(); |
40 | } else { |
41 | recursively_check_ = false; |
42 | } |
43 | check_same_value_ = check_same_value; |
44 | } |
45 | |
46 | void map_id(int this_id, int other_id) { |
47 | auto it = id_map_.find(this_id); |
48 | if (it == id_map_.end()) { |
49 | id_map_[this_id] = other_id; |
50 | } else if (it->second != other_id) { |
51 | same = false; |
52 | } |
53 | } |
54 | |
55 | void check_mapping(Stmt *this_stmt, Stmt *other_stmt) { |
56 | // get the corresponding id in the other node |
57 | // and check if it is other_stmt->id |
58 | auto it = id_map_.find(this_stmt->id); |
59 | if (it != id_map_.end()) { |
60 | if (it->second != other_stmt->id) { |
61 | same = false; |
62 | } |
63 | return; |
64 | } |
65 | if (!recursively_check_) { |
66 | // use identity mapping if not found |
67 | if (this_stmt->id != other_stmt->id) { |
68 | same = false; |
69 | } |
70 | id_map_[this_stmt->id] = other_stmt->id; |
71 | } else { |
72 | // recursively check them |
73 | IRNode *backup_other_node = other_node_; |
74 | other_node_ = other_stmt; |
75 | this_stmt->accept(this); |
76 | other_node_ = backup_other_node; |
77 | } |
78 | } |
79 | |
80 | void visit(Block *stmt_list) override { |
81 | if (!other_node_->is<Block>()) { |
82 | same = false; |
83 | return; |
84 | } |
85 | |
86 | auto other = other_node_->as<Block>(); |
87 | if (stmt_list->size() != other->size()) { |
88 | same = false; |
89 | return; |
90 | } |
91 | for (int i = 0; i < (int)stmt_list->size(); i++) { |
92 | other_node_ = other->statements[i].get(); |
93 | stmt_list->statements[i]->accept(this); |
94 | if (!same) |
95 | break; |
96 | } |
97 | other_node_ = other; |
98 | } |
99 | |
100 | void basic_check(Stmt *stmt) { |
101 | // type check |
102 | if (typeid(*other_node_) != typeid(*stmt)) { |
103 | same = false; |
104 | return; |
105 | } |
106 | auto other = other_node_->as<Stmt>(); |
107 | if (stmt == other) { |
108 | return; |
109 | } |
110 | |
111 | // If two identical statements can have different values, return false. |
112 | // TODO: actually the condition should be "can stmt be an operand of |
113 | // another statement?" |
114 | const bool stmt_has_value = !stmt->is_container_statement(); |
115 | // TODO: We want to know if two identical statements of the type same as |
116 | // stmt can have different values. In most cases, this property is the |
117 | // same as Stmt::common_statement_eliminable(). However, two identical |
118 | // GlobalPtrStmts cannot have different values, although |
119 | // GlobalPtrStmt::common_statement_eliminable() is false. |
120 | |
121 | // ArgLoadStmt can have different type : grad or not grad. |
122 | if (stmt->is<ArgLoadStmt>()) { |
123 | // ArgLoadStmt can have different type : grad or not grad. |
124 | if (stmt->as<ArgLoadStmt>()->is_grad != |
125 | other->as<ArgLoadStmt>()->is_grad) { |
126 | same = false; |
127 | return; |
128 | } |
129 | } |
130 | const bool identical_stmts_can_have_different_value = |
131 | stmt_has_value && !stmt->common_statement_eliminable() && |
132 | !stmt->is<GlobalPtrStmt>(); |
133 | // Note that we do not need to test !stmt2->common_statement_eliminable() |
134 | // because if this condition does not hold, |
135 | // same_value(stmt1, stmt2) returns false anyway. |
136 | if (check_same_value_ && identical_stmts_can_have_different_value) { |
137 | same = false; |
138 | return; |
139 | } |
140 | |
141 | bool field_checked = false; |
142 | if (check_same_value_) { |
143 | if (stmt->is<GlobalPtrStmt>()) { |
144 | // Special case: we do not care about the "activate" field when checking |
145 | // whether two global pointers share the same value. |
146 | // And we cannot use irpass::analysis::definitely_same_address() |
147 | // directly because that function does not support id_map. |
148 | |
149 | if (stmt->as<GlobalPtrStmt>()->snode->id != |
150 | other->as<GlobalPtrStmt>()->snode->id) { |
151 | same = false; |
152 | return; |
153 | } |
154 | field_checked = true; |
155 | } else if (stmt->is<LoopUniqueStmt>()) { |
156 | // Special case: we do not care the "covers" field when checking |
157 | // whether two LoopUniqueStmts share the same value. |
158 | field_checked = true; |
159 | } else if (stmt->is<RangeAssumptionStmt>()) { |
160 | // Special case: we do not care the "low, high" fields when checking |
161 | // whether two RangeAssumptionStmts share the same value. |
162 | field_checked = true; |
163 | } |
164 | } |
165 | if (!field_checked) { |
166 | // field check |
167 | if (!stmt->field_manager.equal(other->field_manager)) { |
168 | same = false; |
169 | return; |
170 | } |
171 | } |
172 | |
173 | bool operand_checked = false; |
174 | if (check_same_value_) { |
175 | if (stmt->is<RangeAssumptionStmt>()) { |
176 | // Special case: we do not care about the "base" operand when checking |
177 | // whether two RangeAssumptionStmts share the same value. |
178 | check_mapping(stmt->as<RangeAssumptionStmt>()->input, |
179 | other->as<RangeAssumptionStmt>()->input); |
180 | operand_checked = true; |
181 | } |
182 | } |
183 | if (!operand_checked) { |
184 | // operand check |
185 | if (stmt->num_operands() != other->num_operands()) { |
186 | same = false; |
187 | return; |
188 | } |
189 | for (int i = 0; i < stmt->num_operands(); i++) { |
190 | if ((stmt->operand(i) == nullptr) != (other->operand(i) == nullptr)) { |
191 | same = false; |
192 | return; |
193 | } |
194 | if (stmt->operand(i) == nullptr) |
195 | continue; |
196 | check_mapping(stmt->operand(i), other->operand(i)); |
197 | } |
198 | } |
199 | |
200 | map_id(stmt->id, other->id); |
201 | } |
202 | |
203 | void visit(Stmt *stmt) override { |
204 | basic_check(stmt); |
205 | } |
206 | |
207 | void visit(IfStmt *stmt) override { |
208 | basic_check(stmt); |
209 | if (!same) |
210 | return; |
211 | auto other = other_node_->as<IfStmt>(); |
212 | if (stmt->true_statements) { |
213 | if (!other->true_statements) { |
214 | same = false; |
215 | return; |
216 | } |
217 | other_node_ = other->true_statements.get(); |
218 | stmt->true_statements->accept(this); |
219 | other_node_ = other; |
220 | } |
221 | if (stmt->false_statements && same) { |
222 | if (!other->false_statements) { |
223 | same = false; |
224 | return; |
225 | } |
226 | other_node_ = other->false_statements.get(); |
227 | stmt->false_statements->accept(this); |
228 | other_node_ = other; |
229 | } |
230 | } |
231 | |
232 | void visit(WhileStmt *stmt) override { |
233 | basic_check(stmt); |
234 | if (!same) |
235 | return; |
236 | auto other = other_node_->as<WhileStmt>(); |
237 | other_node_ = other->body.get(); |
238 | stmt->body->accept(this); |
239 | other_node_ = other; |
240 | } |
241 | |
242 | void visit(RangeForStmt *stmt) override { |
243 | basic_check(stmt); |
244 | if (!same) |
245 | return; |
246 | auto other = other_node_->as<RangeForStmt>(); |
247 | other_node_ = other->body.get(); |
248 | stmt->body->accept(this); |
249 | other_node_ = other; |
250 | } |
251 | |
252 | void visit(StructForStmt *stmt) override { |
253 | basic_check(stmt); |
254 | if (!same) |
255 | return; |
256 | auto other = other_node_->as<StructForStmt>(); |
257 | other_node_ = other->body.get(); |
258 | stmt->body->accept(this); |
259 | other_node_ = other; |
260 | } |
261 | |
262 | void visit(OffloadedStmt *stmt) override { |
263 | basic_check(stmt); |
264 | if (!same) |
265 | return; |
266 | auto other = other_node_->as<OffloadedStmt>(); |
267 | if (stmt->has_body()) { |
268 | TI_ASSERT(stmt->body); |
269 | TI_ASSERT(other->body); |
270 | other_node_ = other->body.get(); |
271 | stmt->body->accept(this); |
272 | other_node_ = other; |
273 | } |
274 | } |
275 | |
276 | static bool run(IRNode *root1, |
277 | IRNode *root2, |
278 | const std::optional<std::unordered_map<int, int>> &id_map, |
279 | bool check_same_value) { |
280 | IRNodeComparator comparator(root2, id_map, check_same_value); |
281 | root1->accept(&comparator); |
282 | return comparator.same; |
283 | } |
284 | }; |
285 | |
286 | namespace irpass::analysis { |
287 | bool same_statements( |
288 | IRNode *root1, |
289 | IRNode *root2, |
290 | const std::optional<std::unordered_map<int, int>> &id_map) { |
291 | // When id_map is std::nullopt by default, this function tests if |
292 | // root1 and root2 are the same, i.e., have the same type, |
293 | // the same operands and the same fields. |
294 | // If root1 and root2 are container statements or statement blocks, |
295 | // this function traverses the contents correspondingly. |
296 | // Two operands are considered the same if they have the same id |
297 | // and do not belong to either root, or they belong to root1 and root2 |
298 | // at the same position in the roots. |
299 | // |
300 | // For example, same_statements(block1, block2, std::nullopt) is true: |
301 | // <i32> $1 = ... |
302 | // block1 : { |
303 | // <i32> $2 = const [1] |
304 | // <i32> $3 = add $1 $2 |
305 | // } |
306 | // block2 : { |
307 | // <i32> $4 = const [1] |
308 | // <i32> $5 = add $1 $4 |
309 | // } |
310 | // |
311 | // If id_map is not std::nullopt, this function also recursively |
312 | // check the operands until ids in the id_map are reached. |
313 | // id_map is an id map from root1 to root2. |
314 | // |
315 | // In the above example, same_statements($3, $5, std::nullopt) is false |
316 | // but same_statements($3, $5, (an empty map)) is true. |
317 | // |
318 | // In the following example, same_statements($3, $6, id_map) is true |
319 | // iff id_map[1] == 4 && id_map[2] == 5: |
320 | // <i32> $3 = add $1 $2 |
321 | // <i32> $6 = add $4 $5 |
322 | // |
323 | // Null pointers as IRNodes are defined to be NOT the same as any other |
324 | // IRNode, except for another nullptr IRNode. |
325 | if (root1 == root2) |
326 | return true; |
327 | if (!root1 || !root2) |
328 | return false; |
329 | return IRNodeComparator::run(root1, root2, id_map, |
330 | /*check_same_value=*/false); |
331 | } |
332 | |
333 | bool same_value(Stmt *stmt1, |
334 | Stmt *stmt2, |
335 | const std::optional<std::unordered_map<int, int>> &id_map) { |
336 | // Test if two statements definitely have the same value. |
337 | if (stmt1 == stmt2) |
338 | return true; |
339 | if (!stmt1 || !stmt2) |
340 | return false; |
341 | return IRNodeComparator::run(stmt1, stmt2, id_map, |
342 | /*check_same_value=*/true); |
343 | } |
344 | } // namespace irpass::analysis |
345 | |
346 | } // namespace taichi::lang |
347 | |