1 | #include "taichi/ir/ir_builder.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/common/logging.h" |
4 | |
5 | namespace taichi::lang { |
6 | |
7 | namespace { |
8 | |
9 | inline bool stmt_location_did_not_change(Stmt *stmt, int location) { |
10 | return location >= 0 && location < stmt->parent->size() && |
11 | stmt->parent->statements[location].get() == stmt; |
12 | } |
13 | |
14 | } // namespace |
15 | |
16 | IRBuilder::IRBuilder() { |
17 | reset(); |
18 | } |
19 | |
20 | void IRBuilder::reset() { |
21 | root_ = std::make_unique<Block>(); |
22 | insert_point_.block = root_->as<Block>(); |
23 | insert_point_.position = 0; |
24 | } |
25 | |
26 | std::unique_ptr<Block> IRBuilder::() { |
27 | auto result = std::move(root_); |
28 | reset(); |
29 | return result; |
30 | } |
31 | |
32 | void IRBuilder::set_insertion_point(InsertPoint new_insert_point) { |
33 | insert_point_ = new_insert_point; |
34 | } |
35 | |
36 | void IRBuilder::set_insertion_point_to_after(Stmt *stmt) { |
37 | set_insertion_point({stmt->parent, stmt->parent->locate(stmt) + 1}); |
38 | } |
39 | |
40 | void IRBuilder::set_insertion_point_to_before(Stmt *stmt) { |
41 | set_insertion_point({stmt->parent, stmt->parent->locate(stmt)}); |
42 | } |
43 | |
44 | void IRBuilder::set_insertion_point_to_true_branch(IfStmt *if_stmt) { |
45 | if (!if_stmt->true_statements) |
46 | if_stmt->set_true_statements(std::make_unique<Block>()); |
47 | set_insertion_point({if_stmt->true_statements.get(), 0}); |
48 | } |
49 | |
50 | void IRBuilder::set_insertion_point_to_false_branch(IfStmt *if_stmt) { |
51 | if (!if_stmt->false_statements) |
52 | if_stmt->set_false_statements(std::make_unique<Block>()); |
53 | set_insertion_point({if_stmt->false_statements.get(), 0}); |
54 | } |
55 | |
56 | IRBuilder::LoopGuard::~LoopGuard() { |
57 | if (stmt_location_did_not_change(loop_, location_)) { |
58 | // faster than set_insertion_point_to_after() |
59 | builder_.set_insertion_point({loop_->parent, location_ + 1}); |
60 | } else { |
61 | builder_.set_insertion_point_to_after(loop_); |
62 | } |
63 | } |
64 | |
65 | IRBuilder::IfGuard::IfGuard(IRBuilder &builder, |
66 | IfStmt *if_stmt, |
67 | bool true_branch) |
68 | : builder_(builder), if_stmt_(if_stmt) { |
69 | location_ = (int)if_stmt_->parent->size() - 1; |
70 | if (true_branch) { |
71 | builder_.set_insertion_point_to_true_branch(if_stmt_); |
72 | } else { |
73 | builder_.set_insertion_point_to_false_branch(if_stmt_); |
74 | } |
75 | } |
76 | |
77 | IRBuilder::IfGuard::~IfGuard() { |
78 | if (stmt_location_did_not_change(if_stmt_, location_)) { |
79 | // faster than set_insertion_point_to_after() |
80 | builder_.set_insertion_point({if_stmt_->parent, location_ + 1}); |
81 | } else { |
82 | builder_.set_insertion_point_to_after(if_stmt_); |
83 | } |
84 | } |
85 | |
86 | RangeForStmt *IRBuilder::create_range_for(Stmt *begin, |
87 | Stmt *end, |
88 | bool is_bit_vectorized, |
89 | int num_cpu_threads, |
90 | int block_dim, |
91 | bool strictly_serialized) { |
92 | return insert(Stmt::make_typed<RangeForStmt>( |
93 | begin, end, std::make_unique<Block>(), is_bit_vectorized, num_cpu_threads, |
94 | block_dim, strictly_serialized)); |
95 | } |
96 | |
97 | StructForStmt *IRBuilder::create_struct_for(SNode *snode, |
98 | bool is_bit_vectorized, |
99 | int num_cpu_threads, |
100 | int block_dim) { |
101 | return insert(Stmt::make_typed<StructForStmt>( |
102 | snode, std::make_unique<Block>(), is_bit_vectorized, num_cpu_threads, |
103 | block_dim)); |
104 | } |
105 | |
106 | MeshForStmt *IRBuilder::create_mesh_for(mesh::Mesh *mesh, |
107 | mesh::MeshElementType element_type, |
108 | bool is_bit_vectorized, |
109 | int num_cpu_threads, |
110 | int block_dim) { |
111 | return insert(Stmt::make_typed<MeshForStmt>( |
112 | mesh, element_type, std::make_unique<Block>(), is_bit_vectorized, |
113 | num_cpu_threads, block_dim)); |
114 | } |
115 | |
116 | WhileStmt *IRBuilder::create_while_true() { |
117 | return insert(Stmt::make_typed<WhileStmt>(std::make_unique<Block>())); |
118 | } |
119 | |
120 | IfStmt *IRBuilder::create_if(Stmt *cond) { |
121 | return insert(Stmt::make_typed<IfStmt>(cond)); |
122 | } |
123 | |
124 | WhileControlStmt *IRBuilder::create_break() { |
125 | return insert(Stmt::make_typed<WhileControlStmt>(nullptr, get_int32(0))); |
126 | } |
127 | |
128 | ContinueStmt *IRBuilder::create_continue() { |
129 | return insert(Stmt::make_typed<ContinueStmt>()); |
130 | } |
131 | |
132 | FuncCallStmt *IRBuilder::create_func_call(Function *func, |
133 | const std::vector<Stmt *> &args) { |
134 | return insert(Stmt::make_typed<FuncCallStmt>(func, args)); |
135 | } |
136 | |
137 | LoopIndexStmt *IRBuilder::get_loop_index(Stmt *loop, int index) { |
138 | return insert(Stmt::make_typed<LoopIndexStmt>(loop, index)); |
139 | } |
140 | |
141 | ConstStmt *IRBuilder::get_int32(int32 value) { |
142 | return insert(Stmt::make_typed<ConstStmt>(TypedConstant( |
143 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32), |
144 | value))); |
145 | } |
146 | |
147 | ConstStmt *IRBuilder::get_int64(int64 value) { |
148 | return insert(Stmt::make_typed<ConstStmt>(TypedConstant( |
149 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i64), |
150 | value))); |
151 | } |
152 | |
153 | ConstStmt *IRBuilder::get_uint32(uint32 value) { |
154 | return insert(Stmt::make_typed<ConstStmt>(TypedConstant( |
155 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u32), |
156 | value))); |
157 | } |
158 | |
159 | ConstStmt *IRBuilder::get_uint64(uint64 value) { |
160 | return insert(Stmt::make_typed<ConstStmt>(TypedConstant( |
161 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u64), |
162 | value))); |
163 | } |
164 | |
165 | ConstStmt *IRBuilder::get_float32(float32 value) { |
166 | return insert(Stmt::make_typed<ConstStmt>(TypedConstant( |
167 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f32), |
168 | value))); |
169 | } |
170 | |
171 | ConstStmt *IRBuilder::get_float64(float64 value) { |
172 | return insert(Stmt::make_typed<ConstStmt>(TypedConstant( |
173 | TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f64), |
174 | value))); |
175 | } |
176 | |
177 | RandStmt *IRBuilder::create_rand(DataType value_type) { |
178 | return insert(Stmt::make_typed<RandStmt>(value_type)); |
179 | } |
180 | |
181 | ArgLoadStmt *IRBuilder::create_arg_load(int arg_id, DataType dt, bool is_ptr) { |
182 | return insert(Stmt::make_typed<ArgLoadStmt>(arg_id, dt, is_ptr)); |
183 | } |
184 | |
185 | ReturnStmt *IRBuilder::create_return(Stmt *value) { |
186 | return insert(Stmt::make_typed<ReturnStmt>(value)); |
187 | } |
188 | |
189 | UnaryOpStmt *IRBuilder::create_cast(Stmt *value, DataType output_type) { |
190 | auto &&result = Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_value, value); |
191 | result->cast_type = output_type; |
192 | return insert(std::move(result)); |
193 | } |
194 | |
195 | UnaryOpStmt *IRBuilder::create_bit_cast(Stmt *value, DataType output_type) { |
196 | auto &&result = Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_bits, value); |
197 | result->cast_type = output_type; |
198 | return insert(std::move(result)); |
199 | } |
200 | |
201 | UnaryOpStmt *IRBuilder::create_neg(Stmt *value) { |
202 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::neg, value)); |
203 | } |
204 | |
205 | UnaryOpStmt *IRBuilder::create_not(Stmt *value) { |
206 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::bit_not, value)); |
207 | } |
208 | |
209 | UnaryOpStmt *IRBuilder::create_logical_not(Stmt *value) { |
210 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::logic_not, value)); |
211 | } |
212 | |
213 | UnaryOpStmt *IRBuilder::create_round(Stmt *value) { |
214 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::round, value)); |
215 | } |
216 | |
217 | UnaryOpStmt *IRBuilder::create_floor(Stmt *value) { |
218 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::floor, value)); |
219 | } |
220 | |
221 | UnaryOpStmt *IRBuilder::create_ceil(Stmt *value) { |
222 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::ceil, value)); |
223 | } |
224 | |
225 | UnaryOpStmt *IRBuilder::create_abs(Stmt *value) { |
226 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::abs, value)); |
227 | } |
228 | |
229 | UnaryOpStmt *IRBuilder::create_sgn(Stmt *value) { |
230 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::sgn, value)); |
231 | } |
232 | |
233 | UnaryOpStmt *IRBuilder::create_sqrt(Stmt *value) { |
234 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::sqrt, value)); |
235 | } |
236 | |
237 | UnaryOpStmt *IRBuilder::create_rsqrt(Stmt *value) { |
238 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::rsqrt, value)); |
239 | } |
240 | |
241 | UnaryOpStmt *IRBuilder::create_sin(Stmt *value) { |
242 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::sin, value)); |
243 | } |
244 | |
245 | UnaryOpStmt *IRBuilder::create_asin(Stmt *value) { |
246 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::asin, value)); |
247 | } |
248 | |
249 | UnaryOpStmt *IRBuilder::create_cos(Stmt *value) { |
250 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cos, value)); |
251 | } |
252 | |
253 | UnaryOpStmt *IRBuilder::create_acos(Stmt *value) { |
254 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::acos, value)); |
255 | } |
256 | |
257 | UnaryOpStmt *IRBuilder::create_tan(Stmt *value) { |
258 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::tan, value)); |
259 | } |
260 | |
261 | UnaryOpStmt *IRBuilder::create_tanh(Stmt *value) { |
262 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::tanh, value)); |
263 | } |
264 | |
265 | UnaryOpStmt *IRBuilder::create_exp(Stmt *value) { |
266 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::exp, value)); |
267 | } |
268 | |
269 | UnaryOpStmt *IRBuilder::create_log(Stmt *value) { |
270 | return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::log, value)); |
271 | } |
272 | |
273 | BinaryOpStmt *IRBuilder::create_add(Stmt *l, Stmt *r) { |
274 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::add, l, r)); |
275 | } |
276 | |
277 | BinaryOpStmt *IRBuilder::create_sub(Stmt *l, Stmt *r) { |
278 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::sub, l, r)); |
279 | } |
280 | |
281 | BinaryOpStmt *IRBuilder::create_mul(Stmt *l, Stmt *r) { |
282 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::mul, l, r)); |
283 | } |
284 | |
285 | BinaryOpStmt *IRBuilder::create_div(Stmt *l, Stmt *r) { |
286 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::div, l, r)); |
287 | } |
288 | |
289 | BinaryOpStmt *IRBuilder::create_floordiv(Stmt *l, Stmt *r) { |
290 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::floordiv, l, r)); |
291 | } |
292 | |
293 | BinaryOpStmt *IRBuilder::create_truediv(Stmt *l, Stmt *r) { |
294 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::truediv, l, r)); |
295 | } |
296 | |
297 | BinaryOpStmt *IRBuilder::create_mod(Stmt *l, Stmt *r) { |
298 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::mod, l, r)); |
299 | } |
300 | |
301 | BinaryOpStmt *IRBuilder::create_max(Stmt *l, Stmt *r) { |
302 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::max, l, r)); |
303 | } |
304 | |
305 | BinaryOpStmt *IRBuilder::create_min(Stmt *l, Stmt *r) { |
306 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::min, l, r)); |
307 | } |
308 | |
309 | BinaryOpStmt *IRBuilder::create_atan2(Stmt *l, Stmt *r) { |
310 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::atan2, l, r)); |
311 | } |
312 | |
313 | BinaryOpStmt *IRBuilder::create_pow(Stmt *l, Stmt *r) { |
314 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::pow, l, r)); |
315 | } |
316 | |
317 | BinaryOpStmt *IRBuilder::create_and(Stmt *l, Stmt *r) { |
318 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_and, l, r)); |
319 | } |
320 | |
321 | BinaryOpStmt *IRBuilder::create_or(Stmt *l, Stmt *r) { |
322 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_or, l, r)); |
323 | } |
324 | |
325 | BinaryOpStmt *IRBuilder::create_xor(Stmt *l, Stmt *r) { |
326 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_xor, l, r)); |
327 | } |
328 | |
329 | BinaryOpStmt *IRBuilder::create_shl(Stmt *l, Stmt *r) { |
330 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_shl, l, r)); |
331 | } |
332 | |
333 | BinaryOpStmt *IRBuilder::create_shr(Stmt *l, Stmt *r) { |
334 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_shr, l, r)); |
335 | } |
336 | |
337 | BinaryOpStmt *IRBuilder::create_sar(Stmt *l, Stmt *r) { |
338 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_sar, l, r)); |
339 | } |
340 | |
341 | BinaryOpStmt *IRBuilder::create_cmp_lt(Stmt *l, Stmt *r) { |
342 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_lt, l, r)); |
343 | } |
344 | |
345 | BinaryOpStmt *IRBuilder::create_cmp_le(Stmt *l, Stmt *r) { |
346 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_le, l, r)); |
347 | } |
348 | |
349 | BinaryOpStmt *IRBuilder::create_cmp_gt(Stmt *l, Stmt *r) { |
350 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_gt, l, r)); |
351 | } |
352 | |
353 | BinaryOpStmt *IRBuilder::create_cmp_ge(Stmt *l, Stmt *r) { |
354 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_ge, l, r)); |
355 | } |
356 | |
357 | BinaryOpStmt *IRBuilder::create_cmp_eq(Stmt *l, Stmt *r) { |
358 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_eq, l, r)); |
359 | } |
360 | |
361 | BinaryOpStmt *IRBuilder::create_cmp_ne(Stmt *l, Stmt *r) { |
362 | return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_ne, l, r)); |
363 | } |
364 | |
365 | AtomicOpStmt *IRBuilder::create_atomic_add(Stmt *dest, Stmt *val) { |
366 | return insert(Stmt::make_typed<AtomicOpStmt>(AtomicOpType::add, dest, val)); |
367 | } |
368 | |
369 | AtomicOpStmt *IRBuilder::create_atomic_sub(Stmt *dest, Stmt *val) { |
370 | return insert(Stmt::make_typed<AtomicOpStmt>(AtomicOpType::sub, dest, val)); |
371 | } |
372 | |
373 | AtomicOpStmt *IRBuilder::create_atomic_max(Stmt *dest, Stmt *val) { |
374 | return insert(Stmt::make_typed<AtomicOpStmt>(AtomicOpType::max, dest, val)); |
375 | } |
376 | |
377 | AtomicOpStmt *IRBuilder::create_atomic_min(Stmt *dest, Stmt *val) { |
378 | return insert(Stmt::make_typed<AtomicOpStmt>(AtomicOpType::min, dest, val)); |
379 | } |
380 | |
381 | AtomicOpStmt *IRBuilder::create_atomic_and(Stmt *dest, Stmt *val) { |
382 | return insert( |
383 | Stmt::make_typed<AtomicOpStmt>(AtomicOpType::bit_and, dest, val)); |
384 | } |
385 | |
386 | AtomicOpStmt *IRBuilder::create_atomic_or(Stmt *dest, Stmt *val) { |
387 | return insert( |
388 | Stmt::make_typed<AtomicOpStmt>(AtomicOpType::bit_or, dest, val)); |
389 | } |
390 | |
391 | AtomicOpStmt *IRBuilder::create_atomic_xor(Stmt *dest, Stmt *val) { |
392 | return insert( |
393 | Stmt::make_typed<AtomicOpStmt>(AtomicOpType::bit_xor, dest, val)); |
394 | } |
395 | |
396 | TernaryOpStmt *IRBuilder::create_select(Stmt *cond, |
397 | Stmt *true_result, |
398 | Stmt *false_result) { |
399 | return insert(Stmt::make_typed<TernaryOpStmt>(TernaryOpType::select, cond, |
400 | true_result, false_result)); |
401 | } |
402 | |
403 | AllocaStmt *IRBuilder::create_local_var(DataType dt) { |
404 | return insert(Stmt::make_typed<AllocaStmt>(dt)); |
405 | } |
406 | |
407 | LocalLoadStmt *IRBuilder::create_local_load(AllocaStmt *ptr) { |
408 | return insert(Stmt::make_typed<LocalLoadStmt>(ptr)); |
409 | } |
410 | |
411 | void IRBuilder::create_local_store(AllocaStmt *ptr, Stmt *data) { |
412 | insert(Stmt::make_typed<LocalStoreStmt>(ptr, data)); |
413 | } |
414 | |
415 | GlobalPtrStmt *IRBuilder::create_global_ptr( |
416 | SNode *snode, |
417 | const std::vector<Stmt *> &indices) { |
418 | return insert(Stmt::make_typed<GlobalPtrStmt>(snode, indices)); |
419 | } |
420 | |
421 | ExternalPtrStmt *IRBuilder::create_external_ptr( |
422 | ArgLoadStmt *ptr, |
423 | const std::vector<Stmt *> &indices) { |
424 | return insert( |
425 | Stmt::make_typed<ExternalPtrStmt>(ptr, indices, std::vector<int>(), 0)); |
426 | } |
427 | |
428 | AdStackAllocaStmt *IRBuilder::create_ad_stack(const DataType &dt, |
429 | std::size_t max_size) { |
430 | return insert(Stmt::make_typed<AdStackAllocaStmt>(dt, max_size)); |
431 | } |
432 | |
433 | void IRBuilder::ad_stack_push(AdStackAllocaStmt *stack, Stmt *val) { |
434 | insert(Stmt::make_typed<AdStackPushStmt>(stack, val)); |
435 | } |
436 | |
437 | void IRBuilder::ad_stack_pop(AdStackAllocaStmt *stack) { |
438 | insert(Stmt::make_typed<AdStackPopStmt>(stack)); |
439 | } |
440 | |
441 | AdStackLoadTopStmt *IRBuilder::ad_stack_load_top(AdStackAllocaStmt *stack) { |
442 | return insert(Stmt::make_typed<AdStackLoadTopStmt>(stack)); |
443 | } |
444 | |
445 | AdStackLoadTopAdjStmt *IRBuilder::ad_stack_load_top_adjoint( |
446 | AdStackAllocaStmt *stack) { |
447 | return insert(Stmt::make_typed<AdStackLoadTopAdjStmt>(stack)); |
448 | } |
449 | |
450 | MatrixInitStmt *IRBuilder::create_matrix_init(std::vector<Stmt *> elements) { |
451 | return insert(Stmt::make_typed<MatrixInitStmt>(elements)); |
452 | } |
453 | |
454 | void IRBuilder::ad_stack_accumulate_adjoint(AdStackAllocaStmt *stack, |
455 | Stmt *val) { |
456 | insert(Stmt::make_typed<AdStackAccAdjointStmt>(stack, val)); |
457 | } |
458 | |
459 | // Mesh related. |
460 | |
461 | MeshRelationAccessStmt *IRBuilder::get_relation_size( |
462 | mesh::Mesh *mesh, |
463 | Stmt *mesh_idx, |
464 | mesh::MeshElementType to_type) { |
465 | return insert( |
466 | Stmt::make_typed<MeshRelationAccessStmt>(mesh, mesh_idx, to_type)); |
467 | } |
468 | |
469 | MeshRelationAccessStmt *IRBuilder::get_relation_access( |
470 | mesh::Mesh *mesh, |
471 | Stmt *mesh_idx, |
472 | mesh::MeshElementType to_type, |
473 | Stmt *neighbor_idx) { |
474 | return insert(Stmt::make_typed<MeshRelationAccessStmt>( |
475 | mesh, mesh_idx, to_type, neighbor_idx)); |
476 | } |
477 | |
478 | MeshPatchIndexStmt *IRBuilder::get_patch_index() { |
479 | return insert(Stmt::make_typed<MeshPatchIndexStmt>()); |
480 | } |
481 | |
482 | } // namespace taichi::lang |
483 | |