1 | #include "taichi/ir/visitors.h" |
---|---|
2 | #include "taichi/ir/statements.h" |
3 | |
4 | #include <unordered_set> |
5 | #include <functional> |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | namespace { |
10 | |
11 | // A statement is considered constexpr in this pass, iff both its value and the |
12 | // control flow reaching it are constant w.r.t. const_seed(...) |
13 | |
14 | class ConstExprPropagation : public IRVisitor { |
15 | public: |
16 | using is_const_seed_func = std::function<bool(Stmt *)>; |
17 | |
18 | explicit ConstExprPropagation(const is_const_seed_func &is_const_seed) |
19 | : is_const_seed_(is_const_seed) { |
20 | allow_undefined_visitor = true; |
21 | invoke_default_visitor = true; |
22 | } |
23 | |
24 | bool generic_test(Stmt *stmt) { |
25 | if (is_const_seed_(stmt)) { |
26 | const_stmts_.insert(stmt); |
27 | return true; |
28 | } else { |
29 | return false; |
30 | } |
31 | } |
32 | |
33 | void visit(Stmt *stmt) override { |
34 | generic_test(stmt); |
35 | } |
36 | |
37 | bool is_inferred_const(Stmt *stmt) { |
38 | // Note: every statement that tests true by "is_const_seed_" should have |
39 | // already been included in const_stmts_. |
40 | return const_stmts_.find(stmt) != const_stmts_.end(); |
41 | }; |
42 | |
43 | void visit(UnaryOpStmt *stmt) override { |
44 | if (generic_test(stmt)) |
45 | return; |
46 | if (is_inferred_const(stmt->operand)) { |
47 | const_stmts_.insert(stmt); |
48 | } |
49 | } |
50 | |
51 | void visit(BinaryOpStmt *stmt) override { |
52 | if (generic_test(stmt)) |
53 | return; |
54 | if (is_inferred_const(stmt->lhs) && is_inferred_const(stmt->rhs)) { |
55 | const_stmts_.insert(stmt); |
56 | } |
57 | } |
58 | |
59 | void visit(TernaryOpStmt *stmt) override { |
60 | if (generic_test(stmt)) |
61 | return; |
62 | if (is_inferred_const(stmt->op1) && is_inferred_const(stmt->op2) && |
63 | is_inferred_const(stmt->op3)) { |
64 | const_stmts_.insert(stmt); |
65 | } |
66 | } |
67 | |
68 | void visit(IfStmt *stmt) override { |
69 | // If the condition is constexpr, then the control flow is also considered |
70 | // const. |
71 | if (is_inferred_const(stmt->cond)) { |
72 | if (stmt->true_statements) |
73 | stmt->true_statements->accept(this); |
74 | if (stmt->false_statements) |
75 | stmt->false_statements->accept(this); |
76 | } |
77 | } |
78 | |
79 | void visit(Block *block) override { |
80 | for (auto &stmt : block->statements) |
81 | stmt->accept(this); |
82 | } |
83 | |
84 | // TODO: how do we rigorously define constexpr in RangeFor loops? |
85 | |
86 | static std::unordered_set<Stmt *> run( |
87 | Block *block, |
88 | const std::function<bool(Stmt *)> &is_const_seed) { |
89 | ConstExprPropagation prop(is_const_seed); |
90 | block->accept(&prop); |
91 | return prop.const_stmts_; |
92 | } |
93 | |
94 | private: |
95 | is_const_seed_func is_const_seed_; |
96 | std::unordered_set<Stmt *> const_stmts_; |
97 | }; |
98 | |
99 | } // namespace |
100 | |
101 | namespace irpass::analysis { |
102 | std::unordered_set<Stmt *> constexpr_prop( |
103 | Block *block, |
104 | std::function<bool(Stmt *)> is_const_seed) { |
105 | return ConstExprPropagation::run(block, is_const_seed); |
106 | } |
107 | } // namespace irpass::analysis |
108 | |
109 | } // namespace taichi::lang |
110 |