1#include "taichi/transforms/loop_invariant_detector.h"
2
3namespace taichi::lang {
4
5class LoopInvariantCodeMotion : public LoopInvariantDetector {
6 public:
7 using LoopInvariantDetector::visit;
8
9 DelayedIRModifier modifier;
10
11 explicit LoopInvariantCodeMotion(const CompileConfig &config)
12 : LoopInvariantDetector(config) {
13 }
14
15 void visit(BinaryOpStmt *stmt) override {
16 if (is_loop_invariant(stmt, stmt->parent)) {
17 auto replacement = stmt->clone();
18 stmt->replace_usages_with(replacement.get());
19
20 modifier.insert_before(current_loop_stmt(), std::move(replacement));
21 modifier.erase(stmt);
22 }
23 }
24
25 void visit(UnaryOpStmt *stmt) override {
26 if (is_loop_invariant(stmt, stmt->parent)) {
27 auto replacement = stmt->clone();
28 stmt->replace_usages_with(replacement.get());
29
30 modifier.insert_before(current_loop_stmt(), std::move(replacement));
31 modifier.erase(stmt);
32 }
33 }
34
35 void visit(GlobalPtrStmt *stmt) override {
36 if (config.cache_loop_invariant_global_vars &&
37 is_loop_invariant(stmt, stmt->parent)) {
38 auto replacement = stmt->clone();
39 stmt->replace_usages_with(replacement.get());
40
41 modifier.insert_before(current_loop_stmt(), std::move(replacement));
42 modifier.erase(stmt);
43 }
44 }
45
46 void visit(ExternalPtrStmt *stmt) override {
47 if (config.cache_loop_invariant_global_vars &&
48 is_loop_invariant(stmt, stmt->parent)) {
49 auto replacement = stmt->clone();
50 stmt->replace_usages_with(replacement.get());
51
52 modifier.insert_before(current_loop_stmt(), std::move(replacement));
53 modifier.erase(stmt);
54 }
55 }
56
57 void visit(ArgLoadStmt *stmt) override {
58 if (config.cache_loop_invariant_global_vars &&
59 is_loop_invariant(stmt, stmt->parent)) {
60 auto replacement = stmt->clone();
61 stmt->replace_usages_with(replacement.get());
62
63 modifier.insert_before(current_loop_stmt(), std::move(replacement));
64 modifier.erase(stmt);
65 }
66 }
67
68 static bool run(IRNode *node, const CompileConfig &config) {
69 bool modified = false;
70
71 while (true) {
72 LoopInvariantCodeMotion eliminator(config);
73 node->accept(&eliminator);
74 if (eliminator.modifier.modify_ir())
75 modified = true;
76 else
77 break;
78 };
79
80 return modified;
81 }
82};
83
84namespace irpass {
85bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config) {
86 TI_AUTO_PROF;
87 return LoopInvariantCodeMotion::run(root, config);
88}
89} // namespace irpass
90
91} // namespace taichi::lang
92