1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/system/profiler.h"
6
7#include <stack>
8
9namespace taichi::lang {
10
11class LoopInvariantDetector : public BasicStmtVisitor {
12 public:
13 using BasicStmtVisitor::visit;
14
15 std::vector<Block *> loop_blocks;
16
17 const CompileConfig &config;
18
19 explicit LoopInvariantDetector(const CompileConfig &config) : config(config) {
20 allow_undefined_visitor = true;
21 }
22
23 bool is_operand_loop_invariant_impl(Stmt *operand,
24 Block *current_scope,
25 Block *loop_block = nullptr) {
26 if (!loop_block) {
27 loop_block = loop_blocks.back();
28 }
29 if (operand->parent == current_scope) {
30 // This statement has an operand that is in the current scope,
31 // so it can not be moved out of the scope.
32 return false;
33 }
34 if (current_scope != loop_block) {
35 // If we enable moving code from a nested if block, we need to check
36 // visibility. Example:
37 // for i in range(10):
38 // a = x[0]
39 // if b:
40 // c = a + 1
41 // Since we are moving statements outside the closest for scope,
42 // We need to check the scope of the operand
43 Stmt *operand_parent = operand;
44 while (operand_parent->parent) {
45 operand_parent = operand_parent->parent->parent_stmt;
46 if (!operand_parent)
47 break;
48 // If the one of the current_scope of the operand is the top loop
49 // scope Then it will not be visible if we move it outside the top
50 // loop scope
51 if (operand_parent == loop_block->parent_stmt) {
52 return false;
53 }
54 }
55 }
56 return true;
57 }
58
59 bool is_operand_loop_invariant(Stmt *operand,
60 Block *current_scope,
61 int depth = -1) {
62 if (depth == -1) {
63 depth = loop_blocks.size() - 1;
64 }
65 if (depth <= 0)
66 return false;
67 return is_operand_loop_invariant_impl(operand, current_scope,
68 loop_blocks[depth]);
69 }
70
71 bool is_loop_invariant(Stmt *stmt, Block *current_scope) {
72 if (loop_blocks.size() <= 1 || (!config.move_loop_invariant_outside_if &&
73 current_scope != loop_blocks.back()))
74 return false;
75
76 bool is_invariant = true;
77
78 for (Stmt *operand : stmt->get_operands()) {
79 if (operand == nullptr)
80 continue;
81 is_invariant &= is_operand_loop_invariant_impl(operand, current_scope);
82 }
83
84 return is_invariant;
85 }
86
87 Stmt *get_loop_stmt(int depth) {
88 return loop_blocks[depth]->parent_stmt;
89 }
90
91 Stmt *current_loop_stmt() {
92 return get_loop_stmt(loop_blocks.size() - 1);
93 }
94 void visit(Block *stmt_list) override {
95 for (auto &stmt : stmt_list->statements)
96 stmt->accept(this);
97 }
98
99 virtual void visit_loop(Block *body) {
100 loop_blocks.push_back(body);
101
102 body->accept(this);
103
104 loop_blocks.pop_back();
105 }
106
107 void visit(RangeForStmt *stmt) override {
108 visit_loop(stmt->body.get());
109 }
110
111 void visit(StructForStmt *stmt) override {
112 visit_loop(stmt->body.get());
113 }
114
115 void visit(MeshForStmt *stmt) override {
116 visit_loop(stmt->body.get());
117 }
118
119 void visit(WhileStmt *stmt) override {
120 visit_loop(stmt->body.get());
121 }
122
123 void visit(OffloadedStmt *stmt) override {
124 if (stmt->tls_prologue)
125 stmt->tls_prologue->accept(this);
126
127 if (stmt->mesh_prologue)
128 stmt->mesh_prologue->accept(this);
129
130 if (stmt->bls_prologue)
131 stmt->bls_prologue->accept(this);
132
133 if (stmt->body) {
134 if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
135 stmt->task_type == OffloadedTaskType::mesh_for ||
136 stmt->task_type == OffloadedStmt::TaskType::struct_for)
137 visit_loop(stmt->body.get());
138 else
139 stmt->body->accept(this);
140 }
141
142 if (stmt->bls_epilogue)
143 stmt->bls_epilogue->accept(this);
144
145 if (stmt->tls_epilogue)
146 stmt->tls_epilogue->accept(this);
147 }
148};
149
150} // namespace taichi::lang
151