1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/system/profiler.h"
6
7namespace taichi::lang {
8
9// Unconditionally eliminate ContinueStmt's at **ends** of loops
10class UselessContinueEliminator : public IRVisitor {
11 public:
12 bool modified;
13
14 UselessContinueEliminator() : modified(false) {
15 allow_undefined_visitor = true;
16 }
17
18 void visit(ContinueStmt *stmt) override {
19 stmt->parent->erase(stmt);
20 modified = true;
21 }
22
23 void visit(IfStmt *if_stmt) override {
24 if (if_stmt->true_statements && if_stmt->true_statements->size())
25 if_stmt->true_statements->back()->accept(this);
26 if (if_stmt->false_statements && if_stmt->false_statements->size())
27 if_stmt->false_statements->back()->accept(this);
28 }
29};
30
31// Eliminate useless ContinueStmt, the statements after ContinueStmt and
32// unreachable if branches
33class UnreachableCodeEliminator : public BasicStmtVisitor {
34 public:
35 using BasicStmtVisitor::visit;
36 bool modified;
37 UselessContinueEliminator useless_continue_eliminator;
38 DelayedIRModifier modifier;
39
40 UnreachableCodeEliminator() : modified(false) {
41 allow_undefined_visitor = true;
42 }
43
44 void visit(Block *stmt_list) override {
45 const int block_size = stmt_list->size();
46 for (int i = 0; i < block_size - 1; i++) {
47 if (stmt_list->statements[i]->is<ContinueStmt>()) {
48 // Eliminate statements after ContinueStmt
49 for (int j = block_size - 1; j > i; j--)
50 stmt_list->erase(j);
51 modified = true;
52 break;
53 }
54 }
55 for (auto &stmt : stmt_list->statements)
56 stmt->accept(this);
57 }
58
59 void visit_loop(Block *body) {
60 if (body->size())
61 body->back()->accept(&useless_continue_eliminator);
62 body->accept(this);
63 }
64
65 void visit(RangeForStmt *stmt) override {
66 visit_loop(stmt->body.get());
67 }
68
69 void visit(StructForStmt *stmt) override {
70 visit_loop(stmt->body.get());
71 }
72
73 void visit(MeshForStmt *stmt) override {
74 visit_loop(stmt->body.get());
75 }
76
77 void visit(WhileStmt *stmt) override {
78 visit_loop(stmt->body.get());
79 }
80
81 void visit(OffloadedStmt *stmt) override {
82 if (stmt->tls_prologue)
83 stmt->tls_prologue->accept(this);
84
85 if (stmt->mesh_prologue)
86 stmt->mesh_prologue->accept(this);
87
88 if (stmt->bls_prologue)
89 stmt->bls_prologue->accept(this);
90
91 if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
92 stmt->task_type == OffloadedStmt::TaskType::mesh_for ||
93 stmt->task_type == OffloadedStmt::TaskType::struct_for)
94 visit_loop(stmt->body.get());
95 else if (stmt->body)
96 stmt->body->accept(this);
97
98 if (stmt->bls_epilogue)
99 stmt->bls_epilogue->accept(this);
100
101 if (stmt->tls_epilogue)
102 stmt->tls_epilogue->accept(this);
103 }
104
105 void visit(IfStmt *if_stmt) override {
106 if (if_stmt->cond->is<ConstStmt>()) {
107 if (if_stmt->cond->as<ConstStmt>()->val.equal_value(0)) {
108 // if (0)
109 if (if_stmt->false_statements) {
110 modifier.insert_before(
111 if_stmt,
112 VecStatement(std::move(if_stmt->false_statements->statements)));
113 }
114 modifier.erase(if_stmt);
115 modified = true;
116 return;
117 } else {
118 // if (1)
119 if (if_stmt->true_statements) {
120 modifier.insert_before(
121 if_stmt,
122 VecStatement(std::move(if_stmt->true_statements->statements)));
123 }
124 modifier.erase(if_stmt);
125 modified = true;
126 return;
127 }
128 }
129 if (if_stmt->true_statements)
130 if_stmt->true_statements->accept(this);
131 if (if_stmt->false_statements)
132 if_stmt->false_statements->accept(this);
133 }
134
135 static bool run(IRNode *node) {
136 bool modified = false;
137 while (true) {
138 UnreachableCodeEliminator eliminator;
139 node->accept(&eliminator);
140 eliminator.modifier.modify_ir();
141 if (eliminator.modified ||
142 eliminator.useless_continue_eliminator.modified) {
143 modified = true;
144 } else {
145 break;
146 }
147 }
148 return modified;
149 }
150};
151
152namespace irpass {
153bool unreachable_code_elimination(IRNode *root) {
154 TI_AUTO_PROF;
155 return UnreachableCodeEliminator::run(root);
156}
157} // namespace irpass
158
159} // namespace taichi::lang
160