1#include "taichi/ir/visitors.h"
2#include "taichi/ir/statements.h"
3
4#include <unordered_set>
5#include <functional>
6
7namespace taichi::lang {
8
9namespace {
10
11// A statement is considered constexpr in this pass, iff both its value and the
12// control flow reaching it are constant w.r.t. const_seed(...)
13
14class ConstExprPropagation : public IRVisitor {
15 public:
16 using is_const_seed_func = std::function<bool(Stmt *)>;
17
18 explicit ConstExprPropagation(const is_const_seed_func &is_const_seed)
19 : is_const_seed_(is_const_seed) {
20 allow_undefined_visitor = true;
21 invoke_default_visitor = true;
22 }
23
24 bool generic_test(Stmt *stmt) {
25 if (is_const_seed_(stmt)) {
26 const_stmts_.insert(stmt);
27 return true;
28 } else {
29 return false;
30 }
31 }
32
33 void visit(Stmt *stmt) override {
34 generic_test(stmt);
35 }
36
37 bool is_inferred_const(Stmt *stmt) {
38 // Note: every statement that tests true by "is_const_seed_" should have
39 // already been included in const_stmts_.
40 return const_stmts_.find(stmt) != const_stmts_.end();
41 };
42
43 void visit(UnaryOpStmt *stmt) override {
44 if (generic_test(stmt))
45 return;
46 if (is_inferred_const(stmt->operand)) {
47 const_stmts_.insert(stmt);
48 }
49 }
50
51 void visit(BinaryOpStmt *stmt) override {
52 if (generic_test(stmt))
53 return;
54 if (is_inferred_const(stmt->lhs) && is_inferred_const(stmt->rhs)) {
55 const_stmts_.insert(stmt);
56 }
57 }
58
59 void visit(TernaryOpStmt *stmt) override {
60 if (generic_test(stmt))
61 return;
62 if (is_inferred_const(stmt->op1) && is_inferred_const(stmt->op2) &&
63 is_inferred_const(stmt->op3)) {
64 const_stmts_.insert(stmt);
65 }
66 }
67
68 void visit(IfStmt *stmt) override {
69 // If the condition is constexpr, then the control flow is also considered
70 // const.
71 if (is_inferred_const(stmt->cond)) {
72 if (stmt->true_statements)
73 stmt->true_statements->accept(this);
74 if (stmt->false_statements)
75 stmt->false_statements->accept(this);
76 }
77 }
78
79 void visit(Block *block) override {
80 for (auto &stmt : block->statements)
81 stmt->accept(this);
82 }
83
84 // TODO: how do we rigorously define constexpr in RangeFor loops?
85
86 static std::unordered_set<Stmt *> run(
87 Block *block,
88 const std::function<bool(Stmt *)> &is_const_seed) {
89 ConstExprPropagation prop(is_const_seed);
90 block->accept(&prop);
91 return prop.const_stmts_;
92 }
93
94 private:
95 is_const_seed_func is_const_seed_;
96 std::unordered_set<Stmt *> const_stmts_;
97};
98
99} // namespace
100
101namespace irpass::analysis {
102std::unordered_set<Stmt *> constexpr_prop(
103 Block *block,
104 std::function<bool(Stmt *)> is_const_seed) {
105 return ConstExprPropagation::run(block, is_const_seed);
106}
107} // namespace irpass::analysis
108
109} // namespace taichi::lang
110