1// TODO: gradually cppize statements.h
2#include "taichi/ir/statements.h"
3#include "taichi/util/bit.h"
4
5namespace taichi::lang {
6
7UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand)
8 : op_type(op_type), operand(operand) {
9 TI_ASSERT(!operand->is<AllocaStmt>());
10 cast_type = PrimitiveType::unknown;
11 TI_STMT_REG_FIELDS;
12}
13
14DecorationStmt::DecorationStmt(Stmt *operand,
15 const std::vector<uint32_t> &decoration)
16 : operand(operand), decoration(decoration) {
17 TI_STMT_REG_FIELDS;
18}
19
20bool UnaryOpStmt::is_cast() const {
21 return unary_op_is_cast(op_type);
22}
23
24bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const {
25 if (op_type == o->op_type) {
26 if (is_cast()) {
27 return cast_type == o->cast_type;
28 } else {
29 return true;
30 }
31 }
32 return false;
33}
34
35ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr,
36 const std::vector<Stmt *> &indices)
37 : base_ptr(base_ptr), indices(indices) {
38 TI_ASSERT(base_ptr != nullptr);
39 TI_ASSERT(base_ptr->is<ArgLoadStmt>());
40 TI_STMT_REG_FIELDS;
41}
42
43ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr,
44 const std::vector<Stmt *> &indices,
45 const std::vector<int> &element_shape,
46 int element_dim)
47 : ExternalPtrStmt(base_ptr, indices) {
48 this->element_shape = element_shape;
49 this->element_dim = element_dim;
50}
51
52GlobalPtrStmt::GlobalPtrStmt(SNode *snode,
53 const std::vector<Stmt *> &indices,
54 bool activate,
55 bool is_cell_access)
56 : snode(snode),
57 indices(indices),
58 activate(activate),
59 is_cell_access(is_cell_access),
60 is_bit_vectorized(false) {
61 TI_ASSERT(snode != nullptr);
62 element_type() = snode->dt;
63 TI_STMT_REG_FIELDS;
64}
65
66MatrixOfGlobalPtrStmt::MatrixOfGlobalPtrStmt(const std::vector<SNode *> &snodes,
67 const std::vector<Stmt *> &indices,
68 bool dynamic_indexable,
69 int dynamic_index_stride,
70 DataType dt,
71 bool activate)
72 : snodes(snodes),
73 indices(indices),
74 dynamic_indexable(dynamic_indexable),
75 dynamic_index_stride(dynamic_index_stride),
76 activate(activate) {
77 ret_type = dt;
78 TI_STMT_REG_FIELDS;
79}
80
81MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts,
82 DataType dt)
83 : stmts(stmts) {
84 ret_type = dt;
85 TI_STMT_REG_FIELDS;
86}
87
88MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input,
89 Stmt *offset_input,
90 const std::string &tb) {
91 origin = origin_input;
92 offset = offset_input;
93 this->tb = tb;
94 if (origin->is<AllocaStmt>() || origin->is<GlobalTemporaryStmt>() ||
95 origin->is<ExternalPtrStmt>() || origin->is<MatrixOfGlobalPtrStmt>() ||
96 origin->is<MatrixOfMatrixPtrStmt>()) {
97 auto tensor_type = origin->ret_type.ptr_removed()->cast<TensorType>();
98 TI_ASSERT(tensor_type != nullptr);
99 element_type() = tensor_type->get_element_type();
100 element_type().set_is_pointer(true);
101 } else if (origin->is<GlobalPtrStmt>()) {
102 element_type() = origin->cast<GlobalPtrStmt>()->ret_type;
103 } else {
104 TI_ERROR(
105 "MatrixPtrStmt must be used for AllocaStmt / GlobalTemporaryStmt "
106 "(locally) or GlobalPtrStmt / MatrixOfGlobalPtrStmt / ExternalPtrStmt "
107 "(globally).")
108 }
109 TI_STMT_REG_FIELDS;
110}
111
112SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type,
113 SNode *snode,
114 Stmt *ptr,
115 Stmt *val)
116 : op_type(op_type), snode(snode), ptr(ptr), val(val) {
117 element_type() = PrimitiveType::i32;
118 TI_STMT_REG_FIELDS;
119}
120
121bool SNodeOpStmt::activation_related(SNodeOpType op) {
122 return op == SNodeOpType::activate || op == SNodeOpType::deactivate ||
123 op == SNodeOpType::is_active;
124}
125
126bool SNodeOpStmt::need_activation(SNodeOpType op) {
127 return op == SNodeOpType::activate || op == SNodeOpType::append ||
128 op == SNodeOpType::allocate;
129}
130
131ExternalTensorShapeAlongAxisStmt::ExternalTensorShapeAlongAxisStmt(int axis,
132 int arg_id)
133 : axis(axis), arg_id(arg_id) {
134 TI_STMT_REG_FIELDS;
135}
136
137LoopUniqueStmt::LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers)
138 : input(input) {
139 for (const auto &sn : covers) {
140 if (sn->is_place()) {
141 TI_INFO(
142 "A place SNode {} appears in the 'covers' parameter "
143 "of 'ti.loop_unique'. It is recommended to use its parent "
144 "(x.parent()) instead.",
145 sn->get_node_type_name_hinted());
146 this->covers.insert(sn->parent->id);
147 } else
148 this->covers.insert(sn->id);
149 }
150 TI_STMT_REG_FIELDS;
151}
152
153IfStmt::IfStmt(Stmt *cond) : cond(cond) {
154 TI_STMT_REG_FIELDS;
155}
156
157void IfStmt::set_true_statements(std::unique_ptr<Block> &&new_true_statements) {
158 true_statements = std::move(new_true_statements);
159 if (true_statements)
160 true_statements->parent_stmt = this;
161}
162
163void IfStmt::set_false_statements(
164 std::unique_ptr<Block> &&new_false_statements) {
165 false_statements = std::move(new_false_statements);
166 if (false_statements)
167 false_statements->parent_stmt = this;
168}
169
170std::unique_ptr<Stmt> IfStmt::clone() const {
171 auto new_stmt = std::make_unique<IfStmt>(cond);
172 if (true_statements)
173 new_stmt->set_true_statements(true_statements->clone());
174 if (false_statements)
175 new_stmt->set_false_statements(false_statements->clone());
176 return new_stmt;
177}
178
179RangeForStmt::RangeForStmt(Stmt *begin,
180 Stmt *end,
181 std::unique_ptr<Block> &&body,
182 bool is_bit_vectorized,
183 int num_cpu_threads,
184 int block_dim,
185 bool strictly_serialized,
186 std::string range_hint)
187 : begin(begin),
188 end(end),
189 body(std::move(body)),
190 is_bit_vectorized(is_bit_vectorized),
191 num_cpu_threads(num_cpu_threads),
192 block_dim(block_dim),
193 strictly_serialized(strictly_serialized),
194 range_hint(range_hint) {
195 reversed = false;
196 this->body->parent_stmt = this;
197 TI_STMT_REG_FIELDS;
198}
199
200std::unique_ptr<Stmt> RangeForStmt::clone() const {
201 auto new_stmt = std::make_unique<RangeForStmt>(
202 begin, end, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim,
203 strictly_serialized);
204 new_stmt->reversed = reversed;
205 return new_stmt;
206}
207
208StructForStmt::StructForStmt(SNode *snode,
209 std::unique_ptr<Block> &&body,
210 bool is_bit_vectorized,
211 int num_cpu_threads,
212 int block_dim)
213 : snode(snode),
214 body(std::move(body)),
215 is_bit_vectorized(is_bit_vectorized),
216 num_cpu_threads(num_cpu_threads),
217 block_dim(block_dim) {
218 this->body->parent_stmt = this;
219 TI_STMT_REG_FIELDS;
220}
221
222std::unique_ptr<Stmt> StructForStmt::clone() const {
223 auto new_stmt = std::make_unique<StructForStmt>(
224 snode, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim);
225 new_stmt->mem_access_opt = mem_access_opt;
226 return new_stmt;
227}
228
229MeshForStmt::MeshForStmt(mesh::Mesh *mesh,
230 mesh::MeshElementType element_type,
231 std::unique_ptr<Block> &&body,
232 bool is_bit_vectorized,
233 int num_cpu_threads,
234 int block_dim)
235 : mesh(mesh),
236 body(std::move(body)),
237 is_bit_vectorized(is_bit_vectorized),
238 num_cpu_threads(num_cpu_threads),
239 block_dim(block_dim),
240 major_from_type(element_type) {
241 this->body->parent_stmt = this;
242 TI_STMT_REG_FIELDS;
243}
244
245std::unique_ptr<Stmt> MeshForStmt::clone() const {
246 auto new_stmt = std::make_unique<MeshForStmt>(
247 mesh, major_from_type, body->clone(), is_bit_vectorized, num_cpu_threads,
248 block_dim);
249 new_stmt->major_to_types = major_to_types;
250 new_stmt->minor_relation_types = minor_relation_types;
251 new_stmt->mem_access_opt = mem_access_opt;
252 return new_stmt;
253}
254
255FuncCallStmt::FuncCallStmt(Function *func, const std::vector<Stmt *> &args)
256 : func(func), args(args) {
257 TI_STMT_REG_FIELDS;
258}
259
260WhileStmt::WhileStmt(std::unique_ptr<Block> &&body)
261 : mask(nullptr), body(std::move(body)) {
262 this->body->parent_stmt = this;
263 TI_STMT_REG_FIELDS;
264}
265
266std::unique_ptr<Stmt> WhileStmt::clone() const {
267 auto new_stmt = std::make_unique<WhileStmt>(body->clone());
268 new_stmt->mask = mask;
269 return new_stmt;
270}
271
272GetChStmt::GetChStmt(Stmt *input_ptr, int chid, bool is_bit_vectorized)
273 : input_ptr(input_ptr), chid(chid), is_bit_vectorized(is_bit_vectorized) {
274 TI_ASSERT(input_ptr->is<SNodeLookupStmt>());
275 input_snode = input_ptr->as<SNodeLookupStmt>()->snode;
276 output_snode = input_snode->ch[chid].get();
277 TI_STMT_REG_FIELDS;
278}
279
280GetChStmt::GetChStmt(Stmt *input_ptr,
281 SNode *snode,
282 int chid,
283 bool is_bit_vectorized)
284 : input_ptr(input_ptr), chid(chid), is_bit_vectorized(is_bit_vectorized) {
285 input_snode = snode;
286 output_snode = input_snode->ch[chid].get();
287 TI_STMT_REG_FIELDS;
288}
289
290OffloadedStmt::OffloadedStmt(TaskType task_type, Arch arch)
291 : task_type(task_type), device(arch) {
292 if (has_body()) {
293 body = std::make_unique<Block>();
294 body->parent_stmt = this;
295 }
296 TI_STMT_REG_FIELDS;
297}
298
299std::string OffloadedStmt::task_name() const {
300 if (task_type == TaskType::serial) {
301 return "serial";
302 } else if (task_type == TaskType::range_for) {
303 return "range_for";
304 } else if (task_type == TaskType::struct_for) {
305 return "struct_for";
306 } else if (task_type == TaskType::mesh_for) {
307 return "mesh_for";
308 } else if (task_type == TaskType::listgen) {
309 TI_ASSERT(snode);
310 return fmt::format("listgen_{}", snode->get_node_type_name_hinted());
311 } else if (task_type == TaskType::gc) {
312 TI_ASSERT(snode);
313 return fmt::format("gc_{}", snode->name);
314 } else if (task_type == TaskType::gc_rc) {
315 return fmt::format("gc_rc");
316 } else {
317 TI_NOT_IMPLEMENTED
318 }
319}
320
321// static
322std::string OffloadedStmt::task_type_name(TaskType tt) {
323 return offloaded_task_type_name(tt);
324}
325
326std::unique_ptr<Stmt> OffloadedStmt::clone() const {
327 auto new_stmt = std::make_unique<OffloadedStmt>(task_type, device);
328 new_stmt->snode = snode;
329 new_stmt->begin_offset = begin_offset;
330 new_stmt->end_offset = end_offset;
331 new_stmt->const_begin = const_begin;
332 new_stmt->const_end = const_end;
333 new_stmt->begin_value = begin_value;
334 new_stmt->end_value = end_value;
335 new_stmt->grid_dim = grid_dim;
336 new_stmt->block_dim = block_dim;
337 new_stmt->reversed = reversed;
338 new_stmt->is_bit_vectorized = is_bit_vectorized;
339 new_stmt->num_cpu_threads = num_cpu_threads;
340 new_stmt->index_offsets = index_offsets;
341
342 new_stmt->mesh = mesh;
343 new_stmt->major_from_type = major_from_type;
344 new_stmt->major_to_types = major_to_types;
345 new_stmt->minor_relation_types = minor_relation_types;
346
347 new_stmt->owned_offset_local = owned_offset_local;
348 new_stmt->total_offset_local = total_offset_local;
349 new_stmt->owned_num_local = owned_num_local;
350 new_stmt->total_num_local = total_num_local;
351
352 if (tls_prologue) {
353 new_stmt->tls_prologue = tls_prologue->clone();
354 new_stmt->tls_prologue->parent_stmt = new_stmt.get();
355 }
356 if (mesh_prologue) {
357 new_stmt->mesh_prologue = mesh_prologue->clone();
358 new_stmt->mesh_prologue->parent_stmt = new_stmt.get();
359 }
360 if (bls_prologue) {
361 new_stmt->bls_prologue = bls_prologue->clone();
362 new_stmt->bls_prologue->parent_stmt = new_stmt.get();
363 }
364 if (body) {
365 new_stmt->body = body->clone();
366 new_stmt->body->parent_stmt = new_stmt.get();
367 }
368 if (bls_epilogue) {
369 new_stmt->bls_epilogue = bls_epilogue->clone();
370 new_stmt->bls_epilogue->parent_stmt = new_stmt.get();
371 }
372 if (tls_epilogue) {
373 new_stmt->tls_epilogue = tls_epilogue->clone();
374 new_stmt->tls_epilogue->parent_stmt = new_stmt.get();
375 }
376 new_stmt->tls_size = tls_size;
377 new_stmt->bls_size = bls_size;
378 new_stmt->mem_access_opt = mem_access_opt;
379 return new_stmt;
380}
381
382void OffloadedStmt::all_blocks_accept(IRVisitor *visitor,
383 bool skip_mesh_prologue) {
384 if (tls_prologue)
385 tls_prologue->accept(visitor);
386 if (mesh_prologue && !skip_mesh_prologue)
387 mesh_prologue->accept(visitor);
388 if (bls_prologue)
389 bls_prologue->accept(visitor);
390 if (body)
391 body->accept(visitor);
392 if (bls_epilogue)
393 bls_epilogue->accept(visitor);
394 if (tls_epilogue)
395 tls_epilogue->accept(visitor);
396}
397
398bool is_clear_list_task(const OffloadedStmt *stmt) {
399 return (stmt->task_type == OffloadedStmt::TaskType::serial) &&
400 (stmt->body->size() == 1) && stmt->body->back()->is<ClearListStmt>();
401}
402
403ClearListStmt::ClearListStmt(SNode *snode) : snode(snode) {
404 TI_STMT_REG_FIELDS;
405}
406
407BitStructType *BitStructStoreStmt::get_bit_struct() const {
408 return ptr->as<SNodeLookupStmt>()->snode->dt->as<BitStructType>();
409}
410
411} // namespace taichi::lang
412