1 | #include "taichi/transforms/loop_invariant_detector.h" |
2 | |
3 | namespace taichi::lang { |
4 | |
5 | class LoopInvariantCodeMotion : public LoopInvariantDetector { |
6 | public: |
7 | using LoopInvariantDetector::visit; |
8 | |
9 | DelayedIRModifier modifier; |
10 | |
11 | explicit LoopInvariantCodeMotion(const CompileConfig &config) |
12 | : LoopInvariantDetector(config) { |
13 | } |
14 | |
15 | void visit(BinaryOpStmt *stmt) override { |
16 | if (is_loop_invariant(stmt, stmt->parent)) { |
17 | auto replacement = stmt->clone(); |
18 | stmt->replace_usages_with(replacement.get()); |
19 | |
20 | modifier.insert_before(current_loop_stmt(), std::move(replacement)); |
21 | modifier.erase(stmt); |
22 | } |
23 | } |
24 | |
25 | void visit(UnaryOpStmt *stmt) override { |
26 | if (is_loop_invariant(stmt, stmt->parent)) { |
27 | auto replacement = stmt->clone(); |
28 | stmt->replace_usages_with(replacement.get()); |
29 | |
30 | modifier.insert_before(current_loop_stmt(), std::move(replacement)); |
31 | modifier.erase(stmt); |
32 | } |
33 | } |
34 | |
35 | void visit(GlobalPtrStmt *stmt) override { |
36 | if (config.cache_loop_invariant_global_vars && |
37 | is_loop_invariant(stmt, stmt->parent)) { |
38 | auto replacement = stmt->clone(); |
39 | stmt->replace_usages_with(replacement.get()); |
40 | |
41 | modifier.insert_before(current_loop_stmt(), std::move(replacement)); |
42 | modifier.erase(stmt); |
43 | } |
44 | } |
45 | |
46 | void visit(ExternalPtrStmt *stmt) override { |
47 | if (config.cache_loop_invariant_global_vars && |
48 | is_loop_invariant(stmt, stmt->parent)) { |
49 | auto replacement = stmt->clone(); |
50 | stmt->replace_usages_with(replacement.get()); |
51 | |
52 | modifier.insert_before(current_loop_stmt(), std::move(replacement)); |
53 | modifier.erase(stmt); |
54 | } |
55 | } |
56 | |
57 | void visit(ArgLoadStmt *stmt) override { |
58 | if (config.cache_loop_invariant_global_vars && |
59 | is_loop_invariant(stmt, stmt->parent)) { |
60 | auto replacement = stmt->clone(); |
61 | stmt->replace_usages_with(replacement.get()); |
62 | |
63 | modifier.insert_before(current_loop_stmt(), std::move(replacement)); |
64 | modifier.erase(stmt); |
65 | } |
66 | } |
67 | |
68 | static bool run(IRNode *node, const CompileConfig &config) { |
69 | bool modified = false; |
70 | |
71 | while (true) { |
72 | LoopInvariantCodeMotion eliminator(config); |
73 | node->accept(&eliminator); |
74 | if (eliminator.modifier.modify_ir()) |
75 | modified = true; |
76 | else |
77 | break; |
78 | }; |
79 | |
80 | return modified; |
81 | } |
82 | }; |
83 | |
84 | namespace irpass { |
85 | bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config) { |
86 | TI_AUTO_PROF; |
87 | return LoopInvariantCodeMotion::run(root, config); |
88 | } |
89 | } // namespace irpass |
90 | |
91 | } // namespace taichi::lang |
92 | |