1 | // Dead Instruction Elimination |
---|---|
2 | |
3 | #include "taichi/ir/ir.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/ir/transforms.h" |
6 | #include "taichi/ir/visitors.h" |
7 | #include "taichi/system/profiler.h" |
8 | |
9 | #include <unordered_set> |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | // Dead Instruction Elimination |
14 | class DIE : public IRVisitor { |
15 | public: |
16 | std::unordered_set<int> used; |
17 | int phase; // 0: mark usage 1: eliminate |
18 | DelayedIRModifier modifier; |
19 | bool modified_ir; |
20 | |
21 | explicit DIE(IRNode *node) { |
22 | allow_undefined_visitor = true; |
23 | invoke_default_visitor = true; |
24 | modified_ir = false; |
25 | while (true) { |
26 | bool modified = false; |
27 | phase = 0; |
28 | used.clear(); |
29 | node->accept(this); |
30 | phase = 1; |
31 | while (true) { |
32 | node->accept(this); |
33 | if (modifier.modify_ir()) { |
34 | modified = true; |
35 | modified_ir = true; |
36 | continue; |
37 | } |
38 | break; |
39 | } |
40 | if (!modified) |
41 | break; |
42 | } |
43 | } |
44 | |
45 | void register_usage(Stmt *stmt) { |
46 | for (auto op : stmt->get_operands()) { |
47 | if (op) { // might be nullptr |
48 | if (used.find(op->instance_id) == used.end()) { |
49 | used.insert(op->instance_id); |
50 | } |
51 | } |
52 | } |
53 | } |
54 | |
55 | void visit(Stmt *stmt) override { |
56 | TI_ASSERT(!stmt->erased); |
57 | if (phase == 0) { |
58 | register_usage(stmt); |
59 | } else { |
60 | if (stmt->dead_instruction_eliminable() && |
61 | used.find(stmt->instance_id) == used.end()) { |
62 | modifier.erase(stmt); |
63 | } |
64 | } |
65 | } |
66 | |
67 | void visit(Block *stmt_list) override { |
68 | for (auto &stmt : stmt_list->statements) { |
69 | stmt->accept(this); |
70 | } |
71 | } |
72 | |
73 | void visit(IfStmt *if_stmt) override { |
74 | register_usage(if_stmt); |
75 | if (if_stmt->true_statements) |
76 | if_stmt->true_statements->accept(this); |
77 | if (if_stmt->false_statements) { |
78 | if_stmt->false_statements->accept(this); |
79 | } |
80 | } |
81 | |
82 | void visit(WhileStmt *stmt) override { |
83 | register_usage(stmt); |
84 | stmt->body->accept(this); |
85 | } |
86 | |
87 | void visit(RangeForStmt *for_stmt) override { |
88 | register_usage(for_stmt); |
89 | for_stmt->body->accept(this); |
90 | } |
91 | |
92 | void visit(StructForStmt *for_stmt) override { |
93 | register_usage(for_stmt); |
94 | for_stmt->body->accept(this); |
95 | } |
96 | |
97 | void visit(MeshForStmt *for_stmt) override { |
98 | register_usage(for_stmt); |
99 | for_stmt->body->accept(this); |
100 | } |
101 | |
102 | void visit(OffloadedStmt *stmt) override { |
103 | // TODO: A hack to make sure end_stmt is registered. |
104 | // Ideally end_stmt should be its own Block instead. |
105 | if (stmt->end_stmt && |
106 | used.find(stmt->end_stmt->instance_id) == used.end()) { |
107 | used.insert(stmt->end_stmt->instance_id); |
108 | } |
109 | stmt->all_blocks_accept(this, true); |
110 | } |
111 | }; |
112 | |
113 | namespace irpass { |
114 | |
115 | bool die(IRNode *root) { |
116 | TI_AUTO_PROF; |
117 | DIE instance(root); |
118 | return instance.modified_ir; |
119 | } |
120 | |
121 | } // namespace irpass |
122 | |
123 | } // namespace taichi::lang |
124 |