1#pragma once
2
3#include "taichi/ir/ir.h"
4#include "taichi/ir/mesh.h"
5
6namespace taichi::lang {
7
8class Function;
9
10class IRBuilder {
11 public:
12 struct InsertPoint {
13 Block *block{nullptr};
14 int position{0};
15 };
16
17 IRBuilder();
18
19 // Clear the IR and the insertion point.
20 void reset();
21
22 // Extract the IR.
23 std::unique_ptr<Block> extract_ir();
24
25 // General inserter. Returns stmt.get().
26 template <typename XStmt>
27 XStmt *insert(std::unique_ptr<XStmt> &&stmt) {
28 return insert(std::move(stmt), &insert_point_);
29 }
30
31 // Insert to a specific insertion point.
32 template <typename XStmt>
33 static XStmt *insert(std::unique_ptr<XStmt> &&stmt,
34 InsertPoint *insert_point) {
35 return insert_point->block
36 ->insert(std::move(stmt), insert_point->position++)
37 ->template as<XStmt>();
38 }
39
40 void set_insertion_point(InsertPoint new_insert_point);
41 void set_insertion_point_to_after(Stmt *stmt);
42 void set_insertion_point_to_before(Stmt *stmt);
43 void set_insertion_point_to_true_branch(IfStmt *if_stmt);
44 void set_insertion_point_to_false_branch(IfStmt *if_stmt);
45 template <typename XStmt>
46 void set_insertion_point_to_loop_begin(XStmt *loop) {
47 using DecayedType = typename std::decay_t<XStmt>;
48 if constexpr (!std::is_base_of_v<Stmt, DecayedType>) {
49 TI_ERROR("The argument is not a statement.");
50 }
51 if constexpr (std::is_same_v<DecayedType, RangeForStmt> ||
52 std::is_same_v<DecayedType, StructForStmt> ||
53 std::is_same_v<DecayedType, MeshForStmt> ||
54 std::is_same_v<DecayedType, WhileStmt>) {
55 set_insertion_point({loop->body.get(), 0});
56 } else {
57 TI_ERROR("Statement {} is not a loop.", loop->name());
58 }
59 }
60
61 // RAII handles insertion points automatically.
62 class LoopGuard {
63 public:
64 // Set the insertion point to the beginning of the loop body.
65 template <typename XStmt>
66 explicit LoopGuard(IRBuilder &builder, XStmt *loop)
67 : builder_(builder), loop_(loop) {
68 location_ = (int)loop->parent->size() - 1;
69 builder_.set_insertion_point_to_loop_begin(loop);
70 }
71
72 // Set the insertion point to the point after the loop.
73 ~LoopGuard();
74
75 private:
76 IRBuilder &builder_;
77 Stmt *loop_;
78 int location_;
79 };
80 class IfGuard {
81 public:
82 // Set the insertion point to the beginning of the true/false branch.
83 explicit IfGuard(IRBuilder &builder, IfStmt *if_stmt, bool true_branch);
84
85 // Set the insertion point to the point after the if statement.
86 ~IfGuard();
87
88 private:
89 IRBuilder &builder_;
90 IfStmt *if_stmt_;
91 int location_;
92 };
93
94 template <typename XStmt>
95 [[nodiscard]] LoopGuard get_loop_guard(XStmt *loop) {
96 return LoopGuard(*this, loop);
97 }
98
99 [[nodiscard]] IfGuard get_if_guard(IfStmt *if_stmt, bool true_branch) {
100 return IfGuard(*this, if_stmt, true_branch);
101 }
102
103 // Control flows.
104 RangeForStmt *create_range_for(Stmt *begin,
105 Stmt *end,
106 bool is_bit_vectorized = false,
107 int num_cpu_threads = 0,
108 int block_dim = 0,
109 bool strictly_serialized = false);
110 StructForStmt *create_struct_for(SNode *snode,
111 bool is_bit_vectorized = false,
112 int num_cpu_threads = 0,
113 int block_dim = 0);
114 MeshForStmt *create_mesh_for(mesh::Mesh *mesh,
115 mesh::MeshElementType element_type,
116 bool is_bit_vectorized = false,
117 int num_cpu_threads = 0,
118 int block_dim = 0);
119 WhileStmt *create_while_true();
120 IfStmt *create_if(Stmt *cond);
121 WhileControlStmt *create_break();
122 ContinueStmt *create_continue();
123
124 // Function.
125 FuncCallStmt *create_func_call(Function *func,
126 const std::vector<Stmt *> &args);
127
128 // Loop index.
129 LoopIndexStmt *get_loop_index(Stmt *loop, int index = 0);
130
131 // Constants. TODO: add more types
132 ConstStmt *get_int32(int32 value);
133 ConstStmt *get_int64(int64 value);
134 ConstStmt *get_uint32(uint32 value);
135 ConstStmt *get_uint64(uint64 value);
136 ConstStmt *get_float32(float32 value);
137 ConstStmt *get_float64(float64 value);
138
139 template <typename T>
140 ConstStmt *get_constant(DataType dt, const T &value) {
141 return insert(Stmt::make_typed<ConstStmt>(TypedConstant(dt, value)));
142 }
143
144 RandStmt *create_rand(DataType value_type);
145
146 // Load kernel arguments.
147 ArgLoadStmt *create_arg_load(int arg_id, DataType dt, bool is_ptr);
148
149 // The return value of the kernel.
150 ReturnStmt *create_return(Stmt *value);
151
152 // Unary operations. Returns the result.
153 UnaryOpStmt *create_cast(Stmt *value, DataType output_type); // cast by value
154 UnaryOpStmt *create_bit_cast(Stmt *value, DataType output_type);
155 UnaryOpStmt *create_neg(Stmt *value);
156 UnaryOpStmt *create_not(Stmt *value); // bitwise
157 UnaryOpStmt *create_logical_not(Stmt *value);
158 UnaryOpStmt *create_round(Stmt *value);
159 UnaryOpStmt *create_floor(Stmt *value);
160 UnaryOpStmt *create_ceil(Stmt *value);
161 UnaryOpStmt *create_abs(Stmt *value);
162 UnaryOpStmt *create_sgn(Stmt *value);
163 UnaryOpStmt *create_sqrt(Stmt *value);
164 UnaryOpStmt *create_rsqrt(Stmt *value);
165 UnaryOpStmt *create_sin(Stmt *value);
166 UnaryOpStmt *create_asin(Stmt *value);
167 UnaryOpStmt *create_cos(Stmt *value);
168 UnaryOpStmt *create_acos(Stmt *value);
169 UnaryOpStmt *create_tan(Stmt *value);
170 UnaryOpStmt *create_tanh(Stmt *value);
171 UnaryOpStmt *create_exp(Stmt *value);
172 UnaryOpStmt *create_log(Stmt *value);
173
174 // Binary operations. Returns the result.
175 BinaryOpStmt *create_add(Stmt *l, Stmt *r);
176 BinaryOpStmt *create_sub(Stmt *l, Stmt *r);
177 BinaryOpStmt *create_mul(Stmt *l, Stmt *r);
178 // l / r in C++
179 BinaryOpStmt *create_div(Stmt *l, Stmt *r);
180 // floor(1.0 * l / r) in C++
181 BinaryOpStmt *create_floordiv(Stmt *l, Stmt *r);
182 // 1.0 * l / r in C++
183 BinaryOpStmt *create_truediv(Stmt *l, Stmt *r);
184 BinaryOpStmt *create_mod(Stmt *l, Stmt *r);
185 BinaryOpStmt *create_max(Stmt *l, Stmt *r);
186 BinaryOpStmt *create_min(Stmt *l, Stmt *r);
187 BinaryOpStmt *create_atan2(Stmt *l, Stmt *r);
188 BinaryOpStmt *create_pow(Stmt *l, Stmt *r);
189 // Bitwise operations. TODO: add logical operations when we support them
190 BinaryOpStmt *create_and(Stmt *l, Stmt *r);
191 BinaryOpStmt *create_or(Stmt *l, Stmt *r);
192 BinaryOpStmt *create_xor(Stmt *l, Stmt *r);
193 BinaryOpStmt *create_shl(Stmt *l, Stmt *r);
194 BinaryOpStmt *create_shr(Stmt *l, Stmt *r);
195 BinaryOpStmt *create_sar(Stmt *l, Stmt *r);
196 // Comparisons.
197 BinaryOpStmt *create_cmp_lt(Stmt *l, Stmt *r);
198 BinaryOpStmt *create_cmp_le(Stmt *l, Stmt *r);
199 BinaryOpStmt *create_cmp_gt(Stmt *l, Stmt *r);
200 BinaryOpStmt *create_cmp_ge(Stmt *l, Stmt *r);
201 BinaryOpStmt *create_cmp_eq(Stmt *l, Stmt *r);
202 BinaryOpStmt *create_cmp_ne(Stmt *l, Stmt *r);
203
204 // Atomic operations.
205 AtomicOpStmt *create_atomic_add(Stmt *dest, Stmt *val);
206 AtomicOpStmt *create_atomic_sub(Stmt *dest, Stmt *val);
207 AtomicOpStmt *create_atomic_max(Stmt *dest, Stmt *val);
208 AtomicOpStmt *create_atomic_min(Stmt *dest, Stmt *val);
209 // Atomic bitwise operations.
210 AtomicOpStmt *create_atomic_and(Stmt *dest, Stmt *val);
211 AtomicOpStmt *create_atomic_or(Stmt *dest, Stmt *val);
212 AtomicOpStmt *create_atomic_xor(Stmt *dest, Stmt *val);
213
214 // Ternary operations. Returns the result.
215 TernaryOpStmt *create_select(Stmt *cond,
216 Stmt *true_result,
217 Stmt *false_result);
218
219 // Matrix Initialization
220 MatrixInitStmt *create_matrix_init(std::vector<Stmt *> elements);
221
222 // Print values and strings. Arguments can be Stmt* or std::string.
223 template <typename... Args>
224 PrintStmt *create_print(Args &&...args) {
225 return insert(Stmt::make_typed<PrintStmt>(std::forward<Args>(args)...));
226 }
227
228 // Local variables.
229 AllocaStmt *create_local_var(DataType dt);
230 LocalLoadStmt *create_local_load(AllocaStmt *ptr);
231 void create_local_store(AllocaStmt *ptr, Stmt *data);
232
233 // Global variables.
234 GlobalPtrStmt *create_global_ptr(SNode *snode,
235 const std::vector<Stmt *> &indices);
236 ExternalPtrStmt *create_external_ptr(ArgLoadStmt *ptr,
237 const std::vector<Stmt *> &indices);
238 template <typename XStmt>
239 GlobalLoadStmt *create_global_load(XStmt *ptr) {
240 using DecayedType = typename std::decay_t<XStmt>;
241 if constexpr (!std::is_base_of_v<Stmt, DecayedType>) {
242 TI_ERROR("The argument is not a statement.");
243 }
244 if constexpr (std::is_same_v<DecayedType, GlobalPtrStmt> ||
245 std::is_same_v<DecayedType, ExternalPtrStmt>) {
246 return insert(Stmt::make_typed<GlobalLoadStmt>(ptr));
247 } else {
248 TI_ERROR("Statement {} is not a global pointer.", ptr->name());
249 }
250 }
251 template <typename XStmt>
252 void create_global_store(XStmt *ptr, Stmt *data) {
253 using DecayedType = typename std::decay_t<XStmt>;
254 if constexpr (!std::is_base_of_v<Stmt, DecayedType>) {
255 TI_ERROR("The argument is not a statement.");
256 }
257 if constexpr (std::is_same_v<DecayedType, GlobalPtrStmt> ||
258 std::is_same_v<DecayedType, ExternalPtrStmt>) {
259 insert(Stmt::make_typed<GlobalStoreStmt>(ptr, data));
260 } else {
261 TI_ERROR("Statement {} is not a global pointer.", ptr->name());
262 }
263 }
264
265 // Autodiff stack operations.
266 AdStackAllocaStmt *create_ad_stack(const DataType &dt, std::size_t max_size);
267 void ad_stack_push(AdStackAllocaStmt *stack, Stmt *val);
268 void ad_stack_pop(AdStackAllocaStmt *stack);
269 AdStackLoadTopStmt *ad_stack_load_top(AdStackAllocaStmt *stack);
270 AdStackLoadTopAdjStmt *ad_stack_load_top_adjoint(AdStackAllocaStmt *stack);
271 void ad_stack_accumulate_adjoint(AdStackAllocaStmt *stack, Stmt *val);
272
273 // Mesh related.
274 MeshRelationAccessStmt *get_relation_size(mesh::Mesh *mesh,
275 Stmt *mesh_idx,
276 mesh::MeshElementType to_type);
277 MeshRelationAccessStmt *get_relation_access(mesh::Mesh *mesh,
278 Stmt *mesh_idx,
279 mesh::MeshElementType to_type,
280 Stmt *neighbor_idx);
281 MeshPatchIndexStmt *get_patch_index();
282
283 private:
284 std::unique_ptr<Block> root_{nullptr};
285 InsertPoint insert_point_;
286};
287
288} // namespace taichi::lang
289