1 | #pragma once |
2 | |
3 | #include "taichi/ir/ir.h" |
4 | #include "taichi/ir/mesh.h" |
5 | |
6 | namespace taichi::lang { |
7 | |
8 | class Function; |
9 | |
10 | class IRBuilder { |
11 | public: |
12 | struct InsertPoint { |
13 | Block *block{nullptr}; |
14 | int position{0}; |
15 | }; |
16 | |
17 | IRBuilder(); |
18 | |
19 | // Clear the IR and the insertion point. |
20 | void reset(); |
21 | |
22 | // Extract the IR. |
23 | std::unique_ptr<Block> (); |
24 | |
25 | // General inserter. Returns stmt.get(). |
26 | template <typename XStmt> |
27 | XStmt *insert(std::unique_ptr<XStmt> &&stmt) { |
28 | return insert(std::move(stmt), &insert_point_); |
29 | } |
30 | |
31 | // Insert to a specific insertion point. |
32 | template <typename XStmt> |
33 | static XStmt *insert(std::unique_ptr<XStmt> &&stmt, |
34 | InsertPoint *insert_point) { |
35 | return insert_point->block |
36 | ->insert(std::move(stmt), insert_point->position++) |
37 | ->template as<XStmt>(); |
38 | } |
39 | |
40 | void set_insertion_point(InsertPoint new_insert_point); |
41 | void set_insertion_point_to_after(Stmt *stmt); |
42 | void set_insertion_point_to_before(Stmt *stmt); |
43 | void set_insertion_point_to_true_branch(IfStmt *if_stmt); |
44 | void set_insertion_point_to_false_branch(IfStmt *if_stmt); |
45 | template <typename XStmt> |
46 | void set_insertion_point_to_loop_begin(XStmt *loop) { |
47 | using DecayedType = typename std::decay_t<XStmt>; |
48 | if constexpr (!std::is_base_of_v<Stmt, DecayedType>) { |
49 | TI_ERROR("The argument is not a statement." ); |
50 | } |
51 | if constexpr (std::is_same_v<DecayedType, RangeForStmt> || |
52 | std::is_same_v<DecayedType, StructForStmt> || |
53 | std::is_same_v<DecayedType, MeshForStmt> || |
54 | std::is_same_v<DecayedType, WhileStmt>) { |
55 | set_insertion_point({loop->body.get(), 0}); |
56 | } else { |
57 | TI_ERROR("Statement {} is not a loop." , loop->name()); |
58 | } |
59 | } |
60 | |
61 | // RAII handles insertion points automatically. |
62 | class LoopGuard { |
63 | public: |
64 | // Set the insertion point to the beginning of the loop body. |
65 | template <typename XStmt> |
66 | explicit LoopGuard(IRBuilder &builder, XStmt *loop) |
67 | : builder_(builder), loop_(loop) { |
68 | location_ = (int)loop->parent->size() - 1; |
69 | builder_.set_insertion_point_to_loop_begin(loop); |
70 | } |
71 | |
72 | // Set the insertion point to the point after the loop. |
73 | ~LoopGuard(); |
74 | |
75 | private: |
76 | IRBuilder &builder_; |
77 | Stmt *loop_; |
78 | int location_; |
79 | }; |
80 | class IfGuard { |
81 | public: |
82 | // Set the insertion point to the beginning of the true/false branch. |
83 | explicit IfGuard(IRBuilder &builder, IfStmt *if_stmt, bool true_branch); |
84 | |
85 | // Set the insertion point to the point after the if statement. |
86 | ~IfGuard(); |
87 | |
88 | private: |
89 | IRBuilder &builder_; |
90 | IfStmt *if_stmt_; |
91 | int location_; |
92 | }; |
93 | |
94 | template <typename XStmt> |
95 | [[nodiscard]] LoopGuard get_loop_guard(XStmt *loop) { |
96 | return LoopGuard(*this, loop); |
97 | } |
98 | |
99 | [[nodiscard]] IfGuard get_if_guard(IfStmt *if_stmt, bool true_branch) { |
100 | return IfGuard(*this, if_stmt, true_branch); |
101 | } |
102 | |
103 | // Control flows. |
104 | RangeForStmt *create_range_for(Stmt *begin, |
105 | Stmt *end, |
106 | bool is_bit_vectorized = false, |
107 | int num_cpu_threads = 0, |
108 | int block_dim = 0, |
109 | bool strictly_serialized = false); |
110 | StructForStmt *create_struct_for(SNode *snode, |
111 | bool is_bit_vectorized = false, |
112 | int num_cpu_threads = 0, |
113 | int block_dim = 0); |
114 | MeshForStmt *create_mesh_for(mesh::Mesh *mesh, |
115 | mesh::MeshElementType element_type, |
116 | bool is_bit_vectorized = false, |
117 | int num_cpu_threads = 0, |
118 | int block_dim = 0); |
119 | WhileStmt *create_while_true(); |
120 | IfStmt *create_if(Stmt *cond); |
121 | WhileControlStmt *create_break(); |
122 | ContinueStmt *create_continue(); |
123 | |
124 | // Function. |
125 | FuncCallStmt *create_func_call(Function *func, |
126 | const std::vector<Stmt *> &args); |
127 | |
128 | // Loop index. |
129 | LoopIndexStmt *get_loop_index(Stmt *loop, int index = 0); |
130 | |
131 | // Constants. TODO: add more types |
132 | ConstStmt *get_int32(int32 value); |
133 | ConstStmt *get_int64(int64 value); |
134 | ConstStmt *get_uint32(uint32 value); |
135 | ConstStmt *get_uint64(uint64 value); |
136 | ConstStmt *get_float32(float32 value); |
137 | ConstStmt *get_float64(float64 value); |
138 | |
139 | template <typename T> |
140 | ConstStmt *get_constant(DataType dt, const T &value) { |
141 | return insert(Stmt::make_typed<ConstStmt>(TypedConstant(dt, value))); |
142 | } |
143 | |
144 | RandStmt *create_rand(DataType value_type); |
145 | |
146 | // Load kernel arguments. |
147 | ArgLoadStmt *create_arg_load(int arg_id, DataType dt, bool is_ptr); |
148 | |
149 | // The return value of the kernel. |
150 | ReturnStmt *create_return(Stmt *value); |
151 | |
152 | // Unary operations. Returns the result. |
153 | UnaryOpStmt *create_cast(Stmt *value, DataType output_type); // cast by value |
154 | UnaryOpStmt *create_bit_cast(Stmt *value, DataType output_type); |
155 | UnaryOpStmt *create_neg(Stmt *value); |
156 | UnaryOpStmt *create_not(Stmt *value); // bitwise |
157 | UnaryOpStmt *create_logical_not(Stmt *value); |
158 | UnaryOpStmt *create_round(Stmt *value); |
159 | UnaryOpStmt *create_floor(Stmt *value); |
160 | UnaryOpStmt *create_ceil(Stmt *value); |
161 | UnaryOpStmt *create_abs(Stmt *value); |
162 | UnaryOpStmt *create_sgn(Stmt *value); |
163 | UnaryOpStmt *create_sqrt(Stmt *value); |
164 | UnaryOpStmt *create_rsqrt(Stmt *value); |
165 | UnaryOpStmt *create_sin(Stmt *value); |
166 | UnaryOpStmt *create_asin(Stmt *value); |
167 | UnaryOpStmt *create_cos(Stmt *value); |
168 | UnaryOpStmt *create_acos(Stmt *value); |
169 | UnaryOpStmt *create_tan(Stmt *value); |
170 | UnaryOpStmt *create_tanh(Stmt *value); |
171 | UnaryOpStmt *create_exp(Stmt *value); |
172 | UnaryOpStmt *create_log(Stmt *value); |
173 | |
174 | // Binary operations. Returns the result. |
175 | BinaryOpStmt *create_add(Stmt *l, Stmt *r); |
176 | BinaryOpStmt *create_sub(Stmt *l, Stmt *r); |
177 | BinaryOpStmt *create_mul(Stmt *l, Stmt *r); |
178 | // l / r in C++ |
179 | BinaryOpStmt *create_div(Stmt *l, Stmt *r); |
180 | // floor(1.0 * l / r) in C++ |
181 | BinaryOpStmt *create_floordiv(Stmt *l, Stmt *r); |
182 | // 1.0 * l / r in C++ |
183 | BinaryOpStmt *create_truediv(Stmt *l, Stmt *r); |
184 | BinaryOpStmt *create_mod(Stmt *l, Stmt *r); |
185 | BinaryOpStmt *create_max(Stmt *l, Stmt *r); |
186 | BinaryOpStmt *create_min(Stmt *l, Stmt *r); |
187 | BinaryOpStmt *create_atan2(Stmt *l, Stmt *r); |
188 | BinaryOpStmt *create_pow(Stmt *l, Stmt *r); |
189 | // Bitwise operations. TODO: add logical operations when we support them |
190 | BinaryOpStmt *create_and(Stmt *l, Stmt *r); |
191 | BinaryOpStmt *create_or(Stmt *l, Stmt *r); |
192 | BinaryOpStmt *create_xor(Stmt *l, Stmt *r); |
193 | BinaryOpStmt *create_shl(Stmt *l, Stmt *r); |
194 | BinaryOpStmt *create_shr(Stmt *l, Stmt *r); |
195 | BinaryOpStmt *create_sar(Stmt *l, Stmt *r); |
196 | // Comparisons. |
197 | BinaryOpStmt *create_cmp_lt(Stmt *l, Stmt *r); |
198 | BinaryOpStmt *create_cmp_le(Stmt *l, Stmt *r); |
199 | BinaryOpStmt *create_cmp_gt(Stmt *l, Stmt *r); |
200 | BinaryOpStmt *create_cmp_ge(Stmt *l, Stmt *r); |
201 | BinaryOpStmt *create_cmp_eq(Stmt *l, Stmt *r); |
202 | BinaryOpStmt *create_cmp_ne(Stmt *l, Stmt *r); |
203 | |
204 | // Atomic operations. |
205 | AtomicOpStmt *create_atomic_add(Stmt *dest, Stmt *val); |
206 | AtomicOpStmt *create_atomic_sub(Stmt *dest, Stmt *val); |
207 | AtomicOpStmt *create_atomic_max(Stmt *dest, Stmt *val); |
208 | AtomicOpStmt *create_atomic_min(Stmt *dest, Stmt *val); |
209 | // Atomic bitwise operations. |
210 | AtomicOpStmt *create_atomic_and(Stmt *dest, Stmt *val); |
211 | AtomicOpStmt *create_atomic_or(Stmt *dest, Stmt *val); |
212 | AtomicOpStmt *create_atomic_xor(Stmt *dest, Stmt *val); |
213 | |
214 | // Ternary operations. Returns the result. |
215 | TernaryOpStmt *create_select(Stmt *cond, |
216 | Stmt *true_result, |
217 | Stmt *false_result); |
218 | |
219 | // Matrix Initialization |
220 | MatrixInitStmt *create_matrix_init(std::vector<Stmt *> elements); |
221 | |
222 | // Print values and strings. Arguments can be Stmt* or std::string. |
223 | template <typename... Args> |
224 | PrintStmt *create_print(Args &&...args) { |
225 | return insert(Stmt::make_typed<PrintStmt>(std::forward<Args>(args)...)); |
226 | } |
227 | |
228 | // Local variables. |
229 | AllocaStmt *create_local_var(DataType dt); |
230 | LocalLoadStmt *create_local_load(AllocaStmt *ptr); |
231 | void create_local_store(AllocaStmt *ptr, Stmt *data); |
232 | |
233 | // Global variables. |
234 | GlobalPtrStmt *create_global_ptr(SNode *snode, |
235 | const std::vector<Stmt *> &indices); |
236 | ExternalPtrStmt *create_external_ptr(ArgLoadStmt *ptr, |
237 | const std::vector<Stmt *> &indices); |
238 | template <typename XStmt> |
239 | GlobalLoadStmt *create_global_load(XStmt *ptr) { |
240 | using DecayedType = typename std::decay_t<XStmt>; |
241 | if constexpr (!std::is_base_of_v<Stmt, DecayedType>) { |
242 | TI_ERROR("The argument is not a statement." ); |
243 | } |
244 | if constexpr (std::is_same_v<DecayedType, GlobalPtrStmt> || |
245 | std::is_same_v<DecayedType, ExternalPtrStmt>) { |
246 | return insert(Stmt::make_typed<GlobalLoadStmt>(ptr)); |
247 | } else { |
248 | TI_ERROR("Statement {} is not a global pointer." , ptr->name()); |
249 | } |
250 | } |
251 | template <typename XStmt> |
252 | void create_global_store(XStmt *ptr, Stmt *data) { |
253 | using DecayedType = typename std::decay_t<XStmt>; |
254 | if constexpr (!std::is_base_of_v<Stmt, DecayedType>) { |
255 | TI_ERROR("The argument is not a statement." ); |
256 | } |
257 | if constexpr (std::is_same_v<DecayedType, GlobalPtrStmt> || |
258 | std::is_same_v<DecayedType, ExternalPtrStmt>) { |
259 | insert(Stmt::make_typed<GlobalStoreStmt>(ptr, data)); |
260 | } else { |
261 | TI_ERROR("Statement {} is not a global pointer." , ptr->name()); |
262 | } |
263 | } |
264 | |
265 | // Autodiff stack operations. |
266 | AdStackAllocaStmt *create_ad_stack(const DataType &dt, std::size_t max_size); |
267 | void ad_stack_push(AdStackAllocaStmt *stack, Stmt *val); |
268 | void ad_stack_pop(AdStackAllocaStmt *stack); |
269 | AdStackLoadTopStmt *ad_stack_load_top(AdStackAllocaStmt *stack); |
270 | AdStackLoadTopAdjStmt *ad_stack_load_top_adjoint(AdStackAllocaStmt *stack); |
271 | void ad_stack_accumulate_adjoint(AdStackAllocaStmt *stack, Stmt *val); |
272 | |
273 | // Mesh related. |
274 | MeshRelationAccessStmt *get_relation_size(mesh::Mesh *mesh, |
275 | Stmt *mesh_idx, |
276 | mesh::MeshElementType to_type); |
277 | MeshRelationAccessStmt *get_relation_access(mesh::Mesh *mesh, |
278 | Stmt *mesh_idx, |
279 | mesh::MeshElementType to_type, |
280 | Stmt *neighbor_idx); |
281 | MeshPatchIndexStmt *get_patch_index(); |
282 | |
283 | private: |
284 | std::unique_ptr<Block> root_{nullptr}; |
285 | InsertPoint insert_point_; |
286 | }; |
287 | |
288 | } // namespace taichi::lang |
289 | |