1 | // TODO: gradually cppize statements.h |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/util/bit.h" |
4 | |
5 | namespace taichi::lang { |
6 | |
7 | UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand) |
8 | : op_type(op_type), operand(operand) { |
9 | TI_ASSERT(!operand->is<AllocaStmt>()); |
10 | cast_type = PrimitiveType::unknown; |
11 | TI_STMT_REG_FIELDS; |
12 | } |
13 | |
14 | DecorationStmt::DecorationStmt(Stmt *operand, |
15 | const std::vector<uint32_t> &decoration) |
16 | : operand(operand), decoration(decoration) { |
17 | TI_STMT_REG_FIELDS; |
18 | } |
19 | |
20 | bool UnaryOpStmt::is_cast() const { |
21 | return unary_op_is_cast(op_type); |
22 | } |
23 | |
24 | bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const { |
25 | if (op_type == o->op_type) { |
26 | if (is_cast()) { |
27 | return cast_type == o->cast_type; |
28 | } else { |
29 | return true; |
30 | } |
31 | } |
32 | return false; |
33 | } |
34 | |
35 | ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr, |
36 | const std::vector<Stmt *> &indices) |
37 | : base_ptr(base_ptr), indices(indices) { |
38 | TI_ASSERT(base_ptr != nullptr); |
39 | TI_ASSERT(base_ptr->is<ArgLoadStmt>()); |
40 | TI_STMT_REG_FIELDS; |
41 | } |
42 | |
43 | ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr, |
44 | const std::vector<Stmt *> &indices, |
45 | const std::vector<int> &element_shape, |
46 | int element_dim) |
47 | : ExternalPtrStmt(base_ptr, indices) { |
48 | this->element_shape = element_shape; |
49 | this->element_dim = element_dim; |
50 | } |
51 | |
52 | GlobalPtrStmt::GlobalPtrStmt(SNode *snode, |
53 | const std::vector<Stmt *> &indices, |
54 | bool activate, |
55 | bool is_cell_access) |
56 | : snode(snode), |
57 | indices(indices), |
58 | activate(activate), |
59 | is_cell_access(is_cell_access), |
60 | is_bit_vectorized(false) { |
61 | TI_ASSERT(snode != nullptr); |
62 | element_type() = snode->dt; |
63 | TI_STMT_REG_FIELDS; |
64 | } |
65 | |
66 | MatrixOfGlobalPtrStmt::MatrixOfGlobalPtrStmt(const std::vector<SNode *> &snodes, |
67 | const std::vector<Stmt *> &indices, |
68 | bool dynamic_indexable, |
69 | int dynamic_index_stride, |
70 | DataType dt, |
71 | bool activate) |
72 | : snodes(snodes), |
73 | indices(indices), |
74 | dynamic_indexable(dynamic_indexable), |
75 | dynamic_index_stride(dynamic_index_stride), |
76 | activate(activate) { |
77 | ret_type = dt; |
78 | TI_STMT_REG_FIELDS; |
79 | } |
80 | |
81 | MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts, |
82 | DataType dt) |
83 | : stmts(stmts) { |
84 | ret_type = dt; |
85 | TI_STMT_REG_FIELDS; |
86 | } |
87 | |
88 | MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input, |
89 | Stmt *offset_input, |
90 | const std::string &tb) { |
91 | origin = origin_input; |
92 | offset = offset_input; |
93 | this->tb = tb; |
94 | if (origin->is<AllocaStmt>() || origin->is<GlobalTemporaryStmt>() || |
95 | origin->is<ExternalPtrStmt>() || origin->is<MatrixOfGlobalPtrStmt>() || |
96 | origin->is<MatrixOfMatrixPtrStmt>()) { |
97 | auto tensor_type = origin->ret_type.ptr_removed()->cast<TensorType>(); |
98 | TI_ASSERT(tensor_type != nullptr); |
99 | element_type() = tensor_type->get_element_type(); |
100 | element_type().set_is_pointer(true); |
101 | } else if (origin->is<GlobalPtrStmt>()) { |
102 | element_type() = origin->cast<GlobalPtrStmt>()->ret_type; |
103 | } else { |
104 | TI_ERROR( |
105 | "MatrixPtrStmt must be used for AllocaStmt / GlobalTemporaryStmt " |
106 | "(locally) or GlobalPtrStmt / MatrixOfGlobalPtrStmt / ExternalPtrStmt " |
107 | "(globally)." ) |
108 | } |
109 | TI_STMT_REG_FIELDS; |
110 | } |
111 | |
112 | SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, |
113 | SNode *snode, |
114 | Stmt *ptr, |
115 | Stmt *val) |
116 | : op_type(op_type), snode(snode), ptr(ptr), val(val) { |
117 | element_type() = PrimitiveType::i32; |
118 | TI_STMT_REG_FIELDS; |
119 | } |
120 | |
121 | bool SNodeOpStmt::activation_related(SNodeOpType op) { |
122 | return op == SNodeOpType::activate || op == SNodeOpType::deactivate || |
123 | op == SNodeOpType::is_active; |
124 | } |
125 | |
126 | bool SNodeOpStmt::need_activation(SNodeOpType op) { |
127 | return op == SNodeOpType::activate || op == SNodeOpType::append || |
128 | op == SNodeOpType::allocate; |
129 | } |
130 | |
131 | ExternalTensorShapeAlongAxisStmt::ExternalTensorShapeAlongAxisStmt(int axis, |
132 | int arg_id) |
133 | : axis(axis), arg_id(arg_id) { |
134 | TI_STMT_REG_FIELDS; |
135 | } |
136 | |
137 | LoopUniqueStmt::LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers) |
138 | : input(input) { |
139 | for (const auto &sn : covers) { |
140 | if (sn->is_place()) { |
141 | TI_INFO( |
142 | "A place SNode {} appears in the 'covers' parameter " |
143 | "of 'ti.loop_unique'. It is recommended to use its parent " |
144 | "(x.parent()) instead." , |
145 | sn->get_node_type_name_hinted()); |
146 | this->covers.insert(sn->parent->id); |
147 | } else |
148 | this->covers.insert(sn->id); |
149 | } |
150 | TI_STMT_REG_FIELDS; |
151 | } |
152 | |
153 | IfStmt::IfStmt(Stmt *cond) : cond(cond) { |
154 | TI_STMT_REG_FIELDS; |
155 | } |
156 | |
157 | void IfStmt::set_true_statements(std::unique_ptr<Block> &&new_true_statements) { |
158 | true_statements = std::move(new_true_statements); |
159 | if (true_statements) |
160 | true_statements->parent_stmt = this; |
161 | } |
162 | |
163 | void IfStmt::set_false_statements( |
164 | std::unique_ptr<Block> &&new_false_statements) { |
165 | false_statements = std::move(new_false_statements); |
166 | if (false_statements) |
167 | false_statements->parent_stmt = this; |
168 | } |
169 | |
170 | std::unique_ptr<Stmt> IfStmt::clone() const { |
171 | auto new_stmt = std::make_unique<IfStmt>(cond); |
172 | if (true_statements) |
173 | new_stmt->set_true_statements(true_statements->clone()); |
174 | if (false_statements) |
175 | new_stmt->set_false_statements(false_statements->clone()); |
176 | return new_stmt; |
177 | } |
178 | |
179 | RangeForStmt::RangeForStmt(Stmt *begin, |
180 | Stmt *end, |
181 | std::unique_ptr<Block> &&body, |
182 | bool is_bit_vectorized, |
183 | int num_cpu_threads, |
184 | int block_dim, |
185 | bool strictly_serialized, |
186 | std::string range_hint) |
187 | : begin(begin), |
188 | end(end), |
189 | body(std::move(body)), |
190 | is_bit_vectorized(is_bit_vectorized), |
191 | num_cpu_threads(num_cpu_threads), |
192 | block_dim(block_dim), |
193 | strictly_serialized(strictly_serialized), |
194 | range_hint(range_hint) { |
195 | reversed = false; |
196 | this->body->parent_stmt = this; |
197 | TI_STMT_REG_FIELDS; |
198 | } |
199 | |
200 | std::unique_ptr<Stmt> RangeForStmt::clone() const { |
201 | auto new_stmt = std::make_unique<RangeForStmt>( |
202 | begin, end, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim, |
203 | strictly_serialized); |
204 | new_stmt->reversed = reversed; |
205 | return new_stmt; |
206 | } |
207 | |
208 | StructForStmt::StructForStmt(SNode *snode, |
209 | std::unique_ptr<Block> &&body, |
210 | bool is_bit_vectorized, |
211 | int num_cpu_threads, |
212 | int block_dim) |
213 | : snode(snode), |
214 | body(std::move(body)), |
215 | is_bit_vectorized(is_bit_vectorized), |
216 | num_cpu_threads(num_cpu_threads), |
217 | block_dim(block_dim) { |
218 | this->body->parent_stmt = this; |
219 | TI_STMT_REG_FIELDS; |
220 | } |
221 | |
222 | std::unique_ptr<Stmt> StructForStmt::clone() const { |
223 | auto new_stmt = std::make_unique<StructForStmt>( |
224 | snode, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim); |
225 | new_stmt->mem_access_opt = mem_access_opt; |
226 | return new_stmt; |
227 | } |
228 | |
229 | MeshForStmt::MeshForStmt(mesh::Mesh *mesh, |
230 | mesh::MeshElementType element_type, |
231 | std::unique_ptr<Block> &&body, |
232 | bool is_bit_vectorized, |
233 | int num_cpu_threads, |
234 | int block_dim) |
235 | : mesh(mesh), |
236 | body(std::move(body)), |
237 | is_bit_vectorized(is_bit_vectorized), |
238 | num_cpu_threads(num_cpu_threads), |
239 | block_dim(block_dim), |
240 | major_from_type(element_type) { |
241 | this->body->parent_stmt = this; |
242 | TI_STMT_REG_FIELDS; |
243 | } |
244 | |
245 | std::unique_ptr<Stmt> MeshForStmt::clone() const { |
246 | auto new_stmt = std::make_unique<MeshForStmt>( |
247 | mesh, major_from_type, body->clone(), is_bit_vectorized, num_cpu_threads, |
248 | block_dim); |
249 | new_stmt->major_to_types = major_to_types; |
250 | new_stmt->minor_relation_types = minor_relation_types; |
251 | new_stmt->mem_access_opt = mem_access_opt; |
252 | return new_stmt; |
253 | } |
254 | |
255 | FuncCallStmt::FuncCallStmt(Function *func, const std::vector<Stmt *> &args) |
256 | : func(func), args(args) { |
257 | TI_STMT_REG_FIELDS; |
258 | } |
259 | |
260 | WhileStmt::WhileStmt(std::unique_ptr<Block> &&body) |
261 | : mask(nullptr), body(std::move(body)) { |
262 | this->body->parent_stmt = this; |
263 | TI_STMT_REG_FIELDS; |
264 | } |
265 | |
266 | std::unique_ptr<Stmt> WhileStmt::clone() const { |
267 | auto new_stmt = std::make_unique<WhileStmt>(body->clone()); |
268 | new_stmt->mask = mask; |
269 | return new_stmt; |
270 | } |
271 | |
272 | GetChStmt::GetChStmt(Stmt *input_ptr, int chid, bool is_bit_vectorized) |
273 | : input_ptr(input_ptr), chid(chid), is_bit_vectorized(is_bit_vectorized) { |
274 | TI_ASSERT(input_ptr->is<SNodeLookupStmt>()); |
275 | input_snode = input_ptr->as<SNodeLookupStmt>()->snode; |
276 | output_snode = input_snode->ch[chid].get(); |
277 | TI_STMT_REG_FIELDS; |
278 | } |
279 | |
280 | GetChStmt::GetChStmt(Stmt *input_ptr, |
281 | SNode *snode, |
282 | int chid, |
283 | bool is_bit_vectorized) |
284 | : input_ptr(input_ptr), chid(chid), is_bit_vectorized(is_bit_vectorized) { |
285 | input_snode = snode; |
286 | output_snode = input_snode->ch[chid].get(); |
287 | TI_STMT_REG_FIELDS; |
288 | } |
289 | |
290 | OffloadedStmt::OffloadedStmt(TaskType task_type, Arch arch) |
291 | : task_type(task_type), device(arch) { |
292 | if (has_body()) { |
293 | body = std::make_unique<Block>(); |
294 | body->parent_stmt = this; |
295 | } |
296 | TI_STMT_REG_FIELDS; |
297 | } |
298 | |
299 | std::string OffloadedStmt::task_name() const { |
300 | if (task_type == TaskType::serial) { |
301 | return "serial" ; |
302 | } else if (task_type == TaskType::range_for) { |
303 | return "range_for" ; |
304 | } else if (task_type == TaskType::struct_for) { |
305 | return "struct_for" ; |
306 | } else if (task_type == TaskType::mesh_for) { |
307 | return "mesh_for" ; |
308 | } else if (task_type == TaskType::listgen) { |
309 | TI_ASSERT(snode); |
310 | return fmt::format("listgen_{}" , snode->get_node_type_name_hinted()); |
311 | } else if (task_type == TaskType::gc) { |
312 | TI_ASSERT(snode); |
313 | return fmt::format("gc_{}" , snode->name); |
314 | } else if (task_type == TaskType::gc_rc) { |
315 | return fmt::format("gc_rc" ); |
316 | } else { |
317 | TI_NOT_IMPLEMENTED |
318 | } |
319 | } |
320 | |
321 | // static |
322 | std::string OffloadedStmt::task_type_name(TaskType tt) { |
323 | return offloaded_task_type_name(tt); |
324 | } |
325 | |
326 | std::unique_ptr<Stmt> OffloadedStmt::clone() const { |
327 | auto new_stmt = std::make_unique<OffloadedStmt>(task_type, device); |
328 | new_stmt->snode = snode; |
329 | new_stmt->begin_offset = begin_offset; |
330 | new_stmt->end_offset = end_offset; |
331 | new_stmt->const_begin = const_begin; |
332 | new_stmt->const_end = const_end; |
333 | new_stmt->begin_value = begin_value; |
334 | new_stmt->end_value = end_value; |
335 | new_stmt->grid_dim = grid_dim; |
336 | new_stmt->block_dim = block_dim; |
337 | new_stmt->reversed = reversed; |
338 | new_stmt->is_bit_vectorized = is_bit_vectorized; |
339 | new_stmt->num_cpu_threads = num_cpu_threads; |
340 | new_stmt->index_offsets = index_offsets; |
341 | |
342 | new_stmt->mesh = mesh; |
343 | new_stmt->major_from_type = major_from_type; |
344 | new_stmt->major_to_types = major_to_types; |
345 | new_stmt->minor_relation_types = minor_relation_types; |
346 | |
347 | new_stmt->owned_offset_local = owned_offset_local; |
348 | new_stmt->total_offset_local = total_offset_local; |
349 | new_stmt->owned_num_local = owned_num_local; |
350 | new_stmt->total_num_local = total_num_local; |
351 | |
352 | if (tls_prologue) { |
353 | new_stmt->tls_prologue = tls_prologue->clone(); |
354 | new_stmt->tls_prologue->parent_stmt = new_stmt.get(); |
355 | } |
356 | if (mesh_prologue) { |
357 | new_stmt->mesh_prologue = mesh_prologue->clone(); |
358 | new_stmt->mesh_prologue->parent_stmt = new_stmt.get(); |
359 | } |
360 | if (bls_prologue) { |
361 | new_stmt->bls_prologue = bls_prologue->clone(); |
362 | new_stmt->bls_prologue->parent_stmt = new_stmt.get(); |
363 | } |
364 | if (body) { |
365 | new_stmt->body = body->clone(); |
366 | new_stmt->body->parent_stmt = new_stmt.get(); |
367 | } |
368 | if (bls_epilogue) { |
369 | new_stmt->bls_epilogue = bls_epilogue->clone(); |
370 | new_stmt->bls_epilogue->parent_stmt = new_stmt.get(); |
371 | } |
372 | if (tls_epilogue) { |
373 | new_stmt->tls_epilogue = tls_epilogue->clone(); |
374 | new_stmt->tls_epilogue->parent_stmt = new_stmt.get(); |
375 | } |
376 | new_stmt->tls_size = tls_size; |
377 | new_stmt->bls_size = bls_size; |
378 | new_stmt->mem_access_opt = mem_access_opt; |
379 | return new_stmt; |
380 | } |
381 | |
382 | void OffloadedStmt::all_blocks_accept(IRVisitor *visitor, |
383 | bool skip_mesh_prologue) { |
384 | if (tls_prologue) |
385 | tls_prologue->accept(visitor); |
386 | if (mesh_prologue && !skip_mesh_prologue) |
387 | mesh_prologue->accept(visitor); |
388 | if (bls_prologue) |
389 | bls_prologue->accept(visitor); |
390 | if (body) |
391 | body->accept(visitor); |
392 | if (bls_epilogue) |
393 | bls_epilogue->accept(visitor); |
394 | if (tls_epilogue) |
395 | tls_epilogue->accept(visitor); |
396 | } |
397 | |
398 | bool is_clear_list_task(const OffloadedStmt *stmt) { |
399 | return (stmt->task_type == OffloadedStmt::TaskType::serial) && |
400 | (stmt->body->size() == 1) && stmt->body->back()->is<ClearListStmt>(); |
401 | } |
402 | |
403 | ClearListStmt::ClearListStmt(SNode *snode) : snode(snode) { |
404 | TI_STMT_REG_FIELDS; |
405 | } |
406 | |
407 | BitStructType *BitStructStoreStmt::get_bit_struct() const { |
408 | return ptr->as<SNodeLookupStmt>()->snode->dt->as<BitStructType>(); |
409 | } |
410 | |
411 | } // namespace taichi::lang |
412 | |