1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/system/profiler.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | // Replace all usages statement A with a new statement B. |
10 | // Note that the original statement A is NOT replaced. |
11 | class StatementUsageReplace : public IRVisitor { |
12 | // The reason why we don't use BasicStmtVisitor is we don't want to go into |
13 | // FrontendForStmt. |
14 | public: |
15 | Stmt *old_stmt, *new_stmt; |
16 | |
17 | StatementUsageReplace(Stmt *old_stmt, Stmt *new_stmt) |
18 | : old_stmt(old_stmt), new_stmt(new_stmt) { |
19 | allow_undefined_visitor = true; |
20 | invoke_default_visitor = true; |
21 | } |
22 | |
23 | void visit(Stmt *stmt) override { |
24 | stmt->replace_operand_with(old_stmt, new_stmt); |
25 | } |
26 | |
27 | void visit(WhileStmt *stmt) override { |
28 | stmt->replace_operand_with(old_stmt, new_stmt); |
29 | stmt->body->accept(this); |
30 | } |
31 | |
32 | void visit(IfStmt *if_stmt) override { |
33 | if_stmt->replace_operand_with(old_stmt, new_stmt); |
34 | if (if_stmt->true_statements) |
35 | if_stmt->true_statements->accept(this); |
36 | if (if_stmt->false_statements) { |
37 | if_stmt->false_statements->accept(this); |
38 | } |
39 | } |
40 | |
41 | void visit(Block *stmt_list) override { |
42 | for (auto &stmt : stmt_list->statements) { |
43 | stmt->accept(this); |
44 | } |
45 | } |
46 | |
47 | void visit(RangeForStmt *stmt) override { |
48 | stmt->replace_operand_with(old_stmt, new_stmt); |
49 | stmt->body->accept(this); |
50 | } |
51 | |
52 | void visit(StructForStmt *stmt) override { |
53 | stmt->body->accept(this); |
54 | } |
55 | |
56 | void visit(MeshForStmt *stmt) override { |
57 | stmt->body->accept(this); |
58 | } |
59 | |
60 | void visit(OffloadedStmt *stmt) override { |
61 | stmt->all_blocks_accept(this); |
62 | } |
63 | |
64 | static void run(IRNode *root, Stmt *old_stmt, Stmt *new_stmt) { |
65 | StatementUsageReplace replacer(old_stmt, new_stmt); |
66 | if (root != nullptr) { |
67 | // If root is specified, simply traverse the root. |
68 | root->accept(&replacer); |
69 | return; |
70 | } |
71 | |
72 | // statements inside old_stmt->parent |
73 | TI_ASSERT(old_stmt->parent != nullptr); |
74 | old_stmt->parent->accept(&replacer); |
75 | auto current_block = old_stmt->parent->parent_block(); |
76 | |
77 | // statements outside old_stmt->parent: bottom-up |
78 | while (current_block != nullptr) { |
79 | for (auto &stmt : current_block->statements) { |
80 | stmt->replace_operand_with(old_stmt, new_stmt); |
81 | } |
82 | current_block = current_block->parent_block(); |
83 | } |
84 | } |
85 | }; |
86 | |
87 | namespace irpass { |
88 | |
89 | void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt) { |
90 | TI_AUTO_PROF; |
91 | StatementUsageReplace::run(root, old_stmt, new_stmt); |
92 | } |
93 | |
94 | } // namespace irpass |
95 | |
96 | } // namespace taichi::lang |
97 | |