1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/system/profiler.h"
6
7namespace taichi::lang {
8
9// Replace all usages statement A with a new statement B.
10// Note that the original statement A is NOT replaced.
11class StatementUsageReplace : public IRVisitor {
12 // The reason why we don't use BasicStmtVisitor is we don't want to go into
13 // FrontendForStmt.
14 public:
15 Stmt *old_stmt, *new_stmt;
16
17 StatementUsageReplace(Stmt *old_stmt, Stmt *new_stmt)
18 : old_stmt(old_stmt), new_stmt(new_stmt) {
19 allow_undefined_visitor = true;
20 invoke_default_visitor = true;
21 }
22
23 void visit(Stmt *stmt) override {
24 stmt->replace_operand_with(old_stmt, new_stmt);
25 }
26
27 void visit(WhileStmt *stmt) override {
28 stmt->replace_operand_with(old_stmt, new_stmt);
29 stmt->body->accept(this);
30 }
31
32 void visit(IfStmt *if_stmt) override {
33 if_stmt->replace_operand_with(old_stmt, new_stmt);
34 if (if_stmt->true_statements)
35 if_stmt->true_statements->accept(this);
36 if (if_stmt->false_statements) {
37 if_stmt->false_statements->accept(this);
38 }
39 }
40
41 void visit(Block *stmt_list) override {
42 for (auto &stmt : stmt_list->statements) {
43 stmt->accept(this);
44 }
45 }
46
47 void visit(RangeForStmt *stmt) override {
48 stmt->replace_operand_with(old_stmt, new_stmt);
49 stmt->body->accept(this);
50 }
51
52 void visit(StructForStmt *stmt) override {
53 stmt->body->accept(this);
54 }
55
56 void visit(MeshForStmt *stmt) override {
57 stmt->body->accept(this);
58 }
59
60 void visit(OffloadedStmt *stmt) override {
61 stmt->all_blocks_accept(this);
62 }
63
64 static void run(IRNode *root, Stmt *old_stmt, Stmt *new_stmt) {
65 StatementUsageReplace replacer(old_stmt, new_stmt);
66 if (root != nullptr) {
67 // If root is specified, simply traverse the root.
68 root->accept(&replacer);
69 return;
70 }
71
72 // statements inside old_stmt->parent
73 TI_ASSERT(old_stmt->parent != nullptr);
74 old_stmt->parent->accept(&replacer);
75 auto current_block = old_stmt->parent->parent_block();
76
77 // statements outside old_stmt->parent: bottom-up
78 while (current_block != nullptr) {
79 for (auto &stmt : current_block->statements) {
80 stmt->replace_operand_with(old_stmt, new_stmt);
81 }
82 current_block = current_block->parent_block();
83 }
84 }
85};
86
87namespace irpass {
88
89void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt) {
90 TI_AUTO_PROF;
91 StatementUsageReplace::run(root, old_stmt, new_stmt);
92}
93
94} // namespace irpass
95
96} // namespace taichi::lang
97