1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/system/profiler.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | // Unconditionally eliminate ContinueStmt's at **ends** of loops |
10 | class UselessContinueEliminator : public IRVisitor { |
11 | public: |
12 | bool modified; |
13 | |
14 | UselessContinueEliminator() : modified(false) { |
15 | allow_undefined_visitor = true; |
16 | } |
17 | |
18 | void visit(ContinueStmt *stmt) override { |
19 | stmt->parent->erase(stmt); |
20 | modified = true; |
21 | } |
22 | |
23 | void visit(IfStmt *if_stmt) override { |
24 | if (if_stmt->true_statements && if_stmt->true_statements->size()) |
25 | if_stmt->true_statements->back()->accept(this); |
26 | if (if_stmt->false_statements && if_stmt->false_statements->size()) |
27 | if_stmt->false_statements->back()->accept(this); |
28 | } |
29 | }; |
30 | |
31 | // Eliminate useless ContinueStmt, the statements after ContinueStmt and |
32 | // unreachable if branches |
33 | class UnreachableCodeEliminator : public BasicStmtVisitor { |
34 | public: |
35 | using BasicStmtVisitor::visit; |
36 | bool modified; |
37 | UselessContinueEliminator useless_continue_eliminator; |
38 | DelayedIRModifier modifier; |
39 | |
40 | UnreachableCodeEliminator() : modified(false) { |
41 | allow_undefined_visitor = true; |
42 | } |
43 | |
44 | void visit(Block *stmt_list) override { |
45 | const int block_size = stmt_list->size(); |
46 | for (int i = 0; i < block_size - 1; i++) { |
47 | if (stmt_list->statements[i]->is<ContinueStmt>()) { |
48 | // Eliminate statements after ContinueStmt |
49 | for (int j = block_size - 1; j > i; j--) |
50 | stmt_list->erase(j); |
51 | modified = true; |
52 | break; |
53 | } |
54 | } |
55 | for (auto &stmt : stmt_list->statements) |
56 | stmt->accept(this); |
57 | } |
58 | |
59 | void visit_loop(Block *body) { |
60 | if (body->size()) |
61 | body->back()->accept(&useless_continue_eliminator); |
62 | body->accept(this); |
63 | } |
64 | |
65 | void visit(RangeForStmt *stmt) override { |
66 | visit_loop(stmt->body.get()); |
67 | } |
68 | |
69 | void visit(StructForStmt *stmt) override { |
70 | visit_loop(stmt->body.get()); |
71 | } |
72 | |
73 | void visit(MeshForStmt *stmt) override { |
74 | visit_loop(stmt->body.get()); |
75 | } |
76 | |
77 | void visit(WhileStmt *stmt) override { |
78 | visit_loop(stmt->body.get()); |
79 | } |
80 | |
81 | void visit(OffloadedStmt *stmt) override { |
82 | if (stmt->tls_prologue) |
83 | stmt->tls_prologue->accept(this); |
84 | |
85 | if (stmt->mesh_prologue) |
86 | stmt->mesh_prologue->accept(this); |
87 | |
88 | if (stmt->bls_prologue) |
89 | stmt->bls_prologue->accept(this); |
90 | |
91 | if (stmt->task_type == OffloadedStmt::TaskType::range_for || |
92 | stmt->task_type == OffloadedStmt::TaskType::mesh_for || |
93 | stmt->task_type == OffloadedStmt::TaskType::struct_for) |
94 | visit_loop(stmt->body.get()); |
95 | else if (stmt->body) |
96 | stmt->body->accept(this); |
97 | |
98 | if (stmt->bls_epilogue) |
99 | stmt->bls_epilogue->accept(this); |
100 | |
101 | if (stmt->tls_epilogue) |
102 | stmt->tls_epilogue->accept(this); |
103 | } |
104 | |
105 | void visit(IfStmt *if_stmt) override { |
106 | if (if_stmt->cond->is<ConstStmt>()) { |
107 | if (if_stmt->cond->as<ConstStmt>()->val.equal_value(0)) { |
108 | // if (0) |
109 | if (if_stmt->false_statements) { |
110 | modifier.insert_before( |
111 | if_stmt, |
112 | VecStatement(std::move(if_stmt->false_statements->statements))); |
113 | } |
114 | modifier.erase(if_stmt); |
115 | modified = true; |
116 | return; |
117 | } else { |
118 | // if (1) |
119 | if (if_stmt->true_statements) { |
120 | modifier.insert_before( |
121 | if_stmt, |
122 | VecStatement(std::move(if_stmt->true_statements->statements))); |
123 | } |
124 | modifier.erase(if_stmt); |
125 | modified = true; |
126 | return; |
127 | } |
128 | } |
129 | if (if_stmt->true_statements) |
130 | if_stmt->true_statements->accept(this); |
131 | if (if_stmt->false_statements) |
132 | if_stmt->false_statements->accept(this); |
133 | } |
134 | |
135 | static bool run(IRNode *node) { |
136 | bool modified = false; |
137 | while (true) { |
138 | UnreachableCodeEliminator eliminator; |
139 | node->accept(&eliminator); |
140 | eliminator.modifier.modify_ir(); |
141 | if (eliminator.modified || |
142 | eliminator.useless_continue_eliminator.modified) { |
143 | modified = true; |
144 | } else { |
145 | break; |
146 | } |
147 | } |
148 | return modified; |
149 | } |
150 | }; |
151 | |
152 | namespace irpass { |
153 | bool unreachable_code_elimination(IRNode *root) { |
154 | TI_AUTO_PROF; |
155 | return UnreachableCodeEliminator::run(root); |
156 | } |
157 | } // namespace irpass |
158 | |
159 | } // namespace taichi::lang |
160 |