1 | #pragma once |
2 | |
3 | #include "taichi/util/str.h" |
4 | #include "taichi/ir/type_utils.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | struct CompileConfig; |
9 | class Expression; |
10 | class Identifier; |
11 | class ExprGroup; |
12 | class SNode; |
13 | class ASTBuilder; |
14 | |
15 | class Expr { |
16 | public: |
17 | std::shared_ptr<Expression> expr; |
18 | bool const_value; |
19 | bool atomic; |
20 | |
21 | Expr() { |
22 | const_value = false; |
23 | atomic = false; |
24 | } |
25 | |
26 | explicit Expr(int16 x); |
27 | |
28 | explicit Expr(int32 x); |
29 | |
30 | explicit Expr(int64 x); |
31 | |
32 | explicit Expr(float32 x); |
33 | |
34 | explicit Expr(float64 x); |
35 | |
36 | explicit Expr(std::shared_ptr<Expression> expr) : Expr() { |
37 | this->expr = expr; |
38 | } |
39 | |
40 | Expr(const Expr &o) : Expr() { |
41 | set(o); |
42 | const_value = o.const_value; |
43 | } |
44 | |
45 | Expr(Expr &&o) : Expr() { |
46 | set(o); |
47 | const_value = o.const_value; |
48 | atomic = o.atomic; |
49 | } |
50 | |
51 | explicit Expr(const Identifier &id); |
52 | |
53 | void set(const Expr &o) { |
54 | expr = o.expr; |
55 | } |
56 | |
57 | // NOLINTNEXTLINE(google-explicit-constructor) |
58 | operator bool() const { |
59 | return expr.get() != nullptr; |
60 | } |
61 | |
62 | Expression *operator->() { |
63 | return expr.get(); |
64 | } |
65 | |
66 | Expression const *operator->() const { |
67 | return expr.get(); |
68 | } |
69 | |
70 | template <typename T> |
71 | std::shared_ptr<T> cast() const { |
72 | TI_ASSERT(expr != nullptr); |
73 | return std::dynamic_pointer_cast<T>(expr); |
74 | } |
75 | |
76 | template <typename T> |
77 | bool is() const { |
78 | return cast<T>() != nullptr; |
79 | } |
80 | |
81 | // FIXME: We really should disable it completely, |
82 | // but we can't. This is because the usage of |
83 | // std::variant<Expr, std::string> in FrontendPrintStmt. |
84 | Expr &operator=(const Expr &o); |
85 | |
86 | template <typename T, typename... Args> |
87 | static Expr make(Args &&...args) { |
88 | return Expr(std::make_shared<T>(std::forward<Args>(args)...)); |
89 | } |
90 | |
91 | SNode *snode() const; |
92 | |
93 | // traceback for type checking error message |
94 | void set_tb(const std::string &tb); |
95 | |
96 | void set_adjoint(const Expr &o); |
97 | |
98 | void set_dual(const Expr &o); |
99 | |
100 | void set_adjoint_checkbit(const Expr &o); |
101 | |
102 | DataType get_ret_type() const; |
103 | |
104 | void type_check(const CompileConfig *config); |
105 | }; |
106 | |
107 | // Value cast |
108 | Expr cast(const Expr &input, DataType dt); |
109 | |
110 | template <typename T> |
111 | Expr cast(const Expr &input) { |
112 | return taichi::lang::cast(input, get_data_type<T>()); |
113 | } |
114 | |
115 | Expr bit_cast(const Expr &input, DataType dt); |
116 | |
117 | template <typename T> |
118 | Expr bit_cast(const Expr &input) { |
119 | return taichi::lang::bit_cast(input, get_data_type<T>()); |
120 | } |
121 | |
122 | // like Expr::Expr, but allows to explicitly specify the type |
123 | template <typename T> |
124 | Expr value(const T &val) { |
125 | return Expr(val); |
126 | } |
127 | |
128 | Expr expr_rand(DataType dt); |
129 | |
130 | template <typename T> |
131 | Expr expr_rand() { |
132 | return taichi::lang::expr_rand(get_data_type<T>()); |
133 | } |
134 | |
135 | Expr assume_range(const Expr &expr, const Expr &base, int low, int high); |
136 | |
137 | Expr loop_unique(const Expr &input, const std::vector<SNode *> &covers); |
138 | |
139 | Expr expr_field(Expr id_expr, DataType dt); |
140 | |
141 | Expr expr_matrix_field(const std::vector<Expr> &fields, |
142 | const std::vector<int> &element_shape); |
143 | |
144 | } // namespace taichi::lang |
145 | |