1 | #pragma once |
2 | |
3 | #include "taichi/program/compile_config.h" |
4 | #include "taichi/util/str.h" |
5 | #include "taichi/ir/ir.h" |
6 | #include "taichi/ir/expr.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | class ExpressionVisitor; |
11 | |
12 | // always a tree - used as rvalues |
13 | class Expression { |
14 | protected: |
15 | Stmt *stmt; |
16 | |
17 | public: |
18 | std::string tb; |
19 | std::map<std::string, std::string> attributes; |
20 | DataType ret_type; |
21 | |
22 | struct FlattenContext { |
23 | VecStatement stmts; |
24 | Block *current_block = nullptr; |
25 | |
26 | inline Stmt *push_back(pStmt &&stmt) { |
27 | return stmts.push_back(std::move(stmt)); |
28 | } |
29 | |
30 | template <typename T, typename... Args> |
31 | T *push_back(Args &&...args) { |
32 | return stmts.push_back<T>(std::forward<Args>(args)...); |
33 | } |
34 | |
35 | Stmt *back_stmt() { |
36 | return stmts.back().get(); |
37 | } |
38 | }; |
39 | |
40 | Expression() { |
41 | stmt = nullptr; |
42 | } |
43 | |
44 | virtual void type_check(const CompileConfig *config) = 0; |
45 | |
46 | virtual void accept(ExpressionVisitor *visitor) = 0; |
47 | |
48 | virtual void flatten(FlattenContext *ctx) { |
49 | TI_NOT_IMPLEMENTED; |
50 | }; |
51 | |
52 | virtual bool is_lvalue() const { |
53 | return false; |
54 | } |
55 | |
56 | virtual ~Expression() { |
57 | } |
58 | |
59 | Stmt *get_flattened_stmt() const { |
60 | return stmt; |
61 | } |
62 | }; |
63 | |
64 | class ExprGroup { |
65 | public: |
66 | std::vector<Expr> exprs; |
67 | |
68 | ExprGroup() { |
69 | } |
70 | |
71 | explicit ExprGroup(const Expr &a) { |
72 | exprs.emplace_back(a); |
73 | } |
74 | |
75 | ExprGroup(const Expr &a, const Expr &b) { |
76 | exprs.emplace_back(a); |
77 | exprs.emplace_back(b); |
78 | } |
79 | |
80 | ExprGroup(const ExprGroup &a, const Expr &b) { |
81 | exprs.resize(a.size() + 1); |
82 | |
83 | for (int i = 0; i < a.size(); ++i) { |
84 | exprs[i].set(a.exprs[i]); |
85 | } |
86 | exprs.back().set(b); |
87 | } |
88 | |
89 | ExprGroup(const Expr &a, const ExprGroup &b) { |
90 | exprs.resize(b.size() + 1); |
91 | exprs.front().set(a); |
92 | for (int i = 0; i < b.size(); i++) { |
93 | exprs[i + 1].set(b.exprs[i]); |
94 | } |
95 | } |
96 | |
97 | void push_back(const Expr &expr) { |
98 | exprs.emplace_back(expr); |
99 | } |
100 | |
101 | std::size_t size() const { |
102 | return exprs.size(); |
103 | } |
104 | |
105 | const Expr &operator[](int i) const { |
106 | return exprs[i]; |
107 | } |
108 | |
109 | Expr &operator[](int i) { |
110 | return exprs[i]; |
111 | } |
112 | }; |
113 | |
114 | inline ExprGroup operator,(const Expr &a, const Expr &b) { |
115 | return ExprGroup(a, b); |
116 | } |
117 | |
118 | inline ExprGroup operator,(const ExprGroup &a, const Expr &b) { |
119 | return ExprGroup(a, b); |
120 | } |
121 | |
122 | #define PER_EXPRESSION(x) class x; |
123 | #include "taichi/inc/expressions.inc.h" |
124 | #undef PER_EXPRESSION |
125 | |
126 | class ExpressionVisitor { |
127 | public: |
128 | explicit ExpressionVisitor(bool allow_undefined_visitor = false, |
129 | bool invoke_default_visitor = false) |
130 | : allow_undefined_visitor_(allow_undefined_visitor), |
131 | invoke_default_visitor_(invoke_default_visitor) { |
132 | } |
133 | |
134 | virtual ~ExpressionVisitor() = default; |
135 | |
136 | virtual void visit(ExprGroup &expr_group) = 0; |
137 | |
138 | void visit(Expr &expr) { |
139 | expr.expr->accept(this); |
140 | } |
141 | |
142 | virtual void visit(Expression *expr) { |
143 | if (!allow_undefined_visitor_) { |
144 | TI_ERROR("missing visitor function" ); |
145 | } |
146 | } |
147 | |
148 | #define DEFINE_VISIT(T) \ |
149 | virtual void visit(T *expr) { \ |
150 | if (allow_undefined_visitor_) { \ |
151 | if (invoke_default_visitor_) \ |
152 | visit((Expression *)expr); \ |
153 | } else \ |
154 | TI_NOT_IMPLEMENTED; \ |
155 | } |
156 | |
157 | #define PER_EXPRESSION(x) DEFINE_VISIT(x) |
158 | #include "taichi/inc/expressions.inc.h" |
159 | #undef PER_EXPRESSION |
160 | #undef DEFINE_VISIT |
161 | private: |
162 | bool allow_undefined_visitor_{false}; |
163 | bool invoke_default_visitor_{false}; |
164 | }; |
165 | |
166 | #define TI_DEFINE_ACCEPT_FOR_EXPRESSION \ |
167 | void accept(ExpressionVisitor *visitor) override { visitor->visit(this); } |
168 | |
169 | } // namespace taichi::lang |
170 | |