1#include "taichi/ir/ir.h"
2#include "taichi/ir/analysis.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/visitors.h"
5#include <unordered_map>
6#include <unordered_set>
7
8namespace taichi::lang {
9
10// Compare if two IRNodes are equivalent.
11class IRNodeComparator : public IRVisitor {
12 private:
13 IRNode *other_node_;
14 // map the id from this node to the other node
15 std::unordered_map<int, int> id_map_;
16
17 bool recursively_check_;
18
19 // Compare if two IRNodes definitely have the same value instead.
20 // When this is true, it's weaker in the sense that we don't require the
21 // activate field in the GlobalPtrStmt to be the same, but stronger in the
22 // sense that we require the value to be the same (especially stronger in
23 // GlobalLoadStmt, RandStmt, etc.).
24 bool check_same_value_;
25
26 public:
27 bool same;
28
29 explicit IRNodeComparator(
30 IRNode *other_node,
31 const std::optional<std::unordered_map<int, int>> &id_map,
32 bool check_same_value)
33 : other_node_(other_node) {
34 allow_undefined_visitor = true;
35 invoke_default_visitor = true;
36 same = true;
37 if (id_map.has_value()) {
38 recursively_check_ = true;
39 this->id_map_ = id_map.value();
40 } else {
41 recursively_check_ = false;
42 }
43 check_same_value_ = check_same_value;
44 }
45
46 void map_id(int this_id, int other_id) {
47 auto it = id_map_.find(this_id);
48 if (it == id_map_.end()) {
49 id_map_[this_id] = other_id;
50 } else if (it->second != other_id) {
51 same = false;
52 }
53 }
54
55 void check_mapping(Stmt *this_stmt, Stmt *other_stmt) {
56 // get the corresponding id in the other node
57 // and check if it is other_stmt->id
58 auto it = id_map_.find(this_stmt->id);
59 if (it != id_map_.end()) {
60 if (it->second != other_stmt->id) {
61 same = false;
62 }
63 return;
64 }
65 if (!recursively_check_) {
66 // use identity mapping if not found
67 if (this_stmt->id != other_stmt->id) {
68 same = false;
69 }
70 id_map_[this_stmt->id] = other_stmt->id;
71 } else {
72 // recursively check them
73 IRNode *backup_other_node = other_node_;
74 other_node_ = other_stmt;
75 this_stmt->accept(this);
76 other_node_ = backup_other_node;
77 }
78 }
79
80 void visit(Block *stmt_list) override {
81 if (!other_node_->is<Block>()) {
82 same = false;
83 return;
84 }
85
86 auto other = other_node_->as<Block>();
87 if (stmt_list->size() != other->size()) {
88 same = false;
89 return;
90 }
91 for (int i = 0; i < (int)stmt_list->size(); i++) {
92 other_node_ = other->statements[i].get();
93 stmt_list->statements[i]->accept(this);
94 if (!same)
95 break;
96 }
97 other_node_ = other;
98 }
99
100 void basic_check(Stmt *stmt) {
101 // type check
102 if (typeid(*other_node_) != typeid(*stmt)) {
103 same = false;
104 return;
105 }
106 auto other = other_node_->as<Stmt>();
107 if (stmt == other) {
108 return;
109 }
110
111 // If two identical statements can have different values, return false.
112 // TODO: actually the condition should be "can stmt be an operand of
113 // another statement?"
114 const bool stmt_has_value = !stmt->is_container_statement();
115 // TODO: We want to know if two identical statements of the type same as
116 // stmt can have different values. In most cases, this property is the
117 // same as Stmt::common_statement_eliminable(). However, two identical
118 // GlobalPtrStmts cannot have different values, although
119 // GlobalPtrStmt::common_statement_eliminable() is false.
120
121 // ArgLoadStmt can have different type : grad or not grad.
122 if (stmt->is<ArgLoadStmt>()) {
123 // ArgLoadStmt can have different type : grad or not grad.
124 if (stmt->as<ArgLoadStmt>()->is_grad !=
125 other->as<ArgLoadStmt>()->is_grad) {
126 same = false;
127 return;
128 }
129 }
130 const bool identical_stmts_can_have_different_value =
131 stmt_has_value && !stmt->common_statement_eliminable() &&
132 !stmt->is<GlobalPtrStmt>();
133 // Note that we do not need to test !stmt2->common_statement_eliminable()
134 // because if this condition does not hold,
135 // same_value(stmt1, stmt2) returns false anyway.
136 if (check_same_value_ && identical_stmts_can_have_different_value) {
137 same = false;
138 return;
139 }
140
141 bool field_checked = false;
142 if (check_same_value_) {
143 if (stmt->is<GlobalPtrStmt>()) {
144 // Special case: we do not care about the "activate" field when checking
145 // whether two global pointers share the same value.
146 // And we cannot use irpass::analysis::definitely_same_address()
147 // directly because that function does not support id_map.
148
149 if (stmt->as<GlobalPtrStmt>()->snode->id !=
150 other->as<GlobalPtrStmt>()->snode->id) {
151 same = false;
152 return;
153 }
154 field_checked = true;
155 } else if (stmt->is<LoopUniqueStmt>()) {
156 // Special case: we do not care the "covers" field when checking
157 // whether two LoopUniqueStmts share the same value.
158 field_checked = true;
159 } else if (stmt->is<RangeAssumptionStmt>()) {
160 // Special case: we do not care the "low, high" fields when checking
161 // whether two RangeAssumptionStmts share the same value.
162 field_checked = true;
163 }
164 }
165 if (!field_checked) {
166 // field check
167 if (!stmt->field_manager.equal(other->field_manager)) {
168 same = false;
169 return;
170 }
171 }
172
173 bool operand_checked = false;
174 if (check_same_value_) {
175 if (stmt->is<RangeAssumptionStmt>()) {
176 // Special case: we do not care about the "base" operand when checking
177 // whether two RangeAssumptionStmts share the same value.
178 check_mapping(stmt->as<RangeAssumptionStmt>()->input,
179 other->as<RangeAssumptionStmt>()->input);
180 operand_checked = true;
181 }
182 }
183 if (!operand_checked) {
184 // operand check
185 if (stmt->num_operands() != other->num_operands()) {
186 same = false;
187 return;
188 }
189 for (int i = 0; i < stmt->num_operands(); i++) {
190 if ((stmt->operand(i) == nullptr) != (other->operand(i) == nullptr)) {
191 same = false;
192 return;
193 }
194 if (stmt->operand(i) == nullptr)
195 continue;
196 check_mapping(stmt->operand(i), other->operand(i));
197 }
198 }
199
200 map_id(stmt->id, other->id);
201 }
202
203 void visit(Stmt *stmt) override {
204 basic_check(stmt);
205 }
206
207 void visit(IfStmt *stmt) override {
208 basic_check(stmt);
209 if (!same)
210 return;
211 auto other = other_node_->as<IfStmt>();
212 if (stmt->true_statements) {
213 if (!other->true_statements) {
214 same = false;
215 return;
216 }
217 other_node_ = other->true_statements.get();
218 stmt->true_statements->accept(this);
219 other_node_ = other;
220 }
221 if (stmt->false_statements && same) {
222 if (!other->false_statements) {
223 same = false;
224 return;
225 }
226 other_node_ = other->false_statements.get();
227 stmt->false_statements->accept(this);
228 other_node_ = other;
229 }
230 }
231
232 void visit(WhileStmt *stmt) override {
233 basic_check(stmt);
234 if (!same)
235 return;
236 auto other = other_node_->as<WhileStmt>();
237 other_node_ = other->body.get();
238 stmt->body->accept(this);
239 other_node_ = other;
240 }
241
242 void visit(RangeForStmt *stmt) override {
243 basic_check(stmt);
244 if (!same)
245 return;
246 auto other = other_node_->as<RangeForStmt>();
247 other_node_ = other->body.get();
248 stmt->body->accept(this);
249 other_node_ = other;
250 }
251
252 void visit(StructForStmt *stmt) override {
253 basic_check(stmt);
254 if (!same)
255 return;
256 auto other = other_node_->as<StructForStmt>();
257 other_node_ = other->body.get();
258 stmt->body->accept(this);
259 other_node_ = other;
260 }
261
262 void visit(OffloadedStmt *stmt) override {
263 basic_check(stmt);
264 if (!same)
265 return;
266 auto other = other_node_->as<OffloadedStmt>();
267 if (stmt->has_body()) {
268 TI_ASSERT(stmt->body);
269 TI_ASSERT(other->body);
270 other_node_ = other->body.get();
271 stmt->body->accept(this);
272 other_node_ = other;
273 }
274 }
275
276 static bool run(IRNode *root1,
277 IRNode *root2,
278 const std::optional<std::unordered_map<int, int>> &id_map,
279 bool check_same_value) {
280 IRNodeComparator comparator(root2, id_map, check_same_value);
281 root1->accept(&comparator);
282 return comparator.same;
283 }
284};
285
286namespace irpass::analysis {
287bool same_statements(
288 IRNode *root1,
289 IRNode *root2,
290 const std::optional<std::unordered_map<int, int>> &id_map) {
291 // When id_map is std::nullopt by default, this function tests if
292 // root1 and root2 are the same, i.e., have the same type,
293 // the same operands and the same fields.
294 // If root1 and root2 are container statements or statement blocks,
295 // this function traverses the contents correspondingly.
296 // Two operands are considered the same if they have the same id
297 // and do not belong to either root, or they belong to root1 and root2
298 // at the same position in the roots.
299 //
300 // For example, same_statements(block1, block2, std::nullopt) is true:
301 // <i32> $1 = ...
302 // block1 : {
303 // <i32> $2 = const [1]
304 // <i32> $3 = add $1 $2
305 // }
306 // block2 : {
307 // <i32> $4 = const [1]
308 // <i32> $5 = add $1 $4
309 // }
310 //
311 // If id_map is not std::nullopt, this function also recursively
312 // check the operands until ids in the id_map are reached.
313 // id_map is an id map from root1 to root2.
314 //
315 // In the above example, same_statements($3, $5, std::nullopt) is false
316 // but same_statements($3, $5, (an empty map)) is true.
317 //
318 // In the following example, same_statements($3, $6, id_map) is true
319 // iff id_map[1] == 4 && id_map[2] == 5:
320 // <i32> $3 = add $1 $2
321 // <i32> $6 = add $4 $5
322 //
323 // Null pointers as IRNodes are defined to be NOT the same as any other
324 // IRNode, except for another nullptr IRNode.
325 if (root1 == root2)
326 return true;
327 if (!root1 || !root2)
328 return false;
329 return IRNodeComparator::run(root1, root2, id_map,
330 /*check_same_value=*/false);
331}
332
333bool same_value(Stmt *stmt1,
334 Stmt *stmt2,
335 const std::optional<std::unordered_map<int, int>> &id_map) {
336 // Test if two statements definitely have the same value.
337 if (stmt1 == stmt2)
338 return true;
339 if (!stmt1 || !stmt2)
340 return false;
341 return IRNodeComparator::run(stmt1, stmt2, id_map,
342 /*check_same_value=*/true);
343}
344} // namespace irpass::analysis
345
346} // namespace taichi::lang
347