1#include "taichi/ir/ir_builder.h"
2#include "taichi/ir/statements.h"
3#include "taichi/common/logging.h"
4
5namespace taichi::lang {
6
7namespace {
8
9inline bool stmt_location_did_not_change(Stmt *stmt, int location) {
10 return location >= 0 && location < stmt->parent->size() &&
11 stmt->parent->statements[location].get() == stmt;
12}
13
14} // namespace
15
16IRBuilder::IRBuilder() {
17 reset();
18}
19
20void IRBuilder::reset() {
21 root_ = std::make_unique<Block>();
22 insert_point_.block = root_->as<Block>();
23 insert_point_.position = 0;
24}
25
26std::unique_ptr<Block> IRBuilder::extract_ir() {
27 auto result = std::move(root_);
28 reset();
29 return result;
30}
31
32void IRBuilder::set_insertion_point(InsertPoint new_insert_point) {
33 insert_point_ = new_insert_point;
34}
35
36void IRBuilder::set_insertion_point_to_after(Stmt *stmt) {
37 set_insertion_point({stmt->parent, stmt->parent->locate(stmt) + 1});
38}
39
40void IRBuilder::set_insertion_point_to_before(Stmt *stmt) {
41 set_insertion_point({stmt->parent, stmt->parent->locate(stmt)});
42}
43
44void IRBuilder::set_insertion_point_to_true_branch(IfStmt *if_stmt) {
45 if (!if_stmt->true_statements)
46 if_stmt->set_true_statements(std::make_unique<Block>());
47 set_insertion_point({if_stmt->true_statements.get(), 0});
48}
49
50void IRBuilder::set_insertion_point_to_false_branch(IfStmt *if_stmt) {
51 if (!if_stmt->false_statements)
52 if_stmt->set_false_statements(std::make_unique<Block>());
53 set_insertion_point({if_stmt->false_statements.get(), 0});
54}
55
56IRBuilder::LoopGuard::~LoopGuard() {
57 if (stmt_location_did_not_change(loop_, location_)) {
58 // faster than set_insertion_point_to_after()
59 builder_.set_insertion_point({loop_->parent, location_ + 1});
60 } else {
61 builder_.set_insertion_point_to_after(loop_);
62 }
63}
64
65IRBuilder::IfGuard::IfGuard(IRBuilder &builder,
66 IfStmt *if_stmt,
67 bool true_branch)
68 : builder_(builder), if_stmt_(if_stmt) {
69 location_ = (int)if_stmt_->parent->size() - 1;
70 if (true_branch) {
71 builder_.set_insertion_point_to_true_branch(if_stmt_);
72 } else {
73 builder_.set_insertion_point_to_false_branch(if_stmt_);
74 }
75}
76
77IRBuilder::IfGuard::~IfGuard() {
78 if (stmt_location_did_not_change(if_stmt_, location_)) {
79 // faster than set_insertion_point_to_after()
80 builder_.set_insertion_point({if_stmt_->parent, location_ + 1});
81 } else {
82 builder_.set_insertion_point_to_after(if_stmt_);
83 }
84}
85
86RangeForStmt *IRBuilder::create_range_for(Stmt *begin,
87 Stmt *end,
88 bool is_bit_vectorized,
89 int num_cpu_threads,
90 int block_dim,
91 bool strictly_serialized) {
92 return insert(Stmt::make_typed<RangeForStmt>(
93 begin, end, std::make_unique<Block>(), is_bit_vectorized, num_cpu_threads,
94 block_dim, strictly_serialized));
95}
96
97StructForStmt *IRBuilder::create_struct_for(SNode *snode,
98 bool is_bit_vectorized,
99 int num_cpu_threads,
100 int block_dim) {
101 return insert(Stmt::make_typed<StructForStmt>(
102 snode, std::make_unique<Block>(), is_bit_vectorized, num_cpu_threads,
103 block_dim));
104}
105
106MeshForStmt *IRBuilder::create_mesh_for(mesh::Mesh *mesh,
107 mesh::MeshElementType element_type,
108 bool is_bit_vectorized,
109 int num_cpu_threads,
110 int block_dim) {
111 return insert(Stmt::make_typed<MeshForStmt>(
112 mesh, element_type, std::make_unique<Block>(), is_bit_vectorized,
113 num_cpu_threads, block_dim));
114}
115
116WhileStmt *IRBuilder::create_while_true() {
117 return insert(Stmt::make_typed<WhileStmt>(std::make_unique<Block>()));
118}
119
120IfStmt *IRBuilder::create_if(Stmt *cond) {
121 return insert(Stmt::make_typed<IfStmt>(cond));
122}
123
124WhileControlStmt *IRBuilder::create_break() {
125 return insert(Stmt::make_typed<WhileControlStmt>(nullptr, get_int32(0)));
126}
127
128ContinueStmt *IRBuilder::create_continue() {
129 return insert(Stmt::make_typed<ContinueStmt>());
130}
131
132FuncCallStmt *IRBuilder::create_func_call(Function *func,
133 const std::vector<Stmt *> &args) {
134 return insert(Stmt::make_typed<FuncCallStmt>(func, args));
135}
136
137LoopIndexStmt *IRBuilder::get_loop_index(Stmt *loop, int index) {
138 return insert(Stmt::make_typed<LoopIndexStmt>(loop, index));
139}
140
141ConstStmt *IRBuilder::get_int32(int32 value) {
142 return insert(Stmt::make_typed<ConstStmt>(TypedConstant(
143 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32),
144 value)));
145}
146
147ConstStmt *IRBuilder::get_int64(int64 value) {
148 return insert(Stmt::make_typed<ConstStmt>(TypedConstant(
149 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i64),
150 value)));
151}
152
153ConstStmt *IRBuilder::get_uint32(uint32 value) {
154 return insert(Stmt::make_typed<ConstStmt>(TypedConstant(
155 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u32),
156 value)));
157}
158
159ConstStmt *IRBuilder::get_uint64(uint64 value) {
160 return insert(Stmt::make_typed<ConstStmt>(TypedConstant(
161 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u64),
162 value)));
163}
164
165ConstStmt *IRBuilder::get_float32(float32 value) {
166 return insert(Stmt::make_typed<ConstStmt>(TypedConstant(
167 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f32),
168 value)));
169}
170
171ConstStmt *IRBuilder::get_float64(float64 value) {
172 return insert(Stmt::make_typed<ConstStmt>(TypedConstant(
173 TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f64),
174 value)));
175}
176
177RandStmt *IRBuilder::create_rand(DataType value_type) {
178 return insert(Stmt::make_typed<RandStmt>(value_type));
179}
180
181ArgLoadStmt *IRBuilder::create_arg_load(int arg_id, DataType dt, bool is_ptr) {
182 return insert(Stmt::make_typed<ArgLoadStmt>(arg_id, dt, is_ptr));
183}
184
185ReturnStmt *IRBuilder::create_return(Stmt *value) {
186 return insert(Stmt::make_typed<ReturnStmt>(value));
187}
188
189UnaryOpStmt *IRBuilder::create_cast(Stmt *value, DataType output_type) {
190 auto &&result = Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_value, value);
191 result->cast_type = output_type;
192 return insert(std::move(result));
193}
194
195UnaryOpStmt *IRBuilder::create_bit_cast(Stmt *value, DataType output_type) {
196 auto &&result = Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_bits, value);
197 result->cast_type = output_type;
198 return insert(std::move(result));
199}
200
201UnaryOpStmt *IRBuilder::create_neg(Stmt *value) {
202 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::neg, value));
203}
204
205UnaryOpStmt *IRBuilder::create_not(Stmt *value) {
206 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::bit_not, value));
207}
208
209UnaryOpStmt *IRBuilder::create_logical_not(Stmt *value) {
210 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::logic_not, value));
211}
212
213UnaryOpStmt *IRBuilder::create_round(Stmt *value) {
214 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::round, value));
215}
216
217UnaryOpStmt *IRBuilder::create_floor(Stmt *value) {
218 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::floor, value));
219}
220
221UnaryOpStmt *IRBuilder::create_ceil(Stmt *value) {
222 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::ceil, value));
223}
224
225UnaryOpStmt *IRBuilder::create_abs(Stmt *value) {
226 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::abs, value));
227}
228
229UnaryOpStmt *IRBuilder::create_sgn(Stmt *value) {
230 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::sgn, value));
231}
232
233UnaryOpStmt *IRBuilder::create_sqrt(Stmt *value) {
234 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::sqrt, value));
235}
236
237UnaryOpStmt *IRBuilder::create_rsqrt(Stmt *value) {
238 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::rsqrt, value));
239}
240
241UnaryOpStmt *IRBuilder::create_sin(Stmt *value) {
242 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::sin, value));
243}
244
245UnaryOpStmt *IRBuilder::create_asin(Stmt *value) {
246 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::asin, value));
247}
248
249UnaryOpStmt *IRBuilder::create_cos(Stmt *value) {
250 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cos, value));
251}
252
253UnaryOpStmt *IRBuilder::create_acos(Stmt *value) {
254 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::acos, value));
255}
256
257UnaryOpStmt *IRBuilder::create_tan(Stmt *value) {
258 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::tan, value));
259}
260
261UnaryOpStmt *IRBuilder::create_tanh(Stmt *value) {
262 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::tanh, value));
263}
264
265UnaryOpStmt *IRBuilder::create_exp(Stmt *value) {
266 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::exp, value));
267}
268
269UnaryOpStmt *IRBuilder::create_log(Stmt *value) {
270 return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::log, value));
271}
272
273BinaryOpStmt *IRBuilder::create_add(Stmt *l, Stmt *r) {
274 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::add, l, r));
275}
276
277BinaryOpStmt *IRBuilder::create_sub(Stmt *l, Stmt *r) {
278 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::sub, l, r));
279}
280
281BinaryOpStmt *IRBuilder::create_mul(Stmt *l, Stmt *r) {
282 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::mul, l, r));
283}
284
285BinaryOpStmt *IRBuilder::create_div(Stmt *l, Stmt *r) {
286 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::div, l, r));
287}
288
289BinaryOpStmt *IRBuilder::create_floordiv(Stmt *l, Stmt *r) {
290 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::floordiv, l, r));
291}
292
293BinaryOpStmt *IRBuilder::create_truediv(Stmt *l, Stmt *r) {
294 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::truediv, l, r));
295}
296
297BinaryOpStmt *IRBuilder::create_mod(Stmt *l, Stmt *r) {
298 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::mod, l, r));
299}
300
301BinaryOpStmt *IRBuilder::create_max(Stmt *l, Stmt *r) {
302 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::max, l, r));
303}
304
305BinaryOpStmt *IRBuilder::create_min(Stmt *l, Stmt *r) {
306 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::min, l, r));
307}
308
309BinaryOpStmt *IRBuilder::create_atan2(Stmt *l, Stmt *r) {
310 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::atan2, l, r));
311}
312
313BinaryOpStmt *IRBuilder::create_pow(Stmt *l, Stmt *r) {
314 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::pow, l, r));
315}
316
317BinaryOpStmt *IRBuilder::create_and(Stmt *l, Stmt *r) {
318 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_and, l, r));
319}
320
321BinaryOpStmt *IRBuilder::create_or(Stmt *l, Stmt *r) {
322 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_or, l, r));
323}
324
325BinaryOpStmt *IRBuilder::create_xor(Stmt *l, Stmt *r) {
326 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_xor, l, r));
327}
328
329BinaryOpStmt *IRBuilder::create_shl(Stmt *l, Stmt *r) {
330 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_shl, l, r));
331}
332
333BinaryOpStmt *IRBuilder::create_shr(Stmt *l, Stmt *r) {
334 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_shr, l, r));
335}
336
337BinaryOpStmt *IRBuilder::create_sar(Stmt *l, Stmt *r) {
338 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::bit_sar, l, r));
339}
340
341BinaryOpStmt *IRBuilder::create_cmp_lt(Stmt *l, Stmt *r) {
342 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_lt, l, r));
343}
344
345BinaryOpStmt *IRBuilder::create_cmp_le(Stmt *l, Stmt *r) {
346 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_le, l, r));
347}
348
349BinaryOpStmt *IRBuilder::create_cmp_gt(Stmt *l, Stmt *r) {
350 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_gt, l, r));
351}
352
353BinaryOpStmt *IRBuilder::create_cmp_ge(Stmt *l, Stmt *r) {
354 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_ge, l, r));
355}
356
357BinaryOpStmt *IRBuilder::create_cmp_eq(Stmt *l, Stmt *r) {
358 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_eq, l, r));
359}
360
361BinaryOpStmt *IRBuilder::create_cmp_ne(Stmt *l, Stmt *r) {
362 return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::cmp_ne, l, r));
363}
364
365AtomicOpStmt *IRBuilder::create_atomic_add(Stmt *dest, Stmt *val) {
366 return insert(Stmt::make_typed<AtomicOpStmt>(AtomicOpType::add, dest, val));
367}
368
369AtomicOpStmt *IRBuilder::create_atomic_sub(Stmt *dest, Stmt *val) {
370 return insert(Stmt::make_typed<AtomicOpStmt>(AtomicOpType::sub, dest, val));
371}
372
373AtomicOpStmt *IRBuilder::create_atomic_max(Stmt *dest, Stmt *val) {
374 return insert(Stmt::make_typed<AtomicOpStmt>(AtomicOpType::max, dest, val));
375}
376
377AtomicOpStmt *IRBuilder::create_atomic_min(Stmt *dest, Stmt *val) {
378 return insert(Stmt::make_typed<AtomicOpStmt>(AtomicOpType::min, dest, val));
379}
380
381AtomicOpStmt *IRBuilder::create_atomic_and(Stmt *dest, Stmt *val) {
382 return insert(
383 Stmt::make_typed<AtomicOpStmt>(AtomicOpType::bit_and, dest, val));
384}
385
386AtomicOpStmt *IRBuilder::create_atomic_or(Stmt *dest, Stmt *val) {
387 return insert(
388 Stmt::make_typed<AtomicOpStmt>(AtomicOpType::bit_or, dest, val));
389}
390
391AtomicOpStmt *IRBuilder::create_atomic_xor(Stmt *dest, Stmt *val) {
392 return insert(
393 Stmt::make_typed<AtomicOpStmt>(AtomicOpType::bit_xor, dest, val));
394}
395
396TernaryOpStmt *IRBuilder::create_select(Stmt *cond,
397 Stmt *true_result,
398 Stmt *false_result) {
399 return insert(Stmt::make_typed<TernaryOpStmt>(TernaryOpType::select, cond,
400 true_result, false_result));
401}
402
403AllocaStmt *IRBuilder::create_local_var(DataType dt) {
404 return insert(Stmt::make_typed<AllocaStmt>(dt));
405}
406
407LocalLoadStmt *IRBuilder::create_local_load(AllocaStmt *ptr) {
408 return insert(Stmt::make_typed<LocalLoadStmt>(ptr));
409}
410
411void IRBuilder::create_local_store(AllocaStmt *ptr, Stmt *data) {
412 insert(Stmt::make_typed<LocalStoreStmt>(ptr, data));
413}
414
415GlobalPtrStmt *IRBuilder::create_global_ptr(
416 SNode *snode,
417 const std::vector<Stmt *> &indices) {
418 return insert(Stmt::make_typed<GlobalPtrStmt>(snode, indices));
419}
420
421ExternalPtrStmt *IRBuilder::create_external_ptr(
422 ArgLoadStmt *ptr,
423 const std::vector<Stmt *> &indices) {
424 return insert(
425 Stmt::make_typed<ExternalPtrStmt>(ptr, indices, std::vector<int>(), 0));
426}
427
428AdStackAllocaStmt *IRBuilder::create_ad_stack(const DataType &dt,
429 std::size_t max_size) {
430 return insert(Stmt::make_typed<AdStackAllocaStmt>(dt, max_size));
431}
432
433void IRBuilder::ad_stack_push(AdStackAllocaStmt *stack, Stmt *val) {
434 insert(Stmt::make_typed<AdStackPushStmt>(stack, val));
435}
436
437void IRBuilder::ad_stack_pop(AdStackAllocaStmt *stack) {
438 insert(Stmt::make_typed<AdStackPopStmt>(stack));
439}
440
441AdStackLoadTopStmt *IRBuilder::ad_stack_load_top(AdStackAllocaStmt *stack) {
442 return insert(Stmt::make_typed<AdStackLoadTopStmt>(stack));
443}
444
445AdStackLoadTopAdjStmt *IRBuilder::ad_stack_load_top_adjoint(
446 AdStackAllocaStmt *stack) {
447 return insert(Stmt::make_typed<AdStackLoadTopAdjStmt>(stack));
448}
449
450MatrixInitStmt *IRBuilder::create_matrix_init(std::vector<Stmt *> elements) {
451 return insert(Stmt::make_typed<MatrixInitStmt>(elements));
452}
453
454void IRBuilder::ad_stack_accumulate_adjoint(AdStackAllocaStmt *stack,
455 Stmt *val) {
456 insert(Stmt::make_typed<AdStackAccAdjointStmt>(stack, val));
457}
458
459// Mesh related.
460
461MeshRelationAccessStmt *IRBuilder::get_relation_size(
462 mesh::Mesh *mesh,
463 Stmt *mesh_idx,
464 mesh::MeshElementType to_type) {
465 return insert(
466 Stmt::make_typed<MeshRelationAccessStmt>(mesh, mesh_idx, to_type));
467}
468
469MeshRelationAccessStmt *IRBuilder::get_relation_access(
470 mesh::Mesh *mesh,
471 Stmt *mesh_idx,
472 mesh::MeshElementType to_type,
473 Stmt *neighbor_idx) {
474 return insert(Stmt::make_typed<MeshRelationAccessStmt>(
475 mesh, mesh_idx, to_type, neighbor_idx));
476}
477
478MeshPatchIndexStmt *IRBuilder::get_patch_index() {
479 return insert(Stmt::make_typed<MeshPatchIndexStmt>());
480}
481
482} // namespace taichi::lang
483