1#pragma once
2
3#include "taichi/program/compile_config.h"
4#include "taichi/util/str.h"
5#include "taichi/ir/ir.h"
6#include "taichi/ir/expr.h"
7
8namespace taichi::lang {
9
10class ExpressionVisitor;
11
12// always a tree - used as rvalues
13class Expression {
14 protected:
15 Stmt *stmt;
16
17 public:
18 std::string tb;
19 std::map<std::string, std::string> attributes;
20 DataType ret_type;
21
22 struct FlattenContext {
23 VecStatement stmts;
24 Block *current_block = nullptr;
25
26 inline Stmt *push_back(pStmt &&stmt) {
27 return stmts.push_back(std::move(stmt));
28 }
29
30 template <typename T, typename... Args>
31 T *push_back(Args &&...args) {
32 return stmts.push_back<T>(std::forward<Args>(args)...);
33 }
34
35 Stmt *back_stmt() {
36 return stmts.back().get();
37 }
38 };
39
40 Expression() {
41 stmt = nullptr;
42 }
43
44 virtual void type_check(const CompileConfig *config) = 0;
45
46 virtual void accept(ExpressionVisitor *visitor) = 0;
47
48 virtual void flatten(FlattenContext *ctx) {
49 TI_NOT_IMPLEMENTED;
50 };
51
52 virtual bool is_lvalue() const {
53 return false;
54 }
55
56 virtual ~Expression() {
57 }
58
59 Stmt *get_flattened_stmt() const {
60 return stmt;
61 }
62};
63
64class ExprGroup {
65 public:
66 std::vector<Expr> exprs;
67
68 ExprGroup() {
69 }
70
71 explicit ExprGroup(const Expr &a) {
72 exprs.emplace_back(a);
73 }
74
75 ExprGroup(const Expr &a, const Expr &b) {
76 exprs.emplace_back(a);
77 exprs.emplace_back(b);
78 }
79
80 ExprGroup(const ExprGroup &a, const Expr &b) {
81 exprs.resize(a.size() + 1);
82
83 for (int i = 0; i < a.size(); ++i) {
84 exprs[i].set(a.exprs[i]);
85 }
86 exprs.back().set(b);
87 }
88
89 ExprGroup(const Expr &a, const ExprGroup &b) {
90 exprs.resize(b.size() + 1);
91 exprs.front().set(a);
92 for (int i = 0; i < b.size(); i++) {
93 exprs[i + 1].set(b.exprs[i]);
94 }
95 }
96
97 void push_back(const Expr &expr) {
98 exprs.emplace_back(expr);
99 }
100
101 std::size_t size() const {
102 return exprs.size();
103 }
104
105 const Expr &operator[](int i) const {
106 return exprs[i];
107 }
108
109 Expr &operator[](int i) {
110 return exprs[i];
111 }
112};
113
114inline ExprGroup operator,(const Expr &a, const Expr &b) {
115 return ExprGroup(a, b);
116}
117
118inline ExprGroup operator,(const ExprGroup &a, const Expr &b) {
119 return ExprGroup(a, b);
120}
121
122#define PER_EXPRESSION(x) class x;
123#include "taichi/inc/expressions.inc.h"
124#undef PER_EXPRESSION
125
126class ExpressionVisitor {
127 public:
128 explicit ExpressionVisitor(bool allow_undefined_visitor = false,
129 bool invoke_default_visitor = false)
130 : allow_undefined_visitor_(allow_undefined_visitor),
131 invoke_default_visitor_(invoke_default_visitor) {
132 }
133
134 virtual ~ExpressionVisitor() = default;
135
136 virtual void visit(ExprGroup &expr_group) = 0;
137
138 void visit(Expr &expr) {
139 expr.expr->accept(this);
140 }
141
142 virtual void visit(Expression *expr) {
143 if (!allow_undefined_visitor_) {
144 TI_ERROR("missing visitor function");
145 }
146 }
147
148#define DEFINE_VISIT(T) \
149 virtual void visit(T *expr) { \
150 if (allow_undefined_visitor_) { \
151 if (invoke_default_visitor_) \
152 visit((Expression *)expr); \
153 } else \
154 TI_NOT_IMPLEMENTED; \
155 }
156
157#define PER_EXPRESSION(x) DEFINE_VISIT(x)
158#include "taichi/inc/expressions.inc.h"
159#undef PER_EXPRESSION
160#undef DEFINE_VISIT
161 private:
162 bool allow_undefined_visitor_{false};
163 bool invoke_default_visitor_{false};
164};
165
166#define TI_DEFINE_ACCEPT_FOR_EXPRESSION \
167 void accept(ExpressionVisitor *visitor) override { visitor->visit(this); }
168
169} // namespace taichi::lang
170