1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/system/profiler.h" |
6 | |
7 | #include <stack> |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | class LoopInvariantDetector : public BasicStmtVisitor { |
12 | public: |
13 | using BasicStmtVisitor::visit; |
14 | |
15 | std::vector<Block *> loop_blocks; |
16 | |
17 | const CompileConfig &config; |
18 | |
19 | explicit LoopInvariantDetector(const CompileConfig &config) : config(config) { |
20 | allow_undefined_visitor = true; |
21 | } |
22 | |
23 | bool is_operand_loop_invariant_impl(Stmt *operand, |
24 | Block *current_scope, |
25 | Block *loop_block = nullptr) { |
26 | if (!loop_block) { |
27 | loop_block = loop_blocks.back(); |
28 | } |
29 | if (operand->parent == current_scope) { |
30 | // This statement has an operand that is in the current scope, |
31 | // so it can not be moved out of the scope. |
32 | return false; |
33 | } |
34 | if (current_scope != loop_block) { |
35 | // If we enable moving code from a nested if block, we need to check |
36 | // visibility. Example: |
37 | // for i in range(10): |
38 | // a = x[0] |
39 | // if b: |
40 | // c = a + 1 |
41 | // Since we are moving statements outside the closest for scope, |
42 | // We need to check the scope of the operand |
43 | Stmt *operand_parent = operand; |
44 | while (operand_parent->parent) { |
45 | operand_parent = operand_parent->parent->parent_stmt; |
46 | if (!operand_parent) |
47 | break; |
48 | // If the one of the current_scope of the operand is the top loop |
49 | // scope Then it will not be visible if we move it outside the top |
50 | // loop scope |
51 | if (operand_parent == loop_block->parent_stmt) { |
52 | return false; |
53 | } |
54 | } |
55 | } |
56 | return true; |
57 | } |
58 | |
59 | bool is_operand_loop_invariant(Stmt *operand, |
60 | Block *current_scope, |
61 | int depth = -1) { |
62 | if (depth == -1) { |
63 | depth = loop_blocks.size() - 1; |
64 | } |
65 | if (depth <= 0) |
66 | return false; |
67 | return is_operand_loop_invariant_impl(operand, current_scope, |
68 | loop_blocks[depth]); |
69 | } |
70 | |
71 | bool is_loop_invariant(Stmt *stmt, Block *current_scope) { |
72 | if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if && |
73 | current_scope != loop_blocks.back())) |
74 | return false; |
75 | |
76 | bool is_invariant = true; |
77 | |
78 | for (Stmt *operand : stmt->get_operands()) { |
79 | if (operand == nullptr) |
80 | continue; |
81 | is_invariant &= is_operand_loop_invariant_impl(operand, current_scope); |
82 | } |
83 | |
84 | return is_invariant; |
85 | } |
86 | |
87 | Stmt *get_loop_stmt(int depth) { |
88 | return loop_blocks[depth]->parent_stmt; |
89 | } |
90 | |
91 | Stmt *current_loop_stmt() { |
92 | return get_loop_stmt(loop_blocks.size() - 1); |
93 | } |
94 | void visit(Block *stmt_list) override { |
95 | for (auto &stmt : stmt_list->statements) |
96 | stmt->accept(this); |
97 | } |
98 | |
99 | virtual void visit_loop(Block *body) { |
100 | loop_blocks.push_back(body); |
101 | |
102 | body->accept(this); |
103 | |
104 | loop_blocks.pop_back(); |
105 | } |
106 | |
107 | void visit(RangeForStmt *stmt) override { |
108 | visit_loop(stmt->body.get()); |
109 | } |
110 | |
111 | void visit(StructForStmt *stmt) override { |
112 | visit_loop(stmt->body.get()); |
113 | } |
114 | |
115 | void visit(MeshForStmt *stmt) override { |
116 | visit_loop(stmt->body.get()); |
117 | } |
118 | |
119 | void visit(WhileStmt *stmt) override { |
120 | visit_loop(stmt->body.get()); |
121 | } |
122 | |
123 | void visit(OffloadedStmt *stmt) override { |
124 | if (stmt->tls_prologue) |
125 | stmt->tls_prologue->accept(this); |
126 | |
127 | if (stmt->mesh_prologue) |
128 | stmt->mesh_prologue->accept(this); |
129 | |
130 | if (stmt->bls_prologue) |
131 | stmt->bls_prologue->accept(this); |
132 | |
133 | if (stmt->body) { |
134 | if (stmt->task_type == OffloadedStmt::TaskType::range_for || |
135 | stmt->task_type == OffloadedTaskType::mesh_for || |
136 | stmt->task_type == OffloadedStmt::TaskType::struct_for) |
137 | visit_loop(stmt->body.get()); |
138 | else |
139 | stmt->body->accept(this); |
140 | } |
141 | |
142 | if (stmt->bls_epilogue) |
143 | stmt->bls_epilogue->accept(this); |
144 | |
145 | if (stmt->tls_epilogue) |
146 | stmt->tls_epilogue->accept(this); |
147 | } |
148 | }; |
149 | |
150 | } // namespace taichi::lang |
151 | |